Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions sqlmodel/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -613,11 +613,11 @@ def get_config(name: str) -> Any:
# TODO: remove this in the future
new_cls.model_config["read_with_orm_mode"] = True # ty: ignore[invalid-key]

config_registry = get_config("registry")
config_registry = kwargs.get("registry", Undefined)
if config_registry is not Undefined:
config_registry = cast(registry, config_registry)
# If it was passed by kwargs, ensure it's also set in config
new_cls.model_config["registry"] = config_table
new_cls.model_config["registry"] = config_registry
setattr(new_cls, "_sa_registry", config_registry) # noqa: B010
setattr(new_cls, "metadata", config_registry.metadata) # noqa: B010
setattr(new_cls, "__abstract__", True) # noqa: B010
Expand Down
30 changes: 30 additions & 0 deletions tests/test_registry_kwarg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
"""Tests for SQLModelMetaclass.__new__ registry kwarg handling."""
import pytest
from sqlalchemy.orm import registry
from sqlmodel import SQLModel


def test_custom_registry_base_stores_registry_in_model_config() -> None:
"""model_config['registry'] must hold the registry object passed as kwarg."""
custom_registry = registry()

class MyBase(SQLModel, registry=custom_registry):
pass

stored = MyBase.model_config.get("registry")
assert stored is custom_registry, (
f"model_config['registry'] should be the custom registry, got {stored!r}"
)


def test_custom_registry_base_sets_sa_registry() -> None:
"""_sa_registry must reference the registry object passed as kwarg."""
custom_registry = registry()

class MyBase2(SQLModel, registry=custom_registry):
pass

sa_registry = getattr(MyBase2, "_sa_registry", None)
assert sa_registry is custom_registry, (
f"_sa_registry should be the custom registry, got {sa_registry!r}"
)
Loading