Skip to content

Type checking improvements #32

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 25, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions openapi_python_client/openapi_parser/openapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,14 +236,16 @@ def _iterate_properties() -> Generator[Property, None, None]:

@staticmethod
def from_dict(d: Dict[str, Dict[str, Any]], /) -> OpenAPI:
""" Create an OpenAPI from dict """
""" Create an OpenAPI from dict
:rtype: object
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this necessary? Does mypy have a problem with -> OpenAPI for some reason?

"""
schemas = Schema.dict(d["components"]["schemas"])
endpoint_collections_by_tag = EndpointCollection.from_dict(d["paths"])
enums = OpenAPI._check_enums(schemas.values(), endpoint_collections_by_tag.values())

return OpenAPI(
title=d["info"]["title"],
description=d["info"]["description"],
description=d["info"].get("description"),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is to allow description to be empty right? So then we'd need to make OpenAPI.description an Optiona[str] instead of a str to be correct.

version=d["info"]["version"],
endpoint_collections_by_tag=endpoint_collections_by_tag,
schemas=schemas,
Expand Down
9 changes: 6 additions & 3 deletions openapi_python_client/openapi_parser/properties.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,10 @@ def transform(self) -> str:

def constructor_from_dict(self, dict_name: str) -> str:
""" How to load this property from a dict (used in generated model from_dict function """
return f'{self.reference.class_name}({dict_name}["{self.name}"]) if "{self.name}" in {dict_name} else None'
constructor = f'{self.reference.class_name}({dict_name}["{self.name}"])'
if not self.required:
constructor += f' if "{self.name}" in {dict_name} else None'
return constructor

@staticmethod
def values_from_list(l: List[str], /) -> Dict[str, str]:
Expand Down Expand Up @@ -208,15 +211,15 @@ def transform(self) -> str:
class DictProperty(Property):
""" Property that is a general Dict """

_type_string: ClassVar[str] = "Dict"
_type_string: ClassVar[str] = "Dict[Any, Any]"


_openapi_types_to_python_type_strings = {
"string": "str",
"number": "float",
"integer": "int",
"boolean": "bool",
"object": "Dict",
"object": "Dict[Any, Any]",
}


Expand Down
4 changes: 2 additions & 2 deletions openapi_python_client/openapi_parser/responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def return_string(self) -> str:

def constructor(self) -> str:
""" How the return value of this response should be constructed """
return f"[{self.reference.class_name}.from_dict(item) for item in response.json()]"
return f"[{self.reference.class_name}.from_dict(item) for item in cast(List[Dict[str, Any]], response.json())]"


@dataclass
Expand All @@ -48,7 +48,7 @@ def return_string(self) -> str:

def constructor(self) -> str:
""" How the return value of this response should be constructed """
return f"{self.reference.class_name}.from_dict(response.json())"
return f"{self.reference.class_name}.from_dict(cast(Dict[str, Any], response.json()))"


@dataclass
Expand Down
6 changes: 3 additions & 3 deletions openapi_python_client/templates/async_endpoint_module.pyi
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import asdict
from typing import Dict, List, Optional, Union
from typing import Dict, List, Optional, Union, Any, cast

import httpx

Expand Down Expand Up @@ -60,8 +60,8 @@ async def {{ endpoint.name }}(
{% endfor %}
{% endif %}

with httpx.AsyncClient() as client:
response = await client.{{ endpoint.method }}(
async with httpx.AsyncClient() as _client:
response = await _client.{{ endpoint.method }}(
url=url,
headers=client.get_headers(),
{% if endpoint.form_body_reference %}
Expand Down
2 changes: 1 addition & 1 deletion openapi_python_client/templates/endpoint_module.pyi
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import asdict
from typing import Dict, List, Optional, Union
from typing import Dict, List, Optional, Union, Any, cast

import httpx

Expand Down
6 changes: 3 additions & 3 deletions openapi_python_client/templates/model.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ from __future__ import annotations

from dataclasses import dataclass
from datetime import datetime
from typing import Dict, List, Optional, cast
from typing import Any, Dict, List, Optional, cast

{% for relative in schema.relative_imports %}
{{ relative }}
Expand All @@ -16,7 +16,7 @@ class {{ schema.reference.class_name }}:
{{ property.to_string() }}
{% endfor %}

def to_dict(self) -> Dict:
def to_dict(self) -> Dict[str, Any]:
return {
{% for property in schema.required_properties %}
"{{ property.name }}": self.{{ property.transform() }},
Expand All @@ -27,7 +27,7 @@ class {{ schema.reference.class_name }}:
}

@staticmethod
def from_dict(d: Dict) -> {{ schema.reference.class_name }}:
def from_dict(d: Dict[str, Any]) -> {{ schema.reference.class_name }}:
{% for property in schema.required_properties + schema.optional_properties %}

{% if property.constructor_template %}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import asdict
from typing import Dict, List, Optional, Union
from typing import Any, Dict, List, Optional, Union, cast

import httpx

Expand All @@ -19,6 +19,6 @@ def ping_ping_get(
response = httpx.get(url=url, headers=client.get_headers(),)

if response.status_code == 200:
return ABCResponse.from_dict(response.json())
return ABCResponse.from_dict(cast(Dict[str, Any], response.json()))
else:
raise ApiResponseError(response=response)
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import asdict
from typing import Dict, List, Optional, Union
from typing import Any, Dict, List, Optional, Union, cast

import httpx

Expand All @@ -25,8 +25,8 @@ def get_list_tests__get(
response = httpx.get(url=url, headers=client.get_headers(), params=params,)

if response.status_code == 200:
return [AModel.from_dict(item) for item in response.json()]
return [AModel.from_dict(item) for item in cast(List[Dict[str, Any]], response.json())]
if response.status_code == 422:
return HTTPValidationError.from_dict(response.json())
return HTTPValidationError.from_dict(cast(Dict[str, Any], response.json()))
else:
raise ApiResponseError(response=response)
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import asdict
from typing import Dict, List, Optional, Union
from typing import Any, Dict, List, Optional, Union, cast

import httpx

Expand All @@ -16,10 +16,10 @@ async def ping_ping_get(
""" A quick check to see if the system is running """
url = f"{client.base_url}/ping"

with httpx.AsyncClient() as client:
response = await client.get(url=url, headers=client.get_headers(),)
async with httpx.AsyncClient() as _client:
response = await _client.get(url=url, headers=client.get_headers(),)

if response.status_code == 200:
return ABCResponse.from_dict(response.json())
return ABCResponse.from_dict(cast(Dict[str, Any], response.json()))
else:
raise ApiResponseError(response=response)
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import asdict
from typing import Dict, List, Optional, Union
from typing import Any, Dict, List, Optional, Union, cast

import httpx

Expand All @@ -22,12 +22,12 @@ async def get_list_tests__get(
"statuses": statuses,
}

with httpx.AsyncClient() as client:
response = await client.get(url=url, headers=client.get_headers(), params=params,)
async with httpx.AsyncClient() as _client:
response = await _client.get(url=url, headers=client.get_headers(), params=params,)

if response.status_code == 200:
return [AModel.from_dict(item) for item in response.json()]
return [AModel.from_dict(item) for item in cast(List[Dict[str, Any]], response.json())]
if response.status_code == 422:
return HTTPValidationError.from_dict(response.json())
return HTTPValidationError.from_dict(cast(Dict[str, Any], response.json()))
else:
raise ApiResponseError(response=response)
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from dataclasses import dataclass
from datetime import datetime
from typing import Dict, List, Optional, cast
from typing import Any, Dict, List, Optional, cast

from .a_list_of_enums import AListOfEnums
from .an_enum_value import AnEnumValue
Expand All @@ -18,7 +18,7 @@ class AModel:
a_list_of_strings: List[str]
a_list_of_objects: List[OtherModel]

def to_dict(self) -> Dict:
def to_dict(self) -> Dict[str, Any]:
return {
"an_enum_value": self.an_enum_value.value,
"a_list_of_enums": self.a_list_of_enums,
Expand All @@ -27,9 +27,9 @@ def to_dict(self) -> Dict:
}

@staticmethod
def from_dict(d: Dict) -> AModel:
def from_dict(d: Dict[str, Any]) -> AModel:

an_enum_value = AnEnumValue(d["an_enum_value"]) if "an_enum_value" in d else None
an_enum_value = AnEnumValue(d["an_enum_value"])

a_list_of_enums = []
for a_list_of_enums_item in d.get("a_list_of_enums", []):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from dataclasses import dataclass
from datetime import datetime
from typing import Dict, List, Optional, cast
from typing import Any, Dict, List, Optional, cast


@dataclass
Expand All @@ -11,13 +11,13 @@ class ABCResponse:

success: bool

def to_dict(self) -> Dict:
def to_dict(self) -> Dict[str, Any]:
return {
"success": self.success,
}

@staticmethod
def from_dict(d: Dict) -> ABCResponse:
def from_dict(d: Dict[str, Any]) -> ABCResponse:

success = d["success"]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from dataclasses import dataclass
from datetime import datetime
from typing import Dict, List, Optional, cast
from typing import Any, Dict, List, Optional, cast

from .validation_error import ValidationError

Expand All @@ -13,13 +13,13 @@ class HTTPValidationError:

detail: Optional[List[ValidationError]] = None

def to_dict(self) -> Dict:
def to_dict(self) -> Dict[str, Any]:
return {
"detail": self.detail if self.detail is not None else None,
}

@staticmethod
def from_dict(d: Dict) -> HTTPValidationError:
def from_dict(d: Dict[str, Any]) -> HTTPValidationError:

detail = []
for detail_item in d.get("detail", []):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from dataclasses import dataclass
from datetime import datetime
from typing import Dict, List, Optional, cast
from typing import Any, Dict, List, Optional, cast


@dataclass
Expand All @@ -11,13 +11,13 @@ class OtherModel:

a_value: str

def to_dict(self) -> Dict:
def to_dict(self) -> Dict[str, Any]:
return {
"a_value": self.a_value,
}

@staticmethod
def from_dict(d: Dict) -> OtherModel:
def from_dict(d: Dict[str, Any]) -> OtherModel:

a_value = d["a_value"]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from dataclasses import dataclass
from datetime import datetime
from typing import Dict, List, Optional, cast
from typing import Any, Dict, List, Optional, cast


@dataclass
Expand All @@ -13,15 +13,15 @@ class ValidationError:
msg: str
type: str

def to_dict(self) -> Dict:
def to_dict(self) -> Dict[str, Any]:
return {
"loc": self.loc,
"msg": self.msg,
"type": self.type,
}

@staticmethod
def from_dict(d: Dict) -> ValidationError:
def from_dict(d: Dict[str, Any]) -> ValidationError:

loc = d.get("loc", [])

Expand Down
5 changes: 5 additions & 0 deletions tests/test_end_to_end/test_end_to_end.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,4 +36,9 @@ def test_end_to_end(capsys):
if result.exit_code != 0:
raise result.exception
_compare_directories(gm_path, output_path)

import mypy.api
out, err, status = mypy.api.run([str(output_path), "--strict"])
assert status == 0, f"Hello Type checking client failed: {err}"

shutil.rmtree(output_path)
14 changes: 12 additions & 2 deletions tests/test_openapi_parser/test_properties.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,9 +169,19 @@ def test_constructor_from_dict(self, mocker):

assert (
enum_property.constructor_from_dict("my_dict")
== 'MyTestEnum(my_dict["test_enum"]) if "test_enum" in my_dict else None'
== 'MyTestEnum(my_dict["test_enum"])'
)

enum_property = EnumProperty(name="test_enum", required=False,
default=None, values={})

assert (
enum_property.constructor_from_dict("my_dict")
== 'MyTestEnum(my_dict["test_enum"]) if "test_enum" in my_dict else None'
)



def test_values_from_list(self):
from openapi_python_client.openapi_parser.properties import EnumProperty

Expand Down Expand Up @@ -392,7 +402,7 @@ def test_property_from_dict_enum_array(self, mocker):

@pytest.mark.parametrize(
"openapi_type,python_type",
[("string", "str"), ("number", "float"), ("integer", "int"), ("boolean", "bool"), ("object", "Dict"),],
[("string", "str"), ("number", "float"), ("integer", "int"), ("boolean", "bool"), ("object", "Dict[Any, Any]"),],
)
def test_property_from_dict_simple_array(self, mocker, openapi_type, python_type):
name = mocker.MagicMock()
Expand Down
4 changes: 2 additions & 2 deletions tests/test_openapi_parser/test_responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def test_constructor(self, mocker):

r = ListRefResponse(200, reference=mocker.MagicMock(class_name="SuperCoolClass"))

assert r.constructor() == "[SuperCoolClass.from_dict(item) for item in response.json()]"
assert r.constructor() == "[SuperCoolClass.from_dict(item) for item in cast(List[Dict[str, Any]], response.json())]"


class TestRefResponse:
Expand All @@ -48,7 +48,7 @@ def test_constructor(self, mocker):

r = RefResponse(200, reference=mocker.MagicMock(class_name="SuperCoolClass"))

assert r.constructor() == "SuperCoolClass.from_dict(response.json())"
assert r.constructor() == "SuperCoolClass.from_dict(cast(Dict[str, Any], response.json()))"


class TestStringResponse:
Expand Down