Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions sqlmodel/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -613,11 +613,11 @@ def get_config(name: str) -> Any:
# TODO: remove this in the future
new_cls.model_config["read_with_orm_mode"] = True # ty: ignore[invalid-key]

config_registry = get_config("registry")
config_registry = kwargs.get("registry", Undefined)
if config_registry is not Undefined:
config_registry = cast(registry, config_registry)
# If it was passed by kwargs, ensure it's also set in config
new_cls.model_config["registry"] = config_table
new_cls.model_config["registry"] = config_registry
setattr(new_cls, "_sa_registry", config_registry) # noqa: B010
setattr(new_cls, "metadata", config_registry.metadata) # noqa: B010
setattr(new_cls, "__abstract__", True) # noqa: B010
Expand Down
30 changes: 30 additions & 0 deletions tests/test_registry_kwarg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
"""Tests for SQLModelMetaclass.__new__ registry kwarg handling."""

from sqlalchemy.orm import registry
from sqlmodel import SQLModel


def test_custom_registry_base_stores_registry_in_model_config() -> None:
"""model_config['registry'] must hold the registry object passed as kwarg."""
custom_registry = registry()

class MyBase(SQLModel, registry=custom_registry):
pass

stored = MyBase.model_config.get("registry")
assert stored is custom_registry, (
f"model_config['registry'] should be the custom registry, got {stored!r}"
)


def test_custom_registry_base_sets_sa_registry() -> None:
"""_sa_registry must reference the registry object passed as kwarg."""
custom_registry = registry()

class MyBase2(SQLModel, registry=custom_registry):
pass

sa_registry = getattr(MyBase2, "_sa_registry", None)
assert sa_registry is custom_registry, (
f"_sa_registry should be the custom registry, got {sa_registry!r}"
)
Loading