Skip to content

Commit edf2310

Browse files
authored
feat: add support for SQLModel (#18)
1 parent 8e58a8b commit edf2310

File tree

10 files changed

+368
-75
lines changed

10 files changed

+368
-75
lines changed

.github/tox-uv/action.yml

+10
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
name: Setup tox-uv
2+
description: Setup tox-uv tool so tox uses uv to install dependencies
3+
runs:
4+
using: composite
5+
steps:
6+
- name: ⚡️ setup uv
7+
uses: ./.github/uv
8+
- name: ⚙️ install tox-uv
9+
shell: bash
10+
run: uv tool install tox --with tox-uv

.github/workflows/ci.yml

+9-4
Original file line numberDiff line numberDiff line change
@@ -6,16 +6,21 @@ on:
66
jobs:
77

88
Tests:
9+
strategy:
10+
matrix:
11+
tox_env: [default, sqlmodel]
912
runs-on: ubuntu-latest
1013
steps:
1114
- name: 📥 checkout
1215
uses: actions/checkout@v4
13-
- name: 🔧 setup uv
14-
uses: ./.github/uv
15-
- name: 🧪 pytest
16-
run: uv run pytest --cov fastsqla --cov-report=term-missing --cov-report=xml
16+
- name: 🔧 setup tox-uv
17+
uses: ./.github/tox-uv
18+
- name: 🧪 tox -e ${{ matrix.tox_env }}
19+
run: uv run tox -e ${{ matrix.tox_env }}
1720
- name: "🐔 codecov: upload test coverage"
1821
uses: codecov/[email protected]
22+
with:
23+
flags: ${{ matrix.tox_env }}
1924
env:
2025
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
2126

README.md

+34-6
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,10 @@ _Async SQLAlchemy 2.0+ for FastAPI — boilerplate, pagination, and seamless ses
1515

1616
-----------------------------------------------------------------------------------------
1717

18-
`FastSQLA` is an [`SQLAlchemy 2.0+`](https://docs.sqlalchemy.org/en/20/) extension for
19-
[`FastAPI`](https://fastapi.tiangolo.com/).
18+
`FastSQLA` is an async [`SQLAlchemy 2.0+`](https://docs.sqlalchemy.org/en/20/)
19+
extension for [`FastAPI`](https://fastapi.tiangolo.com/) with built-in pagination,
20+
[`SQLModel`](http://sqlmodel.tiangolo.com/) support and more.
21+
2022
It streamlines the configuration and asynchronous connection to relational databases by
2123
providing boilerplate and intuitive helpers. Additionally, it offers built-in
2224
customizable pagination and automatically manages the `SQLAlchemy` session lifecycle
@@ -74,13 +76,13 @@ following [`SQLAlchemy`'s best practices](https://docs.sqlalchemy.org/en/20/orm/
7476
async def get_heros(paginate:Paginate):
7577
return await paginate(select(Hero))
7678
```
77-
79+
7880
<center>
79-
81+
8082
👇 `/heros?offset=10&limit=10` 👇
81-
83+
8284
</center>
83-
85+
8486
```json
8587
{
8688
"data": [
@@ -119,6 +121,32 @@ following [`SQLAlchemy`'s best practices](https://docs.sqlalchemy.org/en/20/orm/
119121
* Session lifecycle management: session is commited on request success or rollback on
120122
failure.
121123

124+
* [`SQLModel`](http://sqlmodel.tiangolo.com/) support:
125+
```python
126+
...
127+
from fastsqla import Item, Page, Paginate, Session
128+
from sqlmodel import Field, SQLModel
129+
...
130+
131+
class Hero(SQLModel, table=True):
132+
id: int | None = Field(default=None, primary_key=True)
133+
name: str
134+
secret_identity: str
135+
age: int
136+
137+
138+
@app.get("/heroes", response_model=Page[Hero])
139+
async def get_heroes(paginate: Paginate):
140+
return await paginate(select(Hero))
141+
142+
143+
@app.get("/heroes/{hero_id}", response_model=Item[Hero])
144+
async def get_hero(session: Session, hero_id: int):
145+
hero = await session.get(Hero, hero_id)
146+
if hero is None:
147+
raise HTTPException(status_code=HTTPStatus.NOT_FOUND)
148+
return {"data": hero}
149+
```
122150

123151
## Installing
124152

docs/orm.md

+2-2
Original file line numberDiff line numberDiff line change
@@ -5,5 +5,5 @@
55
::: fastsqla.Base
66
options:
77
heading_level: false
8-
show_source: false
9-
show_bases: false
8+
show_source: true
9+
show_bases: true

docs/pagination.md

+70
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,73 @@
66
options:
77
heading_level: false
88
show_source: false
9+
10+
### `SQLAlchemy` example
11+
12+
``` py title="example.py" hl_lines="25 26 27"
13+
from fastapi import FastAPI
14+
from fastsqla import Base, Paginate, Page, lifespan
15+
from pydantic import BaseModel
16+
from sqlalchemy import select
17+
from sqlalchemy.orm import Mapped, mapped_column
18+
19+
app = FastAPI(lifespan=lifespan)
20+
21+
class Hero(Base):
22+
__tablename__ = "hero"
23+
id: Mapped[int] = mapped_column(primary_key=True)
24+
name: Mapped[str] = mapped_column(unique=True)
25+
secret_identity: Mapped[str]
26+
age: Mapped[int]
27+
28+
29+
class HeroModel(HeroBase):
30+
model_config = ConfigDict(from_attributes=True)
31+
id: int
32+
name: str
33+
secret_identity: str
34+
age: int
35+
36+
37+
@app.get("/heros", response_model=Page[HeroModel]) # (1)!
38+
async def list_heros(paginate: Paginate): # (2)!
39+
return await paginate(select(Hero)) # (3)!
40+
```
41+
42+
1. The endpoint returns a `Page` model of `HeroModel`.
43+
2. Just define an argument with type `Paginate` to get an async `paginate` function
44+
injected in your endpoint function.
45+
3. Await the `paginate` function with the `SQLAlchemy` select statement to get the
46+
paginated result.
47+
48+
To add filtering, just add whatever query parameters you need to the endpoint:
49+
50+
```python
51+
@app.get("/heros", response_model=Page[HeroModel])
52+
async def list_heros(paginate: Paginate, age:int | None = None):
53+
stmt = select(Hero)
54+
if age:
55+
stmt = stmt.where(Hero.age == age)
56+
return await paginate(stmt)
57+
```
58+
59+
### `SQLModel` example
60+
61+
```python
62+
from fastapi import FastAPI
63+
from fastsqla import Page, Paginate, Session
64+
from sqlmodel import Field, SQLModel
65+
from sqlalchemy import select
66+
67+
68+
class Hero(SQLModel, table=True):
69+
id: int | None = Field(default=None, primary_key=True)
70+
name: str
71+
secret_identity: str
72+
age: int
73+
74+
75+
@app.get("/heroes", response_model=Page[Hero])
76+
async def get_heroes(paginate: Paginate):
77+
return await paginate(select(Hero))
78+
```

pyproject.toml

+19-1
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ docs = [
4848
"mkdocs-material>=9.5.50",
4949
"mkdocstrings[python]>=0.27.0",
5050
]
51+
sqlmodel = ["sqlmodel>=0.0.22"]
5152

5253
[tool.uv]
5354
package = true
@@ -71,7 +72,10 @@ dev-dependencies = [
7172
pytest-watch = { git = "https://github.com./styleseat/pytest-watch", rev = "0342193" }
7273

7374
[tool.pytest.ini_options]
74-
asyncio_mode = 'auto'
75+
asyncio_mode = "auto"
76+
asyncio_default_fixture_loop_scope = "function"
77+
78+
filterwarnings = ["ignore::DeprecationWarning:"]
7579

7680
[tool.coverage.run]
7781
branch = true
@@ -86,3 +90,17 @@ version_toml = ["pyproject.toml:project.version"]
8690

8791
[tool.semantic_release.changelog.default_templates]
8892
changelog_file = "./docs/changelog.md"
93+
94+
[tool.tox]
95+
legacy_tox_ini = """
96+
[tox]
97+
envlist = { default, sqlmodel }
98+
99+
[testenv]
100+
passenv = CI
101+
runner = uv-venv-lock-runner
102+
commands =
103+
pytest --cov fastsqla --cov-report=term-missing --cov-report=xml
104+
extras:
105+
sqlmodel: sqlmodel
106+
"""

src/fastsqla.py

+31-59
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,15 @@
1717
from sqlalchemy.orm import DeclarativeBase
1818
from structlog import get_logger
1919

20+
logger = get_logger(__name__)
21+
22+
try:
23+
from sqlmodel.ext.asyncio.session import AsyncSession
24+
25+
except ImportError:
26+
pass
27+
28+
2029
__all__ = [
2130
"Base",
2231
"Collection",
@@ -30,7 +39,7 @@
3039
"open_session",
3140
]
3241

33-
SessionFactory = async_sessionmaker(expire_on_commit=False)
42+
SessionFactory = async_sessionmaker(expire_on_commit=False, class_=AsyncSession)
3443

3544
logger = get_logger(__name__)
3645

@@ -56,6 +65,10 @@ class Hero(Base):
5665
5766
* [ORM Quick Start](https://docs.sqlalchemy.org/en/20/orm/quickstart.html)
5867
* [Declarative Mapping](https://docs.sqlalchemy.org/en/20/orm/mapping_styles.html#declarative-mapping)
68+
69+
!!! note
70+
71+
You don't need this if you use [`SQLModel`](http://sqlmodel.tiangolo.com/).
5972
"""
6073

6174
__abstract__ = True
@@ -142,7 +155,7 @@ async def lifespan(app:FastAPI) -> AsyncGenerator[dict, None]:
142155

143156
@asynccontextmanager
144157
async def open_session() -> AsyncGenerator[AsyncSession, None]:
145-
"""An asynchronous context manager that opens a new `SQLAlchemy` async session.
158+
"""Async context manager that opens a new `SQLAlchemy` or `SQLModel` async session.
146159
147160
To the contrary of the [`Session`][fastsqla.Session] dependency which can only be
148161
used in endpoints, `open_session` can be used anywhere such as in background tasks.
@@ -152,6 +165,16 @@ async def open_session() -> AsyncGenerator[AsyncSession, None]:
152165
In all cases, it closes the session and returns the associated connection to the
153166
connection pool.
154167
168+
169+
Returns:
170+
When `SQLModel` is not installed, an async generator that yields an
171+
[`SQLAlchemy AsyncSession`][sqlalchemy.ext.asyncio.AsyncSession].
172+
173+
When `SQLModel` is installed, an async generator that yields an
174+
[`SQLModel AsyncSession`](https://github.com./fastapi/sqlmodel/blob/main/sqlmodel/ext/asyncio/session.py#L32)
175+
which inherits from [`SQLAlchemy AsyncSession`][sqlalchemy.ext.asyncio.AsyncSession].
176+
177+
155178
```python
156179
from fastsqla import open_session
157180
@@ -191,12 +214,12 @@ async def new_session() -> AsyncGenerator[AsyncSession, None]:
191214

192215

193216
Session = Annotated[AsyncSession, Depends(new_session)]
194-
"""A dependency used exclusively in endpoints to get an `SQLAlchemy` session.
217+
"""Dependency used exclusively in endpoints to get an `SQLAlchemy` or `SQLModel` session.
195218
196219
`Session` is a [`FastAPI` dependency](https://fastapi.tiangolo.com/tutorial/dependencies/)
197-
that provides an asynchronous `SQLAlchemy` session.
220+
that provides an asynchronous `SQLAlchemy` session or `SQLModel` one if it's installed.
198221
By defining an argument with type `Session` in an endpoint, `FastAPI` will automatically
199-
inject an `SQLAlchemy` async session into the endpoint.
222+
inject an async session into the endpoint.
200223
201224
At the end of request handling:
202225
@@ -336,9 +359,9 @@ async def paginate(stmt: Select) -> Page:
336359
Paginate = Annotated[PaginateType[T], Depends(new_pagination())]
337360
"""A dependency used in endpoints to paginate `SQLAlchemy` select queries.
338361
339-
It adds `offset`and `limit` query parameters to the endpoint, which are used to paginate.
340-
The model returned by the endpoint is a `Page` model. It contains a page of data and
341-
metadata:
362+
It adds **`offset`** and **`limit`** query parameters to the endpoint, which are used to
363+
paginate. The model returned by the endpoint is a `Page` model. It contains a page of
364+
data and metadata:
342365
343366
```json
344367
{
@@ -351,55 +374,4 @@ async def paginate(stmt: Select) -> Page:
351374
}
352375
}
353376
```
354-
355-
-----
356-
357-
Example:
358-
``` py title="example.py" hl_lines="22 23 25"
359-
from fastsqla import Base, Paginate, Page
360-
from pydantic import BaseModel
361-
362-
363-
class Hero(Base):
364-
__tablename__ = "hero"
365-
366-
367-
class Hero(Base):
368-
__tablename__ = "hero"
369-
id: Mapped[int] = mapped_column(primary_key=True)
370-
name: Mapped[str] = mapped_column(unique=True)
371-
secret_identity: Mapped[str]
372-
age: Mapped[int]
373-
374-
375-
class HeroModel(HeroBase):
376-
model_config = ConfigDict(from_attributes=True)
377-
id: int
378-
379-
380-
@app.get("/heros", response_model=Page[HeroModel]) # (1)!
381-
async def list_heros(paginate: Paginate): # (2)!
382-
stmt = select(Hero)
383-
return await paginate(stmt) # (3)!
384-
```
385-
386-
1. The endpoint returns a `Page` model of `HeroModel`.
387-
2. Just define an argument with type `Paginate` to get an async `paginate` function
388-
injected in your endpoint function.
389-
3. Await the `paginate` function with the `SQLAlchemy` select statement to get the
390-
paginated result.
391-
392-
To add filtering, just add whatever query parameters you need to the endpoint:
393-
394-
```python
395-
396-
from fastsqla import Paginate, Page
397-
398-
@app.get("/heros", response_model=Page[HeroModel])
399-
async def list_heros(paginate: Paginate, age:int | None = None):
400-
stmt = select(Hero)
401-
if age:
402-
stmt = stmt.where(Hero.age == age)
403-
return await paginate(stmt)
404-
```
405377
"""

0 commit comments

Comments
 (0)