from __future__ import annotations import builtins import ipaddress import uuid import weakref from collections.abc import Callable, Mapping, Sequence, Set from dataclasses import dataclass from datetime import date, datetime, time, timedelta from decimal import Decimal from enum import Enum from pathlib import Path from typing import ( TYPE_CHECKING, Any, ClassVar, Literal, TypeAlias, TypeVar, Union, cast, get_origin, overload, ) from pydantic import BaseModel, EmailStr from pydantic.fields import FieldInfo as PydanticFieldInfo from sqlalchemy import ( Boolean, Column, Date, DateTime, Float, ForeignKey, Integer, Interval, Numeric, inspect, ) from sqlalchemy import Enum as sa_Enum from sqlalchemy.orm import ( Mapped, RelationshipProperty, declared_attr, registry, relationship, ) from sqlalchemy.orm.attributes import set_attribute from sqlalchemy.orm.decl_api import DeclarativeMeta from sqlalchemy.orm.instrumentation import is_instrumented from sqlalchemy.sql.schema import MetaData from sqlalchemy.sql.sqltypes import LargeBinary, Time, Uuid from typing_extensions import deprecated from ._compat import ( # type: ignore[attr-defined] PYDANTIC_MINOR_VERSION, BaseConfig, ModelMetaclass, Representation, SQLModelConfig, Undefined, UndefinedType, finish_init, get_annotations, get_field_metadata, get_model_fields, get_relationship_to, get_sa_type_from_field, init_pydantic_private_attrs, is_field_noneable, is_table_model_class, sqlmodel_init, sqlmodel_validate, ) from .sql.sqltypes import AutoString if TYPE_CHECKING: from pydantic._internal._model_construction import ModelMetaclass as ModelMetaclass from pydantic._internal._repr import Representation as Representation from pydantic_core import PydanticUndefined as Undefined from pydantic_core import PydanticUndefinedType as UndefinedType _T = TypeVar("_T") NoArgAnyCallable = Callable[[], Any] IncEx: TypeAlias = ( set[int] | set[str] | Mapping[int, Union["IncEx", bool]] | Mapping[str, Union["IncEx", bool]] ) OnDeleteType = Literal["CASCADE", "SET NULL", "RESTRICT"] def __dataclass_transform__( *, eq_default: bool = True, order_default: bool = False, kw_only_default: bool = False, field_descriptors: tuple[type | Callable[..., Any], ...] = (()), ) -> Callable[[_T], _T]: return lambda a: a class FieldInfo(PydanticFieldInfo): # type: ignore[misc] # mypy - ignore that PydanticFieldInfo is @final def __init__(self, default: Any = Undefined, **kwargs: Any) -> None: primary_key = kwargs.pop("primary_key", False) nullable = kwargs.pop("nullable", Undefined) foreign_key = kwargs.pop("foreign_key", Undefined) ondelete = kwargs.pop("ondelete", Undefined) unique = kwargs.pop("unique", False) index = kwargs.pop("index", Undefined) sa_type = kwargs.pop("sa_type", Undefined) sa_column = kwargs.pop("sa_column", Undefined) sa_column_args = kwargs.pop("sa_column_args", Undefined) sa_column_kwargs = kwargs.pop("sa_column_kwargs", Undefined) if sa_column is not Undefined: if sa_column_args is not Undefined: raise RuntimeError( "Passing sa_column_args is not supported when " "also passing a sa_column" ) if sa_column_kwargs is not Undefined: raise RuntimeError( "Passing sa_column_kwargs is not supported when " "also passing a sa_column" ) if primary_key is not Undefined: raise RuntimeError( "Passing primary_key is not supported when also passing a sa_column" ) if nullable is not Undefined: raise RuntimeError( "Passing nullable is not supported when also passing a sa_column" ) if foreign_key is not Undefined: raise RuntimeError( "Passing foreign_key is not supported when also passing a sa_column" ) if ondelete is not Undefined: raise RuntimeError( "Passing ondelete is not supported when also passing a sa_column" ) if unique is not Undefined: raise RuntimeError( "Passing unique is not supported when also passing a sa_column" ) if index is not Undefined: raise RuntimeError( "Passing index is not supported when also passing a sa_column" ) if sa_type is not Undefined: raise RuntimeError( "Passing sa_type is not supported when also passing a sa_column" ) if ondelete is not Undefined: if foreign_key is Undefined: raise RuntimeError("ondelete can only be used with foreign_key") super().__init__(default=default, **kwargs) self.primary_key = primary_key self.nullable = nullable self.foreign_key = foreign_key self.ondelete = ondelete self.unique = unique self.index = index self.sa_type = sa_type self.sa_column = sa_column self.sa_column_args = sa_column_args self.sa_column_kwargs = sa_column_kwargs class RelationshipInfo(Representation): def __init__( self, *, back_populates: str | None = None, cascade_delete: bool | None = False, passive_deletes: bool | Literal["all"] | None = False, link_model: Any | None = None, sa_relationship: RelationshipProperty | None = None, # type: ignore sa_relationship_args: Sequence[Any] | None = None, sa_relationship_kwargs: Mapping[str, Any] | None = None, ) -> None: if sa_relationship is not None: if sa_relationship_args is not None: raise RuntimeError( "Passing sa_relationship_args is not supported when " "also passing a sa_relationship" ) if sa_relationship_kwargs is not None: raise RuntimeError( "Passing sa_relationship_kwargs is not supported when " "also passing a sa_relationship" ) self.back_populates = back_populates self.cascade_delete = cascade_delete self.passive_deletes = passive_deletes self.link_model = link_model self.sa_relationship = sa_relationship self.sa_relationship_args = sa_relationship_args self.sa_relationship_kwargs = sa_relationship_kwargs @dataclass class FieldInfoMetadata: primary_key: bool | UndefinedType = Undefined nullable: bool | UndefinedType = Undefined foreign_key: Any = Undefined ondelete: OnDeleteType | UndefinedType = Undefined unique: bool | UndefinedType = Undefined index: bool | UndefinedType = Undefined sa_type: type[Any] | UndefinedType = Undefined sa_column: Column[Any] | UndefinedType = Undefined sa_column_args: Sequence[Any] | UndefinedType = Undefined sa_column_kwargs: Mapping[str, Any] | UndefinedType = Undefined def _get_sqlmodel_field_metadata(field_info: Any) -> FieldInfoMetadata | None: metadata_items = getattr(field_info, "metadata", None) if metadata_items: for meta in metadata_items: if isinstance(meta, FieldInfoMetadata): return meta return None def _get_sqlmodel_field_value( field_info: Any, attribute: str, default: Any = Undefined ) -> Any: metadata = _get_sqlmodel_field_metadata(field_info) if metadata is not None and hasattr(metadata, attribute): return getattr(metadata, attribute) return getattr(field_info, attribute, default) # include sa_type, sa_column_args, sa_column_kwargs @overload def Field( default: Any = Undefined, *, default_factory: NoArgAnyCallable | None = None, alias: str | None = None, validation_alias: str | None = None, serialization_alias: str | None = None, title: str | None = None, description: str | None = None, exclude: Set[int | str] | Mapping[int | str, Any] | Any = None, include: Set[int | str] | Mapping[int | str, Any] | Any = None, const: bool | None = None, gt: float | None = None, ge: float | None = None, lt: float | None = None, le: float | None = None, multiple_of: float | None = None, max_digits: int | None = None, decimal_places: int | None = None, min_items: int | None = None, max_items: int | None = None, unique_items: bool | None = None, min_length: int | None = None, max_length: int | None = None, allow_mutation: bool = True, regex: str | None = None, discriminator: str | None = None, repr: bool = True, primary_key: bool | UndefinedType = Undefined, foreign_key: Any = Undefined, unique: bool | UndefinedType = Undefined, nullable: bool | UndefinedType = Undefined, index: bool | UndefinedType = Undefined, sa_type: type[Any] | UndefinedType = Undefined, sa_column_args: Sequence[Any] | UndefinedType = Undefined, sa_column_kwargs: Mapping[str, Any] | UndefinedType = Undefined, schema_extra: dict[str, Any] | None = None, ) -> Any: ... # When foreign_key is str, include ondelete # include sa_type, sa_column_args, sa_column_kwargs @overload def Field( default: Any = Undefined, *, default_factory: NoArgAnyCallable | None = None, alias: str | None = None, validation_alias: str | None = None, serialization_alias: str | None = None, title: str | None = None, description: str | None = None, exclude: Set[int | str] | Mapping[int | str, Any] | Any = None, include: Set[int | str] | Mapping[int | str, Any] | Any = None, const: bool | None = None, gt: float | None = None, ge: float | None = None, lt: float | None = None, le: float | None = None, multiple_of: float | None = None, max_digits: int | None = None, decimal_places: int | None = None, min_items: int | None = None, max_items: int | None = None, unique_items: bool | None = None, min_length: int | None = None, max_length: int | None = None, allow_mutation: bool = True, regex: str | None = None, discriminator: str | None = None, repr: bool = True, primary_key: bool | UndefinedType = Undefined, foreign_key: str, ondelete: OnDeleteType | UndefinedType = Undefined, unique: bool | UndefinedType = Undefined, nullable: bool | UndefinedType = Undefined, index: bool | UndefinedType = Undefined, sa_type: type[Any] | UndefinedType = Undefined, sa_column_args: Sequence[Any] | UndefinedType = Undefined, sa_column_kwargs: Mapping[str, Any] | UndefinedType = Undefined, schema_extra: dict[str, Any] | None = None, ) -> Any: ... # Include sa_column, don't include # primary_key # foreign_key # ondelete # unique # nullable # index # sa_type # sa_column_args # sa_column_kwargs @overload def Field( default: Any = Undefined, *, default_factory: NoArgAnyCallable | None = None, alias: str | None = None, validation_alias: str | None = None, serialization_alias: str | None = None, title: str | None = None, description: str | None = None, exclude: Set[int | str] | Mapping[int | str, Any] | Any = None, include: Set[int | str] | Mapping[int | str, Any] | Any = None, const: bool | None = None, gt: float | None = None, ge: float | None = None, lt: float | None = None, le: float | None = None, multiple_of: float | None = None, max_digits: int | None = None, decimal_places: int | None = None, min_items: int | None = None, max_items: int | None = None, unique_items: bool | None = None, min_length: int | None = None, max_length: int | None = None, allow_mutation: bool = True, regex: str | None = None, discriminator: str | None = None, repr: bool = True, sa_column: Column[Any] | UndefinedType = Undefined, schema_extra: dict[str, Any] | None = None, ) -> Any: ... def Field( default: Any = Undefined, *, default_factory: NoArgAnyCallable | None = None, alias: str | None = None, validation_alias: str | None = None, serialization_alias: str | None = None, title: str | None = None, description: str | None = None, exclude: Set[int | str] | Mapping[int | str, Any] | Any = None, include: Set[int | str] | Mapping[int | str, Any] | Any = None, const: bool | None = None, gt: float | None = None, ge: float | None = None, lt: float | None = None, le: float | None = None, multiple_of: float | None = None, max_digits: int | None = None, decimal_places: int | None = None, min_items: int | None = None, max_items: int | None = None, unique_items: bool | None = None, min_length: int | None = None, max_length: int | None = None, allow_mutation: bool = True, regex: str | None = None, discriminator: str | None = None, repr: bool = True, primary_key: bool | UndefinedType = Undefined, foreign_key: Any = Undefined, ondelete: OnDeleteType | UndefinedType = Undefined, unique: bool | UndefinedType = Undefined, nullable: bool | UndefinedType = Undefined, index: bool | UndefinedType = Undefined, sa_type: type[Any] | UndefinedType = Undefined, sa_column: Column | UndefinedType = Undefined, # type: ignore sa_column_args: Sequence[Any] | UndefinedType = Undefined, sa_column_kwargs: Mapping[str, Any] | UndefinedType = Undefined, schema_extra: dict[str, Any] | None = None, ) -> Any: current_schema_extra = schema_extra or {} # Extract possible alias settings from schema_extra so we can control precedence schema_validation_alias = current_schema_extra.pop("validation_alias", None) schema_serialization_alias = current_schema_extra.pop("serialization_alias", None) field_info_kwargs = { "alias": alias, "title": title, "description": description, "exclude": exclude, "include": include, "const": const, "gt": gt, "ge": ge, "lt": lt, "le": le, "multiple_of": multiple_of, "max_digits": max_digits, "decimal_places": decimal_places, "min_items": min_items, "max_items": max_items, "unique_items": unique_items, "min_length": min_length, "max_length": max_length, "allow_mutation": allow_mutation, "regex": regex, "discriminator": discriminator, "repr": repr, "primary_key": primary_key, "foreign_key": foreign_key, "ondelete": ondelete, "unique": unique, "nullable": nullable, "index": index, "sa_type": sa_type, "sa_column": sa_column, "sa_column_args": sa_column_args, "sa_column_kwargs": sa_column_kwargs, **current_schema_extra, } # explicit params > schema_extra > alias propagation field_info_kwargs["validation_alias"] = ( validation_alias or schema_validation_alias or alias ) field_info_kwargs["serialization_alias"] = ( serialization_alias or schema_serialization_alias or alias ) field_info = FieldInfo( default, default_factory=default_factory, **field_info_kwargs, ) field_metadata = FieldInfoMetadata( primary_key=primary_key, nullable=nullable, foreign_key=foreign_key, ondelete=ondelete, unique=unique, index=index, sa_type=sa_type, sa_column=sa_column, sa_column_args=sa_column_args, sa_column_kwargs=sa_column_kwargs, ) if hasattr(field_info, "metadata"): field_info.metadata.append(field_metadata) return field_info @overload def Relationship( *, back_populates: str | None = None, cascade_delete: bool | None = False, passive_deletes: bool | Literal["all"] | None = False, link_model: Any | None = None, sa_relationship_args: Sequence[Any] | None = None, sa_relationship_kwargs: Mapping[str, Any] | None = None, ) -> Any: ... @overload def Relationship( *, back_populates: str | None = None, cascade_delete: bool | None = False, passive_deletes: bool | Literal["all"] | None = False, link_model: Any | None = None, sa_relationship: RelationshipProperty[Any] | None = None, ) -> Any: ... def Relationship( *, back_populates: str | None = None, cascade_delete: bool | None = False, passive_deletes: bool | Literal["all"] | None = False, link_model: Any | None = None, sa_relationship: RelationshipProperty[Any] | None = None, sa_relationship_args: Sequence[Any] | None = None, sa_relationship_kwargs: Mapping[str, Any] | None = None, ) -> Any: relationship_info = RelationshipInfo( back_populates=back_populates, cascade_delete=cascade_delete, passive_deletes=passive_deletes, link_model=link_model, sa_relationship=sa_relationship, sa_relationship_args=sa_relationship_args, sa_relationship_kwargs=sa_relationship_kwargs, ) return relationship_info @__dataclass_transform__(kw_only_default=True, field_descriptors=(Field, FieldInfo)) class SQLModelMetaclass(ModelMetaclass, DeclarativeMeta): __sqlmodel_relationships__: dict[str, RelationshipInfo] model_config: SQLModelConfig model_fields: ClassVar[dict[str, FieldInfo]] # Replicate SQLAlchemy def __setattr__(cls, name: str, value: Any) -> None: if is_table_model_class(cls): DeclarativeMeta.__setattr__(cls, name, value) else: super().__setattr__(name, value) def __delattr__(cls, name: str) -> None: if is_table_model_class(cls): DeclarativeMeta.__delattr__(cls, name) else: super().__delattr__(name) # From Pydantic def __new__( cls, name: str, bases: tuple[type[Any], ...], class_dict: dict[str, Any], **kwargs: Any, ) -> Any: relationships: dict[str, RelationshipInfo] = {} dict_for_pydantic = {} original_annotations = get_annotations(class_dict) pydantic_annotations = {} relationship_annotations = {} for k, v in class_dict.items(): if isinstance(v, RelationshipInfo): relationships[k] = v else: dict_for_pydantic[k] = v for k, v in original_annotations.items(): if k in relationships: relationship_annotations[k] = v else: pydantic_annotations[k] = v dict_used = { **dict_for_pydantic, "__weakref__": None, "__sqlmodel_relationships__": relationships, "__annotations__": pydantic_annotations, } # Duplicate logic from Pydantic to filter config kwargs because if they are # passed directly including the registry Pydantic will pass them over to the # superclass causing an error allowed_config_kwargs: set[str] = { key for key in dir(BaseConfig) if not ( key.startswith("__") and key.endswith("__") ) # skip dunder methods and attributes } config_kwargs = { key: kwargs[key] for key in kwargs.keys() & allowed_config_kwargs } new_cls = cast( "SQLModel", super().__new__(cls, name, bases, dict_used, **config_kwargs) ) new_cls.__annotations__ = { **relationship_annotations, **pydantic_annotations, **new_cls.__annotations__, } def get_config(name: str) -> Any: config_class_value = new_cls.model_config.get(name, Undefined) if config_class_value is not Undefined: return config_class_value kwarg_value = kwargs.get(name, Undefined) if kwarg_value is not Undefined: return kwarg_value return Undefined config_table = get_config("table") if config_table is True: # If it was passed by kwargs, ensure it's also set in config new_cls.model_config["table"] = config_table for k, v in get_model_fields(new_cls).items(): col = get_column_from_field(v) setattr(new_cls, k, col) # Set a config flag to tell FastAPI that this should be read with a field # in orm_mode instead of preemptively converting it to a dict. # This could be done by reading new_cls.model_config['table'] in FastAPI, but # that's very specific about SQLModel, so let's have another config that # other future tools based on Pydantic can use. new_cls.model_config["read_from_attributes"] = True # type: ignore[typeddict-unknown-key] # For compatibility with older versions # TODO: remove this in the future new_cls.model_config["read_with_orm_mode"] = True # type: ignore[typeddict-unknown-key] config_registry = get_config("registry") if config_registry is not Undefined: config_registry = cast(registry, config_registry) # If it was passed by kwargs, ensure it's also set in config new_cls.model_config["registry"] = config_table setattr(new_cls, "_sa_registry", config_registry) # noqa: B010 setattr(new_cls, "metadata", config_registry.metadata) # noqa: B010 setattr(new_cls, "__abstract__", True) # noqa: B010 return new_cls # Override SQLAlchemy, allow both SQLAlchemy and plain Pydantic models def __init__( cls, classname: str, bases: tuple[type, ...], dict_: dict[str, Any], **kw: Any ) -> None: # Only one of the base classes (or the current one) should be a table model # this allows FastAPI cloning a SQLModel for the response_model without # trying to create a new SQLAlchemy, for a new table, with the same name, that # triggers an error base_is_table = any(is_table_model_class(base) for base in bases) if is_table_model_class(cls) and not base_is_table: for rel_name, rel_info in cls.__sqlmodel_relationships__.items(): if rel_info.sa_relationship: # There's a SQLAlchemy relationship declared, that takes precedence # over anything else, use that and continue with the next attribute setattr(cls, rel_name, rel_info.sa_relationship) # Fix #315 continue raw_ann = cls.__annotations__[rel_name] origin: Any = get_origin(raw_ann) if origin is Mapped: ann = raw_ann.__args__[0] else: ann = raw_ann # Plain forward references, for models not yet defined, are not # handled well by SQLAlchemy without Mapped, so, wrap the # annotations in Mapped here cls.__annotations__[rel_name] = Mapped[ann] # type: ignore[valid-type] relationship_to = get_relationship_to( name=rel_name, rel_info=rel_info, annotation=ann ) rel_kwargs: dict[str, Any] = {} if rel_info.back_populates: rel_kwargs["back_populates"] = rel_info.back_populates if rel_info.cascade_delete: rel_kwargs["cascade"] = "all, delete-orphan" if rel_info.passive_deletes: rel_kwargs["passive_deletes"] = rel_info.passive_deletes if rel_info.link_model: ins = inspect(rel_info.link_model) local_table = getattr(ins, "local_table") # noqa: B009 if local_table is None: raise RuntimeError( "Couldn't find the secondary table for " f"model {rel_info.link_model}" ) rel_kwargs["secondary"] = local_table rel_args: list[Any] = [] if rel_info.sa_relationship_args: rel_args.extend(rel_info.sa_relationship_args) if rel_info.sa_relationship_kwargs: rel_kwargs.update(rel_info.sa_relationship_kwargs) rel_value = relationship(relationship_to, *rel_args, **rel_kwargs) setattr(cls, rel_name, rel_value) # Fix #315 # SQLAlchemy no longer uses dict_ # Ref: https://github.com/sqlalchemy/sqlalchemy/commit/428ea01f00a9cc7f85e435018565eb6da7af1b77 # Tag: 1.4.36 DeclarativeMeta.__init__(cls, classname, bases, dict_, **kw) else: ModelMetaclass.__init__(cls, classname, bases, dict_, **kw) def get_sqlalchemy_type(field: Any) -> Any: field_info = field sa_type = _get_sqlmodel_field_value(field_info, "sa_type", Undefined) # noqa: B009 if sa_type is not Undefined: return sa_type type_ = get_sa_type_from_field(field) metadata = get_field_metadata(field) # Check enums first as an enum can also be a str, needed by Pydantic/FastAPI if issubclass(type_, Enum): return sa_Enum(type_) if issubclass( type_, ( str, ipaddress.IPv4Address, ipaddress.IPv4Network, ipaddress.IPv6Address, ipaddress.IPv6Network, Path, EmailStr, ), ): max_length = getattr(metadata, "max_length", None) if max_length: return AutoString(length=max_length) return AutoString if issubclass(type_, float): return Float if issubclass(type_, bool): return Boolean if issubclass(type_, int): return Integer if issubclass(type_, datetime): return DateTime if issubclass(type_, date): return Date if issubclass(type_, timedelta): return Interval if issubclass(type_, time): return Time if issubclass(type_, bytes): return LargeBinary if issubclass(type_, Decimal): return Numeric( precision=getattr(metadata, "max_digits", None), scale=getattr(metadata, "decimal_places", None), ) if issubclass(type_, uuid.UUID): return Uuid raise ValueError(f"{type_} has no matching SQLAlchemy type") def get_column_from_field(field: Any) -> Column: # type: ignore field_info = field sa_column = _get_sqlmodel_field_value(field_info, "sa_column", Undefined) if isinstance(sa_column, Column): return sa_column sa_type = get_sqlalchemy_type(field) primary_key = _get_sqlmodel_field_value(field_info, "primary_key", Undefined) if primary_key is Undefined: primary_key = False index = _get_sqlmodel_field_value(field_info, "index", Undefined) if index is Undefined: index = False nullable = not primary_key and is_field_noneable(field) # Override derived nullability if the nullable property is set explicitly # on the field field_nullable = _get_sqlmodel_field_value(field_info, "nullable", Undefined) if field_nullable is not Undefined: assert not isinstance(field_nullable, UndefinedType) nullable = field_nullable args = [] foreign_key = _get_sqlmodel_field_value(field_info, "foreign_key", Undefined) if foreign_key is Undefined: foreign_key = None unique = _get_sqlmodel_field_value(field_info, "unique", Undefined) if unique is Undefined: unique = False if foreign_key: ondelete_value = _get_sqlmodel_field_value(field_info, "ondelete", Undefined) if ondelete_value is Undefined: ondelete_value = None if ondelete_value == "SET NULL" and not nullable: raise RuntimeError('ondelete="SET NULL" requires nullable=True') assert isinstance(foreign_key, str) assert isinstance(ondelete_value, (str, type(None))) # for typing args.append(ForeignKey(foreign_key, ondelete=ondelete_value)) kwargs = { "primary_key": primary_key, "nullable": nullable, "index": index, "unique": unique, } sa_default = Undefined if field_info.default_factory: sa_default = field_info.default_factory elif field_info.default is not Undefined: sa_default = field_info.default if sa_default is not Undefined: kwargs["default"] = sa_default sa_column_args = _get_sqlmodel_field_value(field_info, "sa_column_args", Undefined) if sa_column_args is not Undefined: args.extend(list(cast(Sequence[Any], sa_column_args))) sa_column_kwargs = _get_sqlmodel_field_value( field_info, "sa_column_kwargs", Undefined ) if sa_column_kwargs is not Undefined: kwargs.update(cast(dict[Any, Any], sa_column_kwargs)) return Column(sa_type, *args, **kwargs) class_registry = weakref.WeakValueDictionary() # type: ignore default_registry = registry() _TSQLModel = TypeVar("_TSQLModel", bound="SQLModel") class SQLModel(BaseModel, metaclass=SQLModelMetaclass, registry=default_registry): # SQLAlchemy needs to set weakref(s), Pydantic will set the other slots values __slots__ = ("__weakref__",) __tablename__: ClassVar[str | Callable[..., str]] __sqlmodel_relationships__: ClassVar[builtins.dict[str, RelationshipProperty[Any]]] __name__: ClassVar[str] metadata: ClassVar[MetaData] __allow_unmapped__ = True # https://docs.sqlalchemy.org/en/20/changelog/migration_20.html#migration-20-step-six model_config = SQLModelConfig(from_attributes=True) def __new__(cls, *args: Any, **kwargs: Any) -> Any: new_object = super().__new__(cls) # SQLAlchemy doesn't call __init__ on the base class when querying from DB # Ref: https://docs.sqlalchemy.org/en/14/orm/constructors.html # Set __fields_set__ here, that would have been set when calling __init__ # in the Pydantic model so that when SQLAlchemy sets attributes that are # added (e.g. when querying from DB) to the __fields_set__, this already exists init_pydantic_private_attrs(new_object) return new_object def __init__(__pydantic_self__, **data: Any) -> None: # Uses something other than `self` the first arg to allow "self" as a # settable attribute # SQLAlchemy does very dark black magic and modifies the __init__ method in # sqlalchemy.orm.instrumentation._generate_init() # so, to make SQLAlchemy work, it's needed to explicitly call __init__ to # trigger all the SQLAlchemy logic, it doesn't work using cls.__new__, setting # attributes obj.__dict__, etc. The __init__ method has to be called. But # there are cases where calling all the default logic is not ideal, e.g. # when calling Model.model_validate(), as the validation is done outside # of instance creation. # At the same time, __init__ is what users would normally call, by creating # a new instance, which should have validation and all the default logic. # So, to be able to set up the internal SQLAlchemy logic alone without # executing the rest, and support things like Model.model_validate(), we # use a contextvar to know if we should execute everything. if finish_init.get(): sqlmodel_init(self=__pydantic_self__, data=data) def __setattr__(self, name: str, value: Any) -> None: if name in {"_sa_instance_state"}: self.__dict__[name] = value return else: # Set in SQLAlchemy, before Pydantic to trigger events and updates if is_table_model_class(self.__class__) and is_instrumented(self, name): # type: ignore[no-untyped-call] set_attribute(self, name, value) # Set in Pydantic model to trigger possible validation changes, only for # non relationship values if name not in self.__sqlmodel_relationships__: super().__setattr__(name, value) def __repr_args__(self) -> Sequence[tuple[str | None, Any]]: # Don't show SQLAlchemy private attributes return [ (k, v) for k, v in super().__repr_args__() if not (isinstance(k, str) and k.startswith("_sa_")) ] @declared_attr # type: ignore def __tablename__(cls) -> str: return cls.__name__.lower() @classmethod def model_validate( # type: ignore[override] cls: type[_TSQLModel], obj: Any, *, strict: bool | None = None, from_attributes: bool | None = None, context: builtins.dict[str, Any] | None = None, update: builtins.dict[str, Any] | None = None, ) -> _TSQLModel: return sqlmodel_validate( cls=cls, obj=obj, strict=strict, from_attributes=from_attributes, context=context, update=update, ) def model_dump( self, *, mode: Literal["json", "python"] | str = "python", include: IncEx | None = None, exclude: IncEx | None = None, context: Any | None = None, # v2.7 by_alias: bool | None = None, exclude_unset: bool = False, exclude_defaults: bool = False, exclude_none: bool = False, exclude_computed_fields: bool = False, # v2.12 round_trip: bool = False, warnings: bool | Literal["none", "warn", "error"] = True, fallback: Callable[[Any], Any] | None = None, # v2.11 serialize_as_any: bool = False, # v2.7 ) -> builtins.dict[str, Any]: if PYDANTIC_MINOR_VERSION < (2, 11): by_alias = by_alias or False extra_kwargs: dict[str, Any] = {} extra_kwargs["context"] = context extra_kwargs["serialize_as_any"] = serialize_as_any if PYDANTIC_MINOR_VERSION >= (2, 11): extra_kwargs["fallback"] = fallback if PYDANTIC_MINOR_VERSION >= (2, 12): extra_kwargs["exclude_computed_fields"] = exclude_computed_fields return super().model_dump( mode=mode, include=include, exclude=exclude, by_alias=by_alias, exclude_unset=exclude_unset, exclude_defaults=exclude_defaults, exclude_none=exclude_none, round_trip=round_trip, warnings=warnings, **extra_kwargs, ) @deprecated( """ 🚨 `obj.dict()` was deprecated in SQLModel 0.0.14, you should instead use `obj.model_dump()`. """ ) def dict( self, *, include: IncEx | None = None, exclude: IncEx | None = None, by_alias: bool = False, exclude_unset: bool = False, exclude_defaults: bool = False, exclude_none: bool = False, ) -> builtins.dict[str, Any]: return self.model_dump( include=include, exclude=exclude, by_alias=by_alias, exclude_unset=exclude_unset, exclude_defaults=exclude_defaults, exclude_none=exclude_none, ) @classmethod @deprecated( """ 🚨 `obj.from_orm(data)` was deprecated in SQLModel 0.0.14, you should instead use `obj.model_validate(data)`. """ ) def from_orm( cls: type[_TSQLModel], obj: Any, update: builtins.dict[str, Any] | None = None, ) -> _TSQLModel: return cls.model_validate(obj, update=update) @classmethod @deprecated( """ 🚨 `obj.parse_obj(data)` was deprecated in SQLModel 0.0.14, you should instead use `obj.model_validate(data)`. """ ) def parse_obj( cls: type[_TSQLModel], obj: Any, update: builtins.dict[str, Any] | None = None, ) -> _TSQLModel: return cls.model_validate(obj, update=update) def sqlmodel_update( self: _TSQLModel, obj: builtins.dict[str, Any] | BaseModel, *, update: builtins.dict[str, Any] | None = None, ) -> _TSQLModel: use_update = (update or {}).copy() if isinstance(obj, dict): for key, value in {**obj, **use_update}.items(): if key in get_model_fields(self): setattr(self, key, value) elif isinstance(obj, BaseModel): for key in get_model_fields(obj): if key in use_update: value = use_update.pop(key) else: value = getattr(obj, key) setattr(self, key, value) for remaining_key, value in use_update.items(): if remaining_key in get_model_fields(self): setattr(self, remaining_key, value) else: raise ValueError( "Can't use sqlmodel_update() with something that " f"is not a dict or SQLModel or Pydantic model: {obj}" ) return self