🌐 AI搜索 & 代理 主页
Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
feat: add IntEnum for sqltypes
  • Loading branch information
KunxiSun committed Apr 3, 2025
commit 07bf10cc76f0f9630c010a758fbc2a96abfb3efc
1 change: 1 addition & 0 deletions sqlmodel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,3 +141,4 @@
from .sql.expression import type_coerce as type_coerce
from .sql.expression import within_group as within_group
from .sql.sqltypes import AutoString as AutoString
from .sql.sqltypes import IntEnum as IntEnum
57 changes: 56 additions & 1 deletion sqlmodel/sql/sqltypes.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Any, cast
from typing import Any, cast, Optional
from enum import IntEnum as _IntEnum

from sqlalchemy import types
from sqlalchemy.engine.interfaces import Dialect
Expand All @@ -14,3 +15,57 @@ def load_dialect_impl(self, dialect: Dialect) -> "types.TypeEngine[Any]":
if impl.length is None and dialect.name == "mysql":
return dialect.type_descriptor(types.String(self.mysql_default_length))
return super().load_dialect_impl(dialect)

class IntEnum(types.TypeDecorator): # type: ignore
"""TypeDecorator for Integer-enum conversion.

Automatically converts Python enum.IntEnum <-> database integers.

Args:
enum_type (enum.IntEnum): Integer enum class (subclass of enum.IntEnum)

Example:
>>> class HeroStatus(enum.IntEnum):
... ACTIVE = 1
... DISABLE = 2
>>>>
>>> from sqlmodel import IntEnum
>>> class Hero(SQLModel):
... hero_status: HeroStatus = Field(sa_type=sqlmodel.IntEnum(HeroStatus))
>>> user.hero_status == Status.ACTIVE # Loads back as enum

Returns:
Optional[enum.IntEnum]: Converted enum instance (None if database value is NULL)

Raises:
TypeError: For invalid enum types
"""

impl = types.Integer

def __init__(self, enum_type: _IntEnum, *args, **kwargs):
super().__init__(*args, **kwargs)

# validate the input enum type
if not issubclass(enum_type, _IntEnum):
raise TypeError(
f"Input must be enum.IntEnum"
)

self.enum_type = enum_type

def process_result_value(self, value: Optional[int], dialect) -> Optional[_IntEnum]:

if value is None:
return None

result = self.enum_type(value)
return result

def process_bind_param(self, value: Optional[_IntEnum], dialect) -> Optional[int]:

if value is None:
return None

result = value.value
return result
30 changes: 24 additions & 6 deletions tests/test_enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,14 +63,21 @@ def test_json_schema_flat_model_pydantic_v1():
"properties": {
"id": {"title": "Id", "type": "string", "format": "uuid"},
"enum_field": {"$ref": "#/definitions/MyEnum1"},
"int_enum_field": {"$ref": "#/definitions/MyEnum3"},
},
"required": ["id", "enum_field"],
"required": ["id", "enum_field", "int_enum_field"],
"definitions": {
"MyEnum1": {
"title": "MyEnum1",
"description": "An enumeration.",
"enum": ["A", "B"],
"type": "string",
},
"MyEnum3": {
"title": "MyEnum3",
"description": "An enumeration.",
"enum": [1, 3],
"type": "int",
}
},
}
Expand All @@ -84,14 +91,21 @@ def test_json_schema_inherit_model_pydantic_v1():
"properties": {
"id": {"title": "Id", "type": "string", "format": "uuid"},
"enum_field": {"$ref": "#/definitions/MyEnum2"},
"int_enum_field": {"$ref": "#/definitions/MyEnum3"},
},
"required": ["id", "enum_field"],
"required": ["id", "enum_field", "int_enum_field"],
"definitions": {
"MyEnum2": {
"title": "MyEnum2",
"description": "An enumeration.",
"enum": ["C", "D"],
"type": "string",
},
"MyEnum3": {
"title": "MyEnum3",
"description": "An int enumeration.",
"enum": [1, 3],
"type": "int",
}
},
}
Expand All @@ -105,10 +119,12 @@ def test_json_schema_flat_model_pydantic_v2():
"properties": {
"id": {"title": "Id", "type": "string", "format": "uuid"},
"enum_field": {"$ref": "#/$defs/MyEnum1"},
"int_enum_field": {"$ref": "#/$defs/MyEnum3"},
},
"required": ["id", "enum_field"],
"required": ["id", "enum_field", "int_enum_field"],
"$defs": {
"MyEnum1": {"enum": ["A", "B"], "title": "MyEnum1", "type": "string"}
"MyEnum1": {"enum": ["A", "B"], "title": "MyEnum1", "type": "string"},
"MyEnum3": {"enum": [1, 2], "title": "MyEnum3", "type": "integer"},
},
}

Expand All @@ -121,9 +137,11 @@ def test_json_schema_inherit_model_pydantic_v2():
"properties": {
"id": {"title": "Id", "type": "string", "format": "uuid"},
"enum_field": {"$ref": "#/$defs/MyEnum2"},
"int_enum_field": {"$ref": "#/$defs/MyEnum3"},
},
"required": ["id", "enum_field"],
"required": ["id", "enum_field", "int_enum_field"],
"$defs": {
"MyEnum2": {"enum": ["C", "D"], "title": "MyEnum2", "type": "string"}
"MyEnum2": {"enum": ["C", "D"], "title": "MyEnum2", "type": "string"},
"MyEnum3": {"enum": [1, 2], "title": "MyEnum3", "type": "integer"},
},
}
8 changes: 6 additions & 2 deletions tests/test_enums_models.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import enum
import uuid

from sqlmodel import Field, SQLModel
from sqlmodel import Field, SQLModel, IntEnum


class MyEnum1(str, enum.Enum):
Expand All @@ -13,15 +13,19 @@ class MyEnum2(str, enum.Enum):
C = "C"
D = "D"

class MyEnum3(enum.IntEnum):
E = 1
F = 2

class BaseModel(SQLModel):
id: uuid.UUID = Field(primary_key=True)
enum_field: MyEnum2

int_enum_field: MyEnum3

class FlatModel(SQLModel, table=True):
id: uuid.UUID = Field(primary_key=True)
enum_field: MyEnum1
int_enum_field: MyEnum3 = Field(sa_type=IntEnum(MyEnum3))


class InheritModel(BaseModel, table=True):
Expand Down
Loading