Skip to content

Commit a3842bc

Browse files
committed
Add async_api to generated clients
Closes #16
1 parent f09900b commit a3842bc

File tree

7 files changed

+153
-8
lines changed

7 files changed

+153
-8
lines changed

openapi_python_client/__init__.py

+10-2
Original file line numberDiff line numberDiff line change
@@ -155,13 +155,21 @@ def _build_api(self) -> None:
155155
api_dir = self.package_dir / "api"
156156
api_dir.mkdir()
157157
api_init = api_dir / "__init__.py"
158-
api_init.write_text('""" Contains all methods for accessing the API """')
158+
api_init.write_text('""" Contains synchronous methods for accessing the API """')
159159

160-
api_errors = api_dir / "errors.py"
160+
async_api_dir = self.package_dir / "async_api"
161+
async_api_dir.mkdir()
162+
async_api_init = async_api_dir / "__init__.py"
163+
async_api_init.write_text('""" Contains async methods for accessing the API """')
164+
165+
api_errors = self.package_dir / "errors.py"
161166
errors_template = self.env.get_template("errors.pyi")
162167
api_errors.write_text(errors_template.render())
163168

164169
endpoint_template = self.env.get_template("endpoint_module.pyi")
170+
async_endpoint_template = self.env.get_template("async_endpoint_module.pyi")
165171
for tag, collection in self.openapi.endpoint_collections_by_tag.items():
166172
module_path = api_dir / f"{tag}.py"
167173
module_path.write_text(endpoint_template.render(collection=collection))
174+
async_module_path = async_api_dir / f"{tag}.py"
175+
async_module_path.write_text(async_endpoint_template.render(collection=collection))
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
from dataclasses import asdict
2+
from typing import Dict, List, Optional, Union
3+
4+
import httpx
5+
6+
from ..client import AuthenticatedClient, Client
7+
from .errors import ApiResponseError
8+
9+
{% for relative in collection.relative_imports %}
10+
{{ relative }}
11+
{% endfor %}
12+
{% for endpoint in collection.endpoints %}
13+
14+
15+
async def {{ endpoint.name }}(
16+
*,
17+
{# Proper client based on whether or not the endpoint requires authentication #}
18+
{% if endpoint.requires_security %}
19+
client: AuthenticatedClient,
20+
{% else %}
21+
client: Client,
22+
{% endif %}
23+
{# path parameters #}
24+
{% for parameter in endpoint.path_parameters %}
25+
{{ parameter.to_string() }},
26+
{% endfor %}
27+
{# Form data if any #}
28+
{% if endpoint.form_body_reference %}
29+
form_data: {{ endpoint.form_body_reference.class_name }},
30+
{% endif %}
31+
{# JSON body if any #}
32+
{% if endpoint.json_body %}
33+
json_body: {{ endpoint.json_body.get_type_string() }},
34+
{% endif %}
35+
{# query parameters #}
36+
{% for parameter in endpoint.query_parameters %}
37+
{{ parameter.to_string() }},
38+
{% endfor %}
39+
) -> Union[
40+
{% for response in endpoint.responses %}
41+
{{ response.return_string() }},
42+
{% endfor %}
43+
]:
44+
""" {{ endpoint.description }} """
45+
url = f"{client.base_url}{{ endpoint.path }}"
46+
47+
{% if endpoint.query_parameters %}
48+
params = {
49+
{% for parameter in endpoint.query_parameters %}
50+
{% if parameter.required %}
51+
"{{ parameter.name }}": {{ parameter.transform() }},
52+
{% endif %}
53+
{% endfor %}
54+
}
55+
{% for parameter in endpoint.query_parameters %}
56+
{% if not parameter.required %}
57+
if {{ parameter.name }} is not None:
58+
params["{{ parameter.name }}"] = {{ parameter.transform() }}
59+
{% endif %}
60+
{% endfor %}
61+
{% endif %}
62+
63+
with httpx.AsyncClient() as client:
64+
response = await client.{{ endpoint.method }}(
65+
url=url,
66+
headers=client.get_headers(),
67+
{% if endpoint.form_body_reference %}
68+
data=asdict(form_data),
69+
{% endif %}
70+
{% if endpoint.json_body %}
71+
json={{ endpoint.json_body.transform() }},
72+
{% endif %}
73+
{% if endpoint.query_parameters %}
74+
params=params,
75+
{% endif %}
76+
)
77+
78+
{% for response in endpoint.responses %}
79+
if response.status_code == {{ response.status_code }}:
80+
return {{ response.constructor() }}
81+
{% endfor %}
82+
else:
83+
raise ApiResponseError(response=response)
84+
{% endfor %}

tests/test___init__.py

+32-5
Original file line numberDiff line numberDiff line change
@@ -343,55 +343,82 @@ def test__build_api(self, mocker):
343343
openapi.endpoint_collections_by_tag = {tag_1: collection_1, tag_2: collection_2}
344344
project = _Project(openapi=openapi)
345345
project.package_dir = mocker.MagicMock()
346+
api_errors = mocker.MagicMock(autospec=pathlib.Path)
346347
client_path = mocker.MagicMock()
347348
api_init = mocker.MagicMock(autospec=pathlib.Path)
348-
api_errors = mocker.MagicMock(autospec=pathlib.Path)
349349
collection_1_path = mocker.MagicMock(autospec=pathlib.Path)
350350
collection_2_path = mocker.MagicMock(autospec=pathlib.Path)
351+
async_api_init = mocker.MagicMock(autospec=pathlib.Path)
352+
async_collection_1_path = mocker.MagicMock(autospec=pathlib.Path)
353+
async_collection_2_path = mocker.MagicMock(autospec=pathlib.Path)
351354
api_paths = {
352355
"__init__.py": api_init,
353-
"errors.py": api_errors,
354356
f"{tag_1}.py": collection_1_path,
355357
f"{tag_2}.py": collection_2_path,
356358
}
359+
async_api_paths = {
360+
"__init__.py": async_api_init,
361+
f"{tag_1}.py": async_collection_1_path,
362+
f"{tag_2}.py": async_collection_2_path,
363+
}
357364
api_dir = mocker.MagicMock(autospec=pathlib.Path)
358365
api_dir.__truediv__.side_effect = lambda x: api_paths[x]
366+
async_api_dir = mocker.MagicMock(autospec=pathlib.Path)
367+
async_api_dir.__truediv__.side_effect = lambda x: async_api_paths[x]
368+
359369
package_paths = {
360370
"client.py": client_path,
361371
"api": api_dir,
372+
"async_api": async_api_dir,
373+
"errors.py": api_errors,
362374
}
363375
project.package_dir.__truediv__.side_effect = lambda x: package_paths[x]
364376
client_template = mocker.MagicMock(autospec=Template)
365377
errors_template = mocker.MagicMock(autospec=Template)
366378
endpoint_template = mocker.MagicMock(autospec=Template)
379+
async_endpoint_template = mocker.MagicMock(autospec=Template)
367380
templates = {
368381
"client.pyi": client_template,
369382
"errors.pyi": errors_template,
370383
"endpoint_module.pyi": endpoint_template,
384+
"async_endpoint_module.pyi": async_endpoint_template,
371385
}
372386
mocker.patch.object(project.env, "get_template", autospec=True, side_effect=lambda x: templates[x])
373387
endpoint_renders = {
374388
collection_1: mocker.MagicMock(),
375389
collection_2: mocker.MagicMock(),
376390
}
377391
endpoint_template.render.side_effect = lambda collection: endpoint_renders[collection]
392+
async_endpoint_renders = {
393+
collection_1: mocker.MagicMock(),
394+
collection_2: mocker.MagicMock(),
395+
}
396+
async_endpoint_template.render.side_effect = lambda collection: async_endpoint_renders[collection]
378397

379398
project._build_api()
380399

381400
project.package_dir.__truediv__.assert_has_calls([mocker.call(key) for key in package_paths])
382401
project.env.get_template.assert_has_calls([mocker.call(key) for key in templates])
383402
client_template.render.assert_called_once()
384403
client_path.write_text.assert_called_once_with(client_template.render())
385-
api_dir.mkdir.assert_called_once()
386-
api_dir.__truediv__.assert_has_calls([mocker.call(key) for key in api_paths])
387-
api_init.write_text.assert_called_once_with('""" Contains all methods for accessing the API """')
388404
errors_template.render.assert_called_once()
389405
api_errors.write_text.assert_called_once_with(errors_template.render())
406+
api_dir.mkdir.assert_called_once()
407+
api_dir.__truediv__.assert_has_calls([mocker.call(key) for key in api_paths])
408+
api_init.write_text.assert_called_once_with('""" Contains synchronous methods for accessing the API """')
390409
endpoint_template.render.assert_has_calls(
391410
[mocker.call(collection=collection_1), mocker.call(collection=collection_2)]
392411
)
393412
collection_1_path.write_text.assert_called_once_with(endpoint_renders[collection_1])
394413
collection_2_path.write_text.assert_called_once_with(endpoint_renders[collection_2])
414+
async_api_dir.mkdir.assert_called_once()
415+
async_api_dir.__truediv__.assert_has_calls([mocker.call(key) for key in async_api_paths])
416+
async_api_init.write_text.assert_called_once_with('""" Contains async methods for accessing the API """')
417+
async_endpoint_template.render.assert_has_calls(
418+
[mocker.call(collection=collection_1), mocker.call(collection=collection_2)]
419+
)
420+
async_collection_1_path.write_text.assert_called_once_with(async_endpoint_renders[collection_1])
421+
async_collection_2_path.write_text.assert_called_once_with(async_endpoint_renders[collection_2])
395422

396423

397424
def test__reformat(mocker):
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
""" Contains all methods for accessing the API """
1+
""" Contains synchronous methods for accessing the API """
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
""" Contains async methods for accessing the API """
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
from dataclasses import asdict
2+
from typing import Dict, List, Optional, Union
3+
4+
import httpx
5+
6+
from ..client import AuthenticatedClient, Client
7+
from ..models.ping_response import PingResponse
8+
from .errors import ApiResponseError
9+
10+
11+
async def ping_ping_get(
12+
*, client: Client,
13+
) -> Union[
14+
PingResponse,
15+
]:
16+
""" A quick check to see if the system is running """
17+
url = f"{client.base_url}/ping"
18+
19+
with httpx.AsyncClient() as client:
20+
response = await client.get(url=url, headers=client.get_headers(),)
21+
22+
if response.status_code == 200:
23+
return PingResponse.from_dict(response.json())
24+
else:
25+
raise ApiResponseError(response=response)

0 commit comments

Comments
 (0)