From c2945249e46dc04d693071f22b0f4aec854d1d59 Mon Sep 17 00:00:00 2001 From: Hadrien David Date: Sun, 16 Feb 2025 16:42:04 -0500 Subject: [PATCH 1/3] feat: add support for SQLModel --- .github/tox-uv/action.yml | 9 ++ .github/workflows/ci.yml | 13 ++- pyproject.toml | 20 +++- src/fastsqla.py | 11 ++- tests/conftest.py | 24 ++++- tests/integration/test_sqlmodel.py | 151 +++++++++++++++++++++++++++++ uv.lock | 21 +++- 7 files changed, 240 insertions(+), 9 deletions(-) create mode 100644 .github/tox-uv/action.yml create mode 100644 tests/integration/test_sqlmodel.py diff --git a/.github/tox-uv/action.yml b/.github/tox-uv/action.yml new file mode 100644 index 0000000..cd667b5 --- /dev/null +++ b/.github/tox-uv/action.yml @@ -0,0 +1,9 @@ +name: Setup tox-uv +description: Setup tox-uv tool so tox uses uv to install dependencies +runs: + using: composite + steps: + - name: โšก๏ธ setup uv + uses: .github/uv + run: uv tool install tox --with tox-uv + diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 0f8ebf0..700c14e 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -6,16 +6,21 @@ on: jobs: Tests: + strategy: + matrix: + tox_env: [default, sqlmodel] runs-on: ubuntu-latest steps: - name: ๐Ÿ“ฅ checkout uses: actions/checkout@v4 - - name: ๐Ÿ”ง setup uv - uses: ./.github/uv - - name: ๐Ÿงช pytest - run: uv run pytest --cov fastsqla --cov-report=term-missing --cov-report=xml + - name: ๐Ÿ”ง setup tox-uv + uses: ./.github/tox-uv + - name: ๐Ÿงช tox -e ${{ matrix.tox_env }} + run: uv run tox -e ${{ matrix.tox_env }} - name: "๐Ÿ” codecov: upload test coverage" uses: codecov/codecov-action@v4.2.0 + with: + flags: ${{ matrix.tox_env }} env: CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} diff --git a/pyproject.toml b/pyproject.toml index cb07363..8477397 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,6 +48,7 @@ docs = [ "mkdocs-material>=9.5.50", "mkdocstrings[python]>=0.27.0", ] +sqlmodel = ["sqlmodel>=0.0.22"] [tool.uv] package = true @@ -71,7 +72,10 @@ dev-dependencies = [ pytest-watch = { git = "https://github.com/styleseat/pytest-watch", rev = "0342193" } [tool.pytest.ini_options] -asyncio_mode = 'auto' +asyncio_mode = "auto" +asyncio_default_fixture_loop_scope = "function" + +filterwarnings = ["ignore::DeprecationWarning:"] [tool.coverage.run] branch = true @@ -86,3 +90,17 @@ version_toml = ["pyproject.toml:project.version"] [tool.semantic_release.changelog.default_templates] changelog_file = "./docs/changelog.md" + +[tool.tox] +legacy_tox_ini = """ +[tox] +envlist = { default, sqlmodel } + +[testenv] +passenv = CI +runner = uv-venv-lock-runner +commands = + pytest --cov fastsqla --cov-report=term-missing --cov-report=xml +extras: + sqlmodel: sqlmodel +""" diff --git a/src/fastsqla.py b/src/fastsqla.py index 636afd4..1b91e66 100644 --- a/src/fastsqla.py +++ b/src/fastsqla.py @@ -17,6 +17,15 @@ from sqlalchemy.orm import DeclarativeBase from structlog import get_logger +logger = get_logger(__name__) + +try: + from sqlmodel.ext.asyncio.session import AsyncSession + +except ImportError: + pass + + __all__ = [ "Base", "Collection", @@ -30,7 +39,7 @@ "open_session", ] -SessionFactory = async_sessionmaker(expire_on_commit=False) +SessionFactory = async_sessionmaker(expire_on_commit=False, class_=AsyncSession) logger = get_logger(__name__) diff --git a/tests/conftest.py b/tests/conftest.py index 1d6335b..0daebf9 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,9 +1,15 @@ from unittest.mock import patch -from pytest import fixture +from pytest import fixture, skip from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine +def pytest_configure(config): + config.addinivalue_line( + "markers", "require_sqlmodel: skip test when sqlmodel is not installed." + ) + + @fixture def environ(tmp_path): values = { @@ -38,3 +44,19 @@ def tear_down(): Base.metadata.clear() clear_mappers() + + +try: + import sqlmodel # noqa +except ImportError: + is_sqlmodel_installed = False +else: + is_sqlmodel_installed = True + + +@fixture(autouse=True) +def check_sqlmodel(request): + """Skip test marked with mark.require_sqlmodel if sqlmodel is not installed.""" + marker = request.node.get_closest_marker("require_sqlmodel") + if marker and not is_sqlmodel_installed: + skip(f"{request.node.nodeid} requires sqlmodel which is not installed.") diff --git a/tests/integration/test_sqlmodel.py b/tests/integration/test_sqlmodel.py new file mode 100644 index 0000000..7fb8c30 --- /dev/null +++ b/tests/integration/test_sqlmodel.py @@ -0,0 +1,151 @@ +from http import HTTPStatus + +from fastapi import HTTPException +from pytest import fixture, mark +from sqlalchemy import insert, select, text +from sqlalchemy.exc import IntegrityError +from sqlalchemy.ext.automap import automap_base + + +pytestmark = mark.require_sqlmodel + + +@fixture +def heros_data(): + return [ + ("Superman", "Clark Kent", 30), + ("Batman", "Bruce Wayne", 35), + ("Wonder Woman", "Diana Prince", 30), + ("Iron Man", "Tony Stark", 45), + ("Spider-Man", "Peter Parker", 25), + ("Captain America", "Steve Rogers", 100), + ("Black Widow", "Natasha Romanoff", 35), + ("Thor", "Thor Odinson", 1500), + ("Scarlet Witch", "Wanda Maximoff", 30), + ("Doctor Strange", "Stephen Strange", 40), + ("The Flash", "Barry Allen", 28), + ("Green Lantern", "Hal Jordan", 35), + ] + + +@fixture(autouse=True) +async def setup_tear_down(engine, heros_data): + Base = automap_base() + async with engine.connect() as conn: + await conn.execute( + text(""" + CREATE TABLE hero ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + name TEXT NOT NULL UNIQUE, + secret_identity TEXT NOT NULL, + age INTEGER NOT NULL + ) + """) + ) + + await conn.run_sync(Base.prepare) + + Hero = Base.classes.hero + + stmt = insert(Hero).values( + [ + dict(name=name, secret_identity=secret_identity, age=age) + for name, secret_identity, age in heros_data + ] + ) + await conn.execute(stmt) + await conn.commit() + yield + await conn.execute(text("DROP TABLE hero")) + + +@fixture +async def app(setup_tear_down, app): + from fastsqla import Item, Page, Paginate, Session + from sqlmodel import Field, SQLModel + + class Hero(SQLModel, table=True): + __table_args__ = {"extend_existing": True} + id: int | None = Field(default=None, primary_key=True) + name: str + secret_identity: str + age: int + + @app.get("/heroes", response_model=Page[Hero]) + async def get_heroes(paginate: Paginate): + return await paginate(select(Hero)) + + @app.get("/heroes/{hero_id}", response_model=Item[Hero]) + async def get_hero(session: Session, hero_id: int): + hero = await session.get(Hero, hero_id) + if hero is None: + raise HTTPException(status_code=HTTPStatus.NOT_FOUND) + return {"data": hero} + + @app.post("/heroes", response_model=Item[Hero]) + async def create_hero(session: Session, hero: Hero): + session.add(hero) + try: + await session.flush() + except IntegrityError: + raise HTTPException(status_code=HTTPStatus.CONFLICT) + return {"data": hero} + + return app + + +@mark.parametrize("offset, page_number, items_count", [[0, 1, 10], [10, 2, 2]]) +async def test_pagination(client, heros_data, offset, page_number, items_count): + res = await client.get("/heroes", params={"offset": offset}) + assert res.status_code == 200, (res.status_code, res.content) + + payload = res.json() + assert "data" in payload + data = payload["data"] + assert len(data) == items_count + + for i, hero in enumerate(data): + name, secret_identity, age = heros_data[i + offset] + assert hero["id"] + assert hero["name"] == name + assert hero["secret_identity"] == secret_identity + assert hero["age"] == age + + assert "meta" in payload + assert payload["meta"]["total_items"] == 12 + assert payload["meta"]["total_pages"] == 2 + assert payload["meta"]["offset"] == offset + assert payload["meta"]["page_number"] == page_number + + +async def test_getting_an_entity_with_session_dependency(client, heros_data): + res = await client.get("/heroes/1") + assert res.status_code == 200, (res.status_code, res.content) + + payload = res.json() + assert "data" in payload + data = payload["data"] + + name, secret_identity, age = heros_data[0] + assert data["id"] == 1 + assert data["name"] == name + assert data["secret_identity"] == secret_identity + assert data["age"] == age + + +async def test_creating_an_entity_with_session_dependency(client): + hero = {"name": "Hulk", "secret_identity": "Bruce Banner", "age": 37} + res = await client.post("/heroes", json=hero) + assert res.status_code == 200, (res.status_code, res.content) + + data = res.json()["data"] + assert data["id"] == 13 + assert data["name"] == hero["name"] + assert data["secret_identity"] == hero["secret_identity"] + assert data["age"] == hero["age"] + + +async def test_creating_an_entity_with_conflict(client): + hero = {"name": "Superman", "secret_identity": "Clark Kent", "age": 30} + res = await client.post("/heroes", json=hero) + assert res.status_code == HTTPStatus.CONFLICT, (res.status_code, res.content) diff --git a/uv.lock b/uv.lock index 3976e36..3220be3 100644 --- a/uv.lock +++ b/uv.lock @@ -123,7 +123,7 @@ name = "click" version = "8.1.7" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "colorama", marker = "platform_system == 'Windows'" }, + { name = "colorama", marker = "sys_platform == 'win32'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/96/d3/f04c7bfcf5c1862a2a5b845c6b2b360488cf47af55dfa79c98f6a6bf98b5/click-8.1.7.tar.gz", hash = "sha256:ca9853ad459e787e2192211578cc907e7594e294c7ccc834310722b41b9ca6de", size = 336121 } wheels = [ @@ -278,6 +278,9 @@ docs = [ { name = "mkdocs-material" }, { name = "mkdocstrings", extra = ["python"] }, ] +sqlmodel = [ + { name = "sqlmodel" }, +] [package.dev-dependencies] dev = [ @@ -303,6 +306,7 @@ requires-dist = [ { name = "mkdocs-material", marker = "extra == 'docs'", specifier = ">=9.5.50" }, { name = "mkdocstrings", extras = ["python"], marker = "extra == 'docs'", specifier = ">=0.27.0" }, { name = "sqlalchemy", extras = ["asyncio"], specifier = ">=2.0.37" }, + { name = "sqlmodel", marker = "extra == 'sqlmodel'", specifier = ">=0.0.22" }, { name = "structlog", specifier = ">=24.4.0" }, ] @@ -615,7 +619,7 @@ version = "1.6.1" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "click" }, - { name = "colorama", marker = "platform_system == 'Windows'" }, + { name = "colorama", marker = "sys_platform == 'win32'" }, { name = "ghp-import" }, { name = "jinja2" }, { name = "markdown" }, @@ -1261,6 +1265,19 @@ asyncio = [ { name = "greenlet" }, ] +[[package]] +name = "sqlmodel" +version = "0.0.22" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pydantic" }, + { name = "sqlalchemy" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/b5/39/8641040ab0d5e1d8a1c2325ae89a01ae659fc96c61a43d158fb71c9a0bf0/sqlmodel-0.0.22.tar.gz", hash = "sha256:7d37c882a30c43464d143e35e9ecaf945d88035e20117bf5ec2834a23cbe505e", size = 116392 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/dd/b1/3af5104b716c420e40a6ea1b09886cae3a1b9f4538343875f637755cae5b/sqlmodel-0.0.22-py3-none-any.whl", hash = "sha256:a1ed13e28a1f4057cbf4ff6cdb4fc09e85702621d3259ba17b3c230bfb2f941b", size = 28276 }, +] + [[package]] name = "starlette" version = "0.41.3" From f179e1fb79f7204742052f82f1e9c9fe7a5a7354 Mon Sep 17 00:00:00 2001 From: Hadrien David Date: Sun, 16 Feb 2025 16:53:50 -0500 Subject: [PATCH 2/3] trying to fix tox-uv github action --- .github/tox-uv/action.yml | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/.github/tox-uv/action.yml b/.github/tox-uv/action.yml index cd667b5..c356cb8 100644 --- a/.github/tox-uv/action.yml +++ b/.github/tox-uv/action.yml @@ -4,6 +4,7 @@ runs: using: composite steps: - name: โšก๏ธ setup uv - uses: .github/uv + uses: ./.github/uv + - name: โš™๏ธ install tox-uv + shell: bash run: uv tool install tox --with tox-uv - From 5a19ea3c5e621a5a65b6ad32a34a1b8568325700 Mon Sep 17 00:00:00 2001 From: Hadrien David Date: Mon, 17 Feb 2025 13:53:41 -0500 Subject: [PATCH 3/3] document it --- README.md | 40 +++++++++++++++++++---- docs/orm.md | 4 +-- docs/pagination.md | 70 ++++++++++++++++++++++++++++++++++++++++ src/fastsqla.py | 79 ++++++++++++---------------------------------- 4 files changed, 127 insertions(+), 66 deletions(-) diff --git a/README.md b/README.md index c41dfce..cdc1f7c 100644 --- a/README.md +++ b/README.md @@ -15,8 +15,10 @@ _Async SQLAlchemy 2.0+ for FastAPI โ€” boilerplate, pagination, and seamless ses ----------------------------------------------------------------------------------------- -`FastSQLA` is an [`SQLAlchemy 2.0+`](https://docs.sqlalchemy.org/en/20/) extension for -[`FastAPI`](https://fastapi.tiangolo.com/). +`FastSQLA` is an async [`SQLAlchemy 2.0+`](https://docs.sqlalchemy.org/en/20/) +extension for [`FastAPI`](https://fastapi.tiangolo.com/) with built-in pagination, +[`SQLModel`](http://sqlmodel.tiangolo.com/) support and more. + It streamlines the configuration and asynchronous connection to relational databases by providing boilerplate and intuitive helpers. Additionally, it offers built-in customizable pagination and automatically manages the `SQLAlchemy` session lifecycle @@ -74,13 +76,13 @@ following [`SQLAlchemy`'s best practices](https://docs.sqlalchemy.org/en/20/orm/ async def get_heros(paginate:Paginate): return await paginate(select(Hero)) ``` - +
- + ๐Ÿ‘‡ `/heros?offset=10&limit=10` ๐Ÿ‘‡ - +
- + ```json { "data": [ @@ -119,6 +121,32 @@ following [`SQLAlchemy`'s best practices](https://docs.sqlalchemy.org/en/20/orm/ * Session lifecycle management: session is commited on request success or rollback on failure. +* [`SQLModel`](http://sqlmodel.tiangolo.com/) support: + ```python + ... + from fastsqla import Item, Page, Paginate, Session + from sqlmodel import Field, SQLModel + ... + + class Hero(SQLModel, table=True): + id: int | None = Field(default=None, primary_key=True) + name: str + secret_identity: str + age: int + + + @app.get("/heroes", response_model=Page[Hero]) + async def get_heroes(paginate: Paginate): + return await paginate(select(Hero)) + + + @app.get("/heroes/{hero_id}", response_model=Item[Hero]) + async def get_hero(session: Session, hero_id: int): + hero = await session.get(Hero, hero_id) + if hero is None: + raise HTTPException(status_code=HTTPStatus.NOT_FOUND) + return {"data": hero} + ``` ## Installing diff --git a/docs/orm.md b/docs/orm.md index f4c9e4f..f508067 100644 --- a/docs/orm.md +++ b/docs/orm.md @@ -5,5 +5,5 @@ ::: fastsqla.Base options: heading_level: false - show_source: false - show_bases: false + show_source: true + show_bases: true diff --git a/docs/pagination.md b/docs/pagination.md index 38fcf45..7b43482 100644 --- a/docs/pagination.md +++ b/docs/pagination.md @@ -6,3 +6,73 @@ options: heading_level: false show_source: false + +### `SQLAlchemy` example + +``` py title="example.py" hl_lines="25 26 27" +from fastapi import FastAPI +from fastsqla import Base, Paginate, Page, lifespan +from pydantic import BaseModel +from sqlalchemy import select +from sqlalchemy.orm import Mapped, mapped_column + +app = FastAPI(lifespan=lifespan) + +class Hero(Base): + __tablename__ = "hero" + id: Mapped[int] = mapped_column(primary_key=True) + name: Mapped[str] = mapped_column(unique=True) + secret_identity: Mapped[str] + age: Mapped[int] + + +class HeroModel(HeroBase): + model_config = ConfigDict(from_attributes=True) + id: int + name: str + secret_identity: str + age: int + + +@app.get("/heros", response_model=Page[HeroModel]) # (1)! +async def list_heros(paginate: Paginate): # (2)! + return await paginate(select(Hero)) # (3)! +``` + +1. The endpoint returns a `Page` model of `HeroModel`. +2. Just define an argument with type `Paginate` to get an async `paginate` function + injected in your endpoint function. +3. Await the `paginate` function with the `SQLAlchemy` select statement to get the + paginated result. + +To add filtering, just add whatever query parameters you need to the endpoint: + +```python +@app.get("/heros", response_model=Page[HeroModel]) +async def list_heros(paginate: Paginate, age:int | None = None): + stmt = select(Hero) + if age: + stmt = stmt.where(Hero.age == age) + return await paginate(stmt) +``` + +### `SQLModel` example + +```python +from fastapi import FastAPI +from fastsqla import Page, Paginate, Session +from sqlmodel import Field, SQLModel +from sqlalchemy import select + + +class Hero(SQLModel, table=True): + id: int | None = Field(default=None, primary_key=True) + name: str + secret_identity: str + age: int + + +@app.get("/heroes", response_model=Page[Hero]) +async def get_heroes(paginate: Paginate): + return await paginate(select(Hero)) +``` diff --git a/src/fastsqla.py b/src/fastsqla.py index 1b91e66..1a42e37 100644 --- a/src/fastsqla.py +++ b/src/fastsqla.py @@ -65,6 +65,10 @@ class Hero(Base): * [ORM Quick Start](https://docs.sqlalchemy.org/en/20/orm/quickstart.html) * [Declarative Mapping](https://docs.sqlalchemy.org/en/20/orm/mapping_styles.html#declarative-mapping) + + !!! note + + You don't need this if you use [`SQLModel`](http://sqlmodel.tiangolo.com/). """ __abstract__ = True @@ -151,7 +155,7 @@ async def lifespan(app:FastAPI) -> AsyncGenerator[dict, None]: @asynccontextmanager async def open_session() -> AsyncGenerator[AsyncSession, None]: - """An asynchronous context manager that opens a new `SQLAlchemy` async session. + """Async context manager that opens a new `SQLAlchemy` or `SQLModel` async session. To the contrary of the [`Session`][fastsqla.Session] dependency which can only be used in endpoints, `open_session` can be used anywhere such as in background tasks. @@ -161,6 +165,16 @@ async def open_session() -> AsyncGenerator[AsyncSession, None]: In all cases, it closes the session and returns the associated connection to the connection pool. + + Returns: + When `SQLModel` is not installed, an async generator that yields an + [`SQLAlchemy AsyncSession`][sqlalchemy.ext.asyncio.AsyncSession]. + + When `SQLModel` is installed, an async generator that yields an + [`SQLModel AsyncSession`](https://github.com/fastapi/sqlmodel/blob/main/sqlmodel/ext/asyncio/session.py#L32) + which inherits from [`SQLAlchemy AsyncSession`][sqlalchemy.ext.asyncio.AsyncSession]. + + ```python from fastsqla import open_session @@ -200,12 +214,12 @@ async def new_session() -> AsyncGenerator[AsyncSession, None]: Session = Annotated[AsyncSession, Depends(new_session)] -"""A dependency used exclusively in endpoints to get an `SQLAlchemy` session. +"""Dependency used exclusively in endpoints to get an `SQLAlchemy` or `SQLModel` session. `Session` is a [`FastAPI` dependency](https://fastapi.tiangolo.com/tutorial/dependencies/) -that provides an asynchronous `SQLAlchemy` session. +that provides an asynchronous `SQLAlchemy` session or `SQLModel` one if it's installed. By defining an argument with type `Session` in an endpoint, `FastAPI` will automatically -inject an `SQLAlchemy` async session into the endpoint. +inject an async session into the endpoint. At the end of request handling: @@ -345,9 +359,9 @@ async def paginate(stmt: Select) -> Page: Paginate = Annotated[PaginateType[T], Depends(new_pagination())] """A dependency used in endpoints to paginate `SQLAlchemy` select queries. -It adds `offset`and `limit` query parameters to the endpoint, which are used to paginate. -The model returned by the endpoint is a `Page` model. It contains a page of data and -metadata: +It adds **`offset`** and **`limit`** query parameters to the endpoint, which are used to +paginate. The model returned by the endpoint is a `Page` model. It contains a page of +data and metadata: ```json { @@ -360,55 +374,4 @@ async def paginate(stmt: Select) -> Page: } } ``` - ------ - -Example: -``` py title="example.py" hl_lines="22 23 25" -from fastsqla import Base, Paginate, Page -from pydantic import BaseModel - - -class Hero(Base): - __tablename__ = "hero" - - -class Hero(Base): - __tablename__ = "hero" - id: Mapped[int] = mapped_column(primary_key=True) - name: Mapped[str] = mapped_column(unique=True) - secret_identity: Mapped[str] - age: Mapped[int] - - -class HeroModel(HeroBase): - model_config = ConfigDict(from_attributes=True) - id: int - - -@app.get("/heros", response_model=Page[HeroModel]) # (1)! -async def list_heros(paginate: Paginate): # (2)! - stmt = select(Hero) - return await paginate(stmt) # (3)! -``` - -1. The endpoint returns a `Page` model of `HeroModel`. -2. Just define an argument with type `Paginate` to get an async `paginate` function - injected in your endpoint function. -3. Await the `paginate` function with the `SQLAlchemy` select statement to get the - paginated result. - -To add filtering, just add whatever query parameters you need to the endpoint: - -```python - -from fastsqla import Paginate, Page - -@app.get("/heros", response_model=Page[HeroModel]) -async def list_heros(paginate: Paginate, age:int | None = None): - stmt = select(Hero) - if age: - stmt = stmt.where(Hero.age == age) - return await paginate(stmt) -``` """