Skip to content

feat: add support for SQLModel #18

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Feb 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions .github/tox-uv/action.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
name: Setup tox-uv
description: Setup tox-uv tool so tox uses uv to install dependencies
runs:
using: composite
steps:
- name: ⚡️ setup uv
uses: ./.github/uv
- name: ⚙️ install tox-uv
shell: bash
run: uv tool install tox --with tox-uv
13 changes: 9 additions & 4 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,21 @@ on:
jobs:

Tests:
strategy:
matrix:
tox_env: [default, sqlmodel]
runs-on: ubuntu-latest
steps:
- name: 📥 checkout
uses: actions/checkout@v4
- name: 🔧 setup uv
uses: ./.github/uv
- name: 🧪 pytest
run: uv run pytest --cov fastsqla --cov-report=term-missing --cov-report=xml
- name: 🔧 setup tox-uv
uses: ./.github/tox-uv
- name: 🧪 tox -e ${{ matrix.tox_env }}
run: uv run tox -e ${{ matrix.tox_env }}
- name: "🐔 codecov: upload test coverage"
uses: codecov/[email protected]
with:
flags: ${{ matrix.tox_env }}
env:
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}

Expand Down
40 changes: 34 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,10 @@ _Async SQLAlchemy 2.0+ for FastAPI — boilerplate, pagination, and seamless ses

-----------------------------------------------------------------------------------------

`FastSQLA` is an [`SQLAlchemy 2.0+`](https://docs.sqlalchemy.org/en/20/) extension for
[`FastAPI`](https://fastapi.tiangolo.com/).
`FastSQLA` is an async [`SQLAlchemy 2.0+`](https://docs.sqlalchemy.org/en/20/)
extension for [`FastAPI`](https://fastapi.tiangolo.com/) with built-in pagination,
[`SQLModel`](http://sqlmodel.tiangolo.com/) support and more.

It streamlines the configuration and asynchronous connection to relational databases by
providing boilerplate and intuitive helpers. Additionally, it offers built-in
customizable pagination and automatically manages the `SQLAlchemy` session lifecycle
Expand Down Expand Up @@ -74,13 +76,13 @@ following [`SQLAlchemy`'s best practices](https://docs.sqlalchemy.org/en/20/orm/
async def get_heros(paginate:Paginate):
return await paginate(select(Hero))
```

<center>

👇 `/heros?offset=10&limit=10` 👇

</center>

```json
{
"data": [
Expand Down Expand Up @@ -119,6 +121,32 @@ following [`SQLAlchemy`'s best practices](https://docs.sqlalchemy.org/en/20/orm/
* Session lifecycle management: session is commited on request success or rollback on
failure.

* [`SQLModel`](http://sqlmodel.tiangolo.com/) support:
```python
...
from fastsqla import Item, Page, Paginate, Session
from sqlmodel import Field, SQLModel
...

class Hero(SQLModel, table=True):
id: int | None = Field(default=None, primary_key=True)
name: str
secret_identity: str
age: int


@app.get("/heroes", response_model=Page[Hero])
async def get_heroes(paginate: Paginate):
return await paginate(select(Hero))


@app.get("/heroes/{hero_id}", response_model=Item[Hero])
async def get_hero(session: Session, hero_id: int):
hero = await session.get(Hero, hero_id)
if hero is None:
raise HTTPException(status_code=HTTPStatus.NOT_FOUND)
return {"data": hero}
```

## Installing

Expand Down
4 changes: 2 additions & 2 deletions docs/orm.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,5 @@
::: fastsqla.Base
options:
heading_level: false
show_source: false
show_bases: false
show_source: true
show_bases: true
70 changes: 70 additions & 0 deletions docs/pagination.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,73 @@
options:
heading_level: false
show_source: false

### `SQLAlchemy` example

``` py title="example.py" hl_lines="25 26 27"
from fastapi import FastAPI
from fastsqla import Base, Paginate, Page, lifespan
from pydantic import BaseModel
from sqlalchemy import select
from sqlalchemy.orm import Mapped, mapped_column

app = FastAPI(lifespan=lifespan)

class Hero(Base):
__tablename__ = "hero"
id: Mapped[int] = mapped_column(primary_key=True)
name: Mapped[str] = mapped_column(unique=True)
secret_identity: Mapped[str]
age: Mapped[int]


class HeroModel(HeroBase):
model_config = ConfigDict(from_attributes=True)
id: int
name: str
secret_identity: str
age: int


@app.get("/heros", response_model=Page[HeroModel]) # (1)!
async def list_heros(paginate: Paginate): # (2)!
return await paginate(select(Hero)) # (3)!
```

1. The endpoint returns a `Page` model of `HeroModel`.
2. Just define an argument with type `Paginate` to get an async `paginate` function
injected in your endpoint function.
3. Await the `paginate` function with the `SQLAlchemy` select statement to get the
paginated result.

To add filtering, just add whatever query parameters you need to the endpoint:

```python
@app.get("/heros", response_model=Page[HeroModel])
async def list_heros(paginate: Paginate, age:int | None = None):
stmt = select(Hero)
if age:
stmt = stmt.where(Hero.age == age)
return await paginate(stmt)
```

### `SQLModel` example

```python
from fastapi import FastAPI
from fastsqla import Page, Paginate, Session
from sqlmodel import Field, SQLModel
from sqlalchemy import select


class Hero(SQLModel, table=True):
id: int | None = Field(default=None, primary_key=True)
name: str
secret_identity: str
age: int


@app.get("/heroes", response_model=Page[Hero])
async def get_heroes(paginate: Paginate):
return await paginate(select(Hero))
```
20 changes: 19 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ docs = [
"mkdocs-material>=9.5.50",
"mkdocstrings[python]>=0.27.0",
]
sqlmodel = ["sqlmodel>=0.0.22"]

[tool.uv]
package = true
Expand All @@ -71,7 +72,10 @@ dev-dependencies = [
pytest-watch = { git = "https://github.com./styleseat/pytest-watch", rev = "0342193" }

[tool.pytest.ini_options]
asyncio_mode = 'auto'
asyncio_mode = "auto"
asyncio_default_fixture_loop_scope = "function"

filterwarnings = ["ignore::DeprecationWarning:"]

[tool.coverage.run]
branch = true
Expand All @@ -86,3 +90,17 @@ version_toml = ["pyproject.toml:project.version"]

[tool.semantic_release.changelog.default_templates]
changelog_file = "./docs/changelog.md"

[tool.tox]
legacy_tox_ini = """
[tox]
envlist = { default, sqlmodel }

[testenv]
passenv = CI
runner = uv-venv-lock-runner
commands =
pytest --cov fastsqla --cov-report=term-missing --cov-report=xml
extras:
sqlmodel: sqlmodel
"""
90 changes: 31 additions & 59 deletions src/fastsqla.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,15 @@
from sqlalchemy.orm import DeclarativeBase
from structlog import get_logger

logger = get_logger(__name__)

try:
from sqlmodel.ext.asyncio.session import AsyncSession

except ImportError:
pass


__all__ = [
"Base",
"Collection",
Expand All @@ -30,7 +39,7 @@
"open_session",
]

SessionFactory = async_sessionmaker(expire_on_commit=False)
SessionFactory = async_sessionmaker(expire_on_commit=False, class_=AsyncSession)

logger = get_logger(__name__)

Expand All @@ -56,6 +65,10 @@ class Hero(Base):

* [ORM Quick Start](https://docs.sqlalchemy.org/en/20/orm/quickstart.html)
* [Declarative Mapping](https://docs.sqlalchemy.org/en/20/orm/mapping_styles.html#declarative-mapping)

!!! note

You don't need this if you use [`SQLModel`](http://sqlmodel.tiangolo.com/).
"""

__abstract__ = True
Expand Down Expand Up @@ -142,7 +155,7 @@ async def lifespan(app:FastAPI) -> AsyncGenerator[dict, None]:

@asynccontextmanager
async def open_session() -> AsyncGenerator[AsyncSession, None]:
"""An asynchronous context manager that opens a new `SQLAlchemy` async session.
"""Async context manager that opens a new `SQLAlchemy` or `SQLModel` async session.

To the contrary of the [`Session`][fastsqla.Session] dependency which can only be
used in endpoints, `open_session` can be used anywhere such as in background tasks.
Expand All @@ -152,6 +165,16 @@ async def open_session() -> AsyncGenerator[AsyncSession, None]:
In all cases, it closes the session and returns the associated connection to the
connection pool.


Returns:
When `SQLModel` is not installed, an async generator that yields an
[`SQLAlchemy AsyncSession`][sqlalchemy.ext.asyncio.AsyncSession].

When `SQLModel` is installed, an async generator that yields an
[`SQLModel AsyncSession`](https://github.com./fastapi/sqlmodel/blob/main/sqlmodel/ext/asyncio/session.py#L32)
which inherits from [`SQLAlchemy AsyncSession`][sqlalchemy.ext.asyncio.AsyncSession].


```python
from fastsqla import open_session

Expand Down Expand Up @@ -191,12 +214,12 @@ async def new_session() -> AsyncGenerator[AsyncSession, None]:


Session = Annotated[AsyncSession, Depends(new_session)]
"""A dependency used exclusively in endpoints to get an `SQLAlchemy` session.
"""Dependency used exclusively in endpoints to get an `SQLAlchemy` or `SQLModel` session.

`Session` is a [`FastAPI` dependency](https://fastapi.tiangolo.com/tutorial/dependencies/)
that provides an asynchronous `SQLAlchemy` session.
that provides an asynchronous `SQLAlchemy` session or `SQLModel` one if it's installed.
By defining an argument with type `Session` in an endpoint, `FastAPI` will automatically
inject an `SQLAlchemy` async session into the endpoint.
inject an async session into the endpoint.

At the end of request handling:

Expand Down Expand Up @@ -336,9 +359,9 @@ async def paginate(stmt: Select) -> Page:
Paginate = Annotated[PaginateType[T], Depends(new_pagination())]
"""A dependency used in endpoints to paginate `SQLAlchemy` select queries.

It adds `offset`and `limit` query parameters to the endpoint, which are used to paginate.
The model returned by the endpoint is a `Page` model. It contains a page of data and
metadata:
It adds **`offset`** and **`limit`** query parameters to the endpoint, which are used to
paginate. The model returned by the endpoint is a `Page` model. It contains a page of
data and metadata:

```json
{
Expand All @@ -351,55 +374,4 @@ async def paginate(stmt: Select) -> Page:
}
}
```

-----

Example:
``` py title="example.py" hl_lines="22 23 25"
from fastsqla import Base, Paginate, Page
from pydantic import BaseModel


class Hero(Base):
__tablename__ = "hero"


class Hero(Base):
__tablename__ = "hero"
id: Mapped[int] = mapped_column(primary_key=True)
name: Mapped[str] = mapped_column(unique=True)
secret_identity: Mapped[str]
age: Mapped[int]


class HeroModel(HeroBase):
model_config = ConfigDict(from_attributes=True)
id: int


@app.get("/heros", response_model=Page[HeroModel]) # (1)!
async def list_heros(paginate: Paginate): # (2)!
stmt = select(Hero)
return await paginate(stmt) # (3)!
```

1. The endpoint returns a `Page` model of `HeroModel`.
2. Just define an argument with type `Paginate` to get an async `paginate` function
injected in your endpoint function.
3. Await the `paginate` function with the `SQLAlchemy` select statement to get the
paginated result.

To add filtering, just add whatever query parameters you need to the endpoint:

```python

from fastsqla import Paginate, Page

@app.get("/heros", response_model=Page[HeroModel])
async def list_heros(paginate: Paginate, age:int | None = None):
stmt = select(Hero)
if age:
stmt = stmt.where(Hero.age == age)
return await paginate(stmt)
```
"""
Loading