17
17
from sqlalchemy .orm import DeclarativeBase
18
18
from structlog import get_logger
19
19
20
+ logger = get_logger (__name__ )
21
+
22
+ try :
23
+ from sqlmodel .ext .asyncio .session import AsyncSession
24
+
25
+ except ImportError :
26
+ pass
27
+
28
+
20
29
__all__ = [
21
30
"Base" ,
22
31
"Collection" ,
30
39
"open_session" ,
31
40
]
32
41
33
- SessionFactory = async_sessionmaker (expire_on_commit = False )
42
+ SessionFactory = async_sessionmaker (expire_on_commit = False , class_ = AsyncSession )
34
43
35
44
logger = get_logger (__name__ )
36
45
@@ -56,6 +65,10 @@ class Hero(Base):
56
65
57
66
* [ORM Quick Start](https://docs.sqlalchemy.org/en/20/orm/quickstart.html)
58
67
* [Declarative Mapping](https://docs.sqlalchemy.org/en/20/orm/mapping_styles.html#declarative-mapping)
68
+
69
+ !!! note
70
+
71
+ You don't need this if you use [`SQLModel`](http://sqlmodel.tiangolo.com/).
59
72
"""
60
73
61
74
__abstract__ = True
@@ -142,7 +155,7 @@ async def lifespan(app:FastAPI) -> AsyncGenerator[dict, None]:
142
155
143
156
@asynccontextmanager
144
157
async def open_session () -> AsyncGenerator [AsyncSession , None ]:
145
- """An asynchronous context manager that opens a new `SQLAlchemy` async session.
158
+ """Async context manager that opens a new `SQLAlchemy` or `SQLModel ` async session.
146
159
147
160
To the contrary of the [`Session`][fastsqla.Session] dependency which can only be
148
161
used in endpoints, `open_session` can be used anywhere such as in background tasks.
@@ -152,6 +165,16 @@ async def open_session() -> AsyncGenerator[AsyncSession, None]:
152
165
In all cases, it closes the session and returns the associated connection to the
153
166
connection pool.
154
167
168
+
169
+ Returns:
170
+ When `SQLModel` is not installed, an async generator that yields an
171
+ [`SQLAlchemy AsyncSession`][sqlalchemy.ext.asyncio.AsyncSession].
172
+
173
+ When `SQLModel` is installed, an async generator that yields an
174
+ [`SQLModel AsyncSession`](https://github.com./fastapi/sqlmodel/blob/main/sqlmodel/ext/asyncio/session.py#L32)
175
+ which inherits from [`SQLAlchemy AsyncSession`][sqlalchemy.ext.asyncio.AsyncSession].
176
+
177
+
155
178
```python
156
179
from fastsqla import open_session
157
180
@@ -191,12 +214,12 @@ async def new_session() -> AsyncGenerator[AsyncSession, None]:
191
214
192
215
193
216
Session = Annotated [AsyncSession , Depends (new_session )]
194
- """A dependency used exclusively in endpoints to get an `SQLAlchemy` session.
217
+ """Dependency used exclusively in endpoints to get an `SQLAlchemy` or `SQLModel ` session.
195
218
196
219
`Session` is a [`FastAPI` dependency](https://fastapi.tiangolo.com/tutorial/dependencies/)
197
- that provides an asynchronous `SQLAlchemy` session.
220
+ that provides an asynchronous `SQLAlchemy` session or `SQLModel` one if it's installed .
198
221
By defining an argument with type `Session` in an endpoint, `FastAPI` will automatically
199
- inject an `SQLAlchemy` async session into the endpoint.
222
+ inject an async session into the endpoint.
200
223
201
224
At the end of request handling:
202
225
@@ -336,9 +359,9 @@ async def paginate(stmt: Select) -> Page:
336
359
Paginate = Annotated [PaginateType [T ], Depends (new_pagination ())]
337
360
"""A dependency used in endpoints to paginate `SQLAlchemy` select queries.
338
361
339
- It adds `offset`and `limit` query parameters to the endpoint, which are used to paginate.
340
- The model returned by the endpoint is a `Page` model. It contains a page of data and
341
- metadata:
362
+ It adds ** `offset`** and ** `limit`** query parameters to the endpoint, which are used to
363
+ paginate. The model returned by the endpoint is a `Page` model. It contains a page of
364
+ data and metadata:
342
365
343
366
```json
344
367
{
@@ -351,55 +374,4 @@ async def paginate(stmt: Select) -> Page:
351
374
}
352
375
}
353
376
```
354
-
355
- -----
356
-
357
- Example:
358
- ``` py title="example.py" hl_lines="22 23 25"
359
- from fastsqla import Base, Paginate, Page
360
- from pydantic import BaseModel
361
-
362
-
363
- class Hero(Base):
364
- __tablename__ = "hero"
365
-
366
-
367
- class Hero(Base):
368
- __tablename__ = "hero"
369
- id: Mapped[int] = mapped_column(primary_key=True)
370
- name: Mapped[str] = mapped_column(unique=True)
371
- secret_identity: Mapped[str]
372
- age: Mapped[int]
373
-
374
-
375
- class HeroModel(HeroBase):
376
- model_config = ConfigDict(from_attributes=True)
377
- id: int
378
-
379
-
380
- @app.get("/heros", response_model=Page[HeroModel]) # (1)!
381
- async def list_heros(paginate: Paginate): # (2)!
382
- stmt = select(Hero)
383
- return await paginate(stmt) # (3)!
384
- ```
385
-
386
- 1. The endpoint returns a `Page` model of `HeroModel`.
387
- 2. Just define an argument with type `Paginate` to get an async `paginate` function
388
- injected in your endpoint function.
389
- 3. Await the `paginate` function with the `SQLAlchemy` select statement to get the
390
- paginated result.
391
-
392
- To add filtering, just add whatever query parameters you need to the endpoint:
393
-
394
- ```python
395
-
396
- from fastsqla import Paginate, Page
397
-
398
- @app.get("/heros", response_model=Page[HeroModel])
399
- async def list_heros(paginate: Paginate, age:int | None = None):
400
- stmt = select(Hero)
401
- if age:
402
- stmt = stmt.where(Hero.age == age)
403
- return await paginate(stmt)
404
- ```
405
377
"""
0 commit comments