diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 8a0e89dd..df73dd76 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -47,7 +47,7 @@ jobs: steps: - uses: "actions/checkout@v3" - + - uses: "actions/setup-python@v4" with: cache: "pip" @@ -71,12 +71,15 @@ jobs: export TOTAL=$(python -c "import json;print(json.load(open('coverage.json'))['totals']['percent_covered_display'])") echo "total=$TOTAL" >> $GITHUB_ENV + # Report again and fail if under the threshold. + python -Im coverage report --fail-under=97 + - name: "Upload HTML report." uses: "actions/upload-artifact@v3" with: name: "html-report" path: "htmlcov" - + - name: "Make badge" if: github.ref == 'refs/heads/main' uses: "schneegans/dynamic-badges-action@v1.4.0" diff --git a/HISTORY.md b/HISTORY.md index b772db02..ebcf92ba 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -1,5 +1,10 @@ # History +## 24.1.0 (UNRELEASED) + +- More robust support for `Annotated` and `NotRequired` in TypedDicts. + ([#450](https://github.com/python-attrs/cattrs/pull/450)) + ## 23.2.1 (2023-11-18) - Fix unnecessary `typing_extensions` import on Python 3.11. diff --git a/src/cattrs/_compat.py b/src/cattrs/_compat.py index 428734d7..4d01dd41 100644 --- a/src/cattrs/_compat.py +++ b/src/cattrs/_compat.py @@ -6,7 +6,7 @@ from dataclasses import fields as dataclass_fields from dataclasses import is_dataclass from typing import AbstractSet as TypingAbstractSet -from typing import Any, Deque, Dict, Final, FrozenSet, List +from typing import Any, Deque, Dict, Final, FrozenSet, List, Literal from typing import Mapping as TypingMapping from typing import MutableMapping as TypingMutableMapping from typing import MutableSequence as TypingMutableSequence @@ -243,6 +243,9 @@ def get_newtype_base(typ: Any) -> Optional[type]: return None def get_notrequired_base(type) -> "Union[Any, Literal[NOTHING]]": + if is_annotated(type): + # Handle `Annotated[NotRequired[int]]` + type = get_args(type)[0] if get_origin(type) in (NotRequired, Required): return get_args(type)[0] return NOTHING @@ -438,8 +441,6 @@ def is_counter(type): or getattr(type, "__origin__", None) is ColCounter ) - from typing import Literal - def is_literal(type) -> bool: return type.__class__ is _GenericAlias and type.__origin__ is Literal @@ -453,6 +454,10 @@ def copy_with(type, args): return type.copy_with(args) def get_notrequired_base(type) -> "Union[Any, Literal[NOTHING]]": + if is_annotated(type): + # Handle `Annotated[NotRequired[int]]` + type = get_origin(type) + if get_origin(type) in (NotRequired, Required): return get_args(type)[0] return NOTHING diff --git a/src/cattrs/converters.py b/src/cattrs/converters.py index 3ba1ecad..9b2b99cd 100644 --- a/src/cattrs/converters.py +++ b/src/cattrs/converters.py @@ -55,7 +55,13 @@ is_union_type, ) from .disambiguators import create_default_dis_func, is_supported_union -from .dispatch import HookFactory, MultiStrategyDispatch, StructureHook, UnstructureHook +from .dispatch import ( + HookFactory, + MultiStrategyDispatch, + StructureHook, + UnstructuredValue, + UnstructureHook, +) from .errors import ( IterableValidationError, IterableValidationNote, @@ -327,7 +333,7 @@ def register_structure_hook_factory( """ self._structure_func.register_func_list([(predicate, factory, True)]) - def structure(self, obj: Any, cl: Type[T]) -> T: + def structure(self, obj: UnstructuredValue, cl: Type[T]) -> T: """Convert unstructured Python data structures to structured data.""" return self._structure_func.dispatch(cl)(obj, cl) diff --git a/src/cattrs/gen/typeddicts.py b/src/cattrs/gen/typeddicts.py index ed02249d..023d625a 100644 --- a/src/cattrs/gen/typeddicts.py +++ b/src/cattrs/gen/typeddicts.py @@ -125,19 +125,20 @@ def make_dict_unstructure_fn( break handler = None t = a.type - nrb = get_notrequired_base(t) - if nrb is not NOTHING: - t = nrb if isinstance(t, TypeVar): if t.__name__ in mapping: t = mapping[t.__name__] else: + # Unbound typevars use late binding. handler = converter.unstructure elif is_generic(t) and not is_bare(t) and not is_annotated(t): t = deep_copy_with(t, mapping) if handler is None: + nrb = get_notrequired_base(t) + if nrb is not NOTHING: + t = nrb try: handler = converter._unstructure_func.dispatch(t) except RecursionError: @@ -171,9 +172,6 @@ def make_dict_unstructure_fn( handler = override.unstruct_hook else: t = a.type - nrb = get_notrequired_base(t) - if nrb is not NOTHING: - t = nrb if isinstance(t, TypeVar): if t.__name__ in mapping: @@ -184,6 +182,9 @@ def make_dict_unstructure_fn( t = deep_copy_with(t, mapping) if handler is None: + nrb = get_notrequired_base(t) + if nrb is not NOTHING: + t = nrb try: handler = converter._unstructure_func.dispatch(t) except RecursionError: @@ -282,9 +283,6 @@ def make_dict_structure_fn( mapping = generate_mapping(base, mapping) break - if isinstance(cl, TypeVar): - cl = mapping.get(cl.__name__, cl) - cl_name = cl.__name__ fn_name = "structure_" + cl_name @@ -337,6 +335,12 @@ def make_dict_structure_fn( if override.omit: continue t = a.type + + if isinstance(t, TypeVar): + t = mapping.get(t.__name__, t) + elif is_generic(t) and not is_bare(t) and not is_annotated(t): + t = deep_copy_with(t, mapping) + nrb = get_notrequired_base(t) if nrb is not NOTHING: t = nrb @@ -370,16 +374,11 @@ def make_dict_structure_fn( tn = f"__c_type_{ix}" internal_arg_parts[tn] = t - if handler: - if handler == converter._structure_call: - internal_arg_parts[struct_handler_name] = t - lines.append(f"{i}res['{an}'] = {struct_handler_name}(o['{kn}'])") - else: - lines.append( - f"{i}res['{an}'] = {struct_handler_name}(o['{kn}'], {tn})" - ) + if handler == converter._structure_call: + internal_arg_parts[struct_handler_name] = t + lines.append(f"{i}res['{an}'] = {struct_handler_name}(o['{kn}'])") else: - lines.append(f"{i}res['{an}'] = o['{kn}']") + lines.append(f"{i}res['{an}'] = {struct_handler_name}(o['{kn}'], {tn})") if override.rename is not None: lines.append(f"{i}del res['{kn}']") i = i[:-2] @@ -415,42 +414,38 @@ def make_dict_structure_fn( continue t = a.type - nrb = get_notrequired_base(t) - if nrb is not NOTHING: - t = nrb if isinstance(t, TypeVar): t = mapping.get(t.__name__, t) elif is_generic(t) and not is_bare(t) and not is_annotated(t): t = deep_copy_with(t, mapping) - # For each attribute, we try resolving the type here and now. - # If a type is manually overwritten, this function should be - # regenerated. - if t is not None: - handler = converter._structure_func.dispatch(t) + nrb = get_notrequired_base(t) + if nrb is not NOTHING: + t = nrb + + if override.struct_hook is not None: + handler = override.struct_hook else: - handler = converter.structure + # For each attribute, we try resolving the type here and now. + # If a type is manually overwritten, this function should be + # regenerated. + handler = converter._structure_func.dispatch(t) kn = an if override.rename is None else override.rename allowed_fields.add(kn) - if handler: - struct_handler_name = f"__c_structure_{ix}" - internal_arg_parts[struct_handler_name] = handler - if handler == converter._structure_call: - internal_arg_parts[struct_handler_name] = t - invocation_line = ( - f" res['{an}'] = {struct_handler_name}(o['{kn}'])" - ) - else: - tn = f"__c_type_{ix}" - internal_arg_parts[tn] = t - invocation_line = ( - f" res['{an}'] = {struct_handler_name}(o['{kn}'], {tn})" - ) + struct_handler_name = f"__c_structure_{ix}" + internal_arg_parts[struct_handler_name] = handler + if handler == converter._structure_call: + internal_arg_parts[struct_handler_name] = t + invocation_line = f" res['{an}'] = {struct_handler_name}(o['{kn}'])" else: - invocation_line = f" res['{an}'] = o['{kn}']" + tn = f"__c_type_{ix}" + internal_arg_parts[tn] = t + invocation_line = ( + f" res['{an}'] = {struct_handler_name}(o['{kn}'], {tn})" + ) lines.append(invocation_line) if override.rename is not None: @@ -472,13 +467,13 @@ def make_dict_structure_fn( elif is_generic(t) and not is_bare(t) and not is_annotated(t): t = deep_copy_with(t, mapping) - # For each attribute, we try resolving the type here and now. - # If a type is manually overwritten, this function should be - # regenerated. - if t is not None: - handler = converter._structure_func.dispatch(t) + if override.struct_hook is not None: + handler = override.struct_hook else: - handler = converter.structure + # For each attribute, we try resolving the type here and now. + # If a type is manually overwritten, this function should be + # regenerated. + handler = converter._structure_func.dispatch(t) struct_handler_name = f"__c_structure_{ix}" internal_arg_parts[struct_handler_name] = handler @@ -487,20 +482,17 @@ def make_dict_structure_fn( kn = an if override.rename is None else override.rename allowed_fields.add(kn) post_lines.append(f" if '{kn}' in o:") - if handler: - if handler == converter._structure_call: - internal_arg_parts[struct_handler_name] = t - post_lines.append( - f" res['{ian}'] = {struct_handler_name}(o['{kn}'])" - ) - else: - tn = f"__c_type_{ix}" - internal_arg_parts[tn] = t - post_lines.append( - f" res['{ian}'] = {struct_handler_name}(o['{kn}'], {tn})" - ) + if handler == converter._structure_call: + internal_arg_parts[struct_handler_name] = t + post_lines.append( + f" res['{ian}'] = {struct_handler_name}(o['{kn}'])" + ) else: - post_lines.append(f" res['{ian}'] = o['{kn}']") + tn = f"__c_type_{ix}" + internal_arg_parts[tn] = t + post_lines.append( + f" res['{ian}'] = {struct_handler_name}(o['{kn}'], {tn})" + ) if override.rename is not None: lines.append(f" res.pop('{override.rename}', None)") @@ -568,6 +560,7 @@ def _required_keys(cls: type) -> set[str]: from typing_extensions import Annotated, NotRequired, Required, get_args def _required_keys(cls: type) -> set[str]: + """Own own processor for required keys.""" if _is_extensions_typeddict(cls): return cls.__required_keys__ @@ -600,6 +593,7 @@ def _required_keys(cls: type) -> set[str]: # On 3.8, typing.TypedDicts do not have __required_keys__. def _required_keys(cls: type) -> set[str]: + """Own own processor for required keys.""" if _is_extensions_typeddict(cls): return cls.__required_keys__ @@ -613,12 +607,12 @@ def _required_keys(cls: type) -> set[str]: if key in superclass_keys: continue annotation_type = own_annotations[key] + + if is_annotated(annotation_type): + # If this is `Annotated`, we need to get the origin twice. + annotation_type = get_origin(annotation_type) + annotation_origin = get_origin(annotation_type) - if annotation_origin is Annotated: - annotation_args = get_args(annotation_type) - if annotation_args: - annotation_type = annotation_args[0] - annotation_origin = get_origin(annotation_type) if annotation_origin is Required: required_keys.add(key) diff --git a/tests/test_typeddicts.py b/tests/test_typeddicts.py index f805945b..1ffa455c 100644 --- a/tests/test_typeddicts.py +++ b/tests/test_typeddicts.py @@ -1,15 +1,20 @@ """Tests for TypedDict un/structuring.""" from datetime import datetime -from typing import Dict, Set, Tuple +from typing import Dict, Generic, Set, Tuple, TypedDict, TypeVar import pytest from hypothesis import assume, given from hypothesis.strategies import booleans from pytest import raises +from typing_extensions import NotRequired from cattrs import BaseConverter, Converter from cattrs._compat import ExtensionsTypedDict, is_generic -from cattrs.errors import ClassValidationError, ForbiddenExtraKeysError +from cattrs.errors import ( + ClassValidationError, + ForbiddenExtraKeysError, + StructureHandlerNotFoundError, +) from cattrs.gen import already_generating, override from cattrs.gen._generics import generate_mapping from cattrs.gen.typeddicts import ( @@ -155,8 +160,11 @@ def test_generics( cls, instance = cls_and_instance unstructured = c.unstructure(instance, unstructure_as=cls) + assert not any(isinstance(v, datetime) for v in unstructured.values()) - if all(a is not datetime for _, a in get_annot(cls).items()): + if all( + a not in (datetime, NotRequired[datetime]) for _, a in get_annot(cls).items() + ): assert unstructured == instance if all(a is int for _, a in get_annot(cls).items()): @@ -168,6 +176,24 @@ def test_generics( assert restructured == instance +@pytest.mark.skipif(not is_py311_plus, reason="3.11+ only") +@given(booleans()) +def test_generics_with_unbound(detailed_validation: bool): + """TypedDicts with unbound TypeVars work.""" + c = mk_converter(detailed_validation=detailed_validation) + + T = TypeVar("T") + + class GenericTypedDict(TypedDict, Generic[T]): + a: T + + assert c.unstructure({"a": 1}, GenericTypedDict) + + with pytest.raises(StructureHandlerNotFoundError): + # This doesn't work since we refuse the temptation to guess. + c.structure({"a": 1}, GenericTypedDict) + + @given(simple_typeddicts(total=True, not_required=True), booleans()) def test_not_required( cls_and_instance: Tuple[type, Dict], detailed_validation: bool @@ -415,3 +441,33 @@ class A(ExtensionsTypedDict): else: with pytest.raises(ValueError): converter.structure({"a": "a"}, A) + + +def test_override_entire_hooks(converter: BaseConverter): + """Overriding entire hooks works.""" + + class A(ExtensionsTypedDict): + a: int + b: NotRequired[int] + + converter.register_structure_hook( + A, + make_dict_structure_fn( + A, + converter, + a=override(struct_hook=lambda v, _: 1), + b=override(struct_hook=lambda v, _: 2), + ), + ) + converter.register_unstructure_hook( + A, + make_dict_unstructure_fn( + A, + converter, + a=override(unstruct_hook=lambda v: 1), + b=override(unstruct_hook=lambda v: 2), + ), + ) + + assert converter.unstructure({"a": 10, "b": 10}, A) == {"a": 1, "b": 2} + assert converter.structure({"a": 10, "b": 10}, A) == {"a": 1, "b": 2} diff --git a/tests/typeddicts.py b/tests/typeddicts.py index 4f7804d4..18453d70 100644 --- a/tests/typeddicts.py +++ b/tests/typeddicts.py @@ -3,7 +3,7 @@ from string import ascii_lowercase from typing import Any, Dict, Generic, List, Optional, Set, Tuple, TypeVar -from attr import NOTHING +from attrs import NOTHING from hypothesis.strategies import ( DrawFn, SearchStrategy, @@ -17,7 +17,13 @@ text, ) -from cattrs._compat import ExtensionsTypedDict, NotRequired, Required, TypedDict +from cattrs._compat import ( + Annotated, + ExtensionsTypedDict, + NotRequired, + Required, + TypedDict, +) from .untyped import gen_attr_names @@ -55,6 +61,34 @@ def int_attributes( return int, integers() | just(NOTHING), text(ascii_lowercase) +@composite +def annotated_int_attributes( + draw: DrawFn, total: bool = True, not_required: bool = False +) -> Tuple[int, SearchStrategy, SearchStrategy]: + """Generate combinations of Annotated types.""" + if total: + if not_required and draw(booleans()): + return ( + NotRequired[Annotated[int, "test"]] + if draw(booleans()) + else Annotated[NotRequired[int], "test"], + integers() | just(NOTHING), + text(ascii_lowercase), + ) + return Annotated[int, "test"], integers(), text(ascii_lowercase) + + if not_required and draw(booleans()): + return ( + Required[Annotated[int, "test"]] + if draw(booleans()) + else Annotated[Required[int], "test"], + integers(), + text(ascii_lowercase), + ) + + return Annotated[int, "test"], integers() | just(NOTHING), text(ascii_lowercase) + + @composite def datetime_attributes( draw: DrawFn, total: bool = True, not_required: bool = False @@ -120,6 +154,7 @@ def simple_typeddicts( attrs = draw( lists( int_attributes(total, not_required) + | annotated_int_attributes(total, not_required) | list_of_int_attributes(total, not_required) | datetime_attributes(total, not_required), min_size=min_attrs, @@ -201,7 +236,11 @@ def generic_typeddicts( if ix in generic_attrs: typevar = TypeVar(f"T{ix+1}") generics.append(typevar) - actual_types.append(attr_type) + if total and draw(booleans()): + # We might decide to make these NotRequired + actual_types.append(NotRequired[attr_type]) + else: + actual_types.append(attr_type) attrs_dict[attr_name] = typevar cls = make_typeddict( @@ -227,13 +266,13 @@ def make_typeddict( globs = {"TypedDict": TypedDict} lines = [] - bases_snippet = ",".join(f"_base{ix}" for ix in range(len(bases))) + bases_snippet = ", ".join(f"_base{ix}" for ix in range(len(bases))) for ix, base in enumerate(bases): globs[f"_base{ix}"] = base if bases_snippet: bases_snippet = f", {bases_snippet}" - lines.append(f"class {cls_name}(TypedDict{bases_snippet},total={total}):") + lines.append(f"class {cls_name}(TypedDict{bases_snippet}, total={total}):") for n, t in attrs.items(): # Strip the initial underscore if present, to prevent mangling. trimmed = n[1:] if n.startswith("_") else n