Skip to content

Commit 03638dd

Browse files
authored
Support recursive named tuples (#13371)
This is a continuation of #13297 The main change here is that although named tuples are still stored in symbol tables as `TypeInfo`s, when type analyzer sees them, it creates a `TypeAliasType` targeting what it would return before (a `TupleType` with a fallback to an instance of that `TypeInfo`). Although it is a significant change, IMO this is the simplest but still clean way to support recursive named tuples. Also it is very simple to extend to TypedDicts, but I wanted to make the latter in a separate PR, to minimize the scope of changes. It would be great if someone can take a look at this PR soon. The most code changes are to make named tuples semantic analysis idempotent, previously they were analyzed "for real" only once, when all types were ready. It is not possible anymore if we want them to be recursive. So I pass in `existing_info` everywhere, and update it instead of creating a new one every time.
1 parent 27c5a9e commit 03638dd

16 files changed

+611
-67
lines changed

mypy/checkmember.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -563,6 +563,7 @@ def analyze_descriptor_access(descriptor_type: Type, mx: MemberContext) -> Type:
563563
The return type of the appropriate ``__get__`` overload for the descriptor.
564564
"""
565565
instance_type = get_proper_type(mx.original_type)
566+
orig_descriptor_type = descriptor_type
566567
descriptor_type = get_proper_type(descriptor_type)
567568

568569
if isinstance(descriptor_type, UnionType):
@@ -571,10 +572,10 @@ def analyze_descriptor_access(descriptor_type: Type, mx: MemberContext) -> Type:
571572
[analyze_descriptor_access(typ, mx) for typ in descriptor_type.items]
572573
)
573574
elif not isinstance(descriptor_type, Instance):
574-
return descriptor_type
575+
return orig_descriptor_type
575576

576577
if not descriptor_type.type.has_readable_member("__get__"):
577-
return descriptor_type
578+
return orig_descriptor_type
578579

579580
dunder_get = descriptor_type.type.get_method("__get__")
580581
if dunder_get is None:

mypy/fixup.py

+7
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ def visit_type_info(self, info: TypeInfo) -> None:
7575
p.accept(self.type_fixer)
7676
if info.tuple_type:
7777
info.tuple_type.accept(self.type_fixer)
78+
info.update_tuple_type(info.tuple_type)
7879
if info.typeddict_type:
7980
info.typeddict_type.accept(self.type_fixer)
8081
if info.declared_metaclass:
@@ -337,6 +338,12 @@ def lookup_fully_qualified_alias(
337338
node = stnode.node if stnode else None
338339
if isinstance(node, TypeAlias):
339340
return node
341+
elif isinstance(node, TypeInfo):
342+
if node.tuple_alias:
343+
return node.tuple_alias
344+
alias = TypeAlias.from_tuple_type(node)
345+
node.tuple_alias = alias
346+
return alias
340347
else:
341348
# Looks like a missing TypeAlias during an initial daemon load, put something there
342349
assert (

mypy/messages.py

+5
Original file line numberDiff line numberDiff line change
@@ -2292,6 +2292,11 @@ def visit_instance(self, t: Instance) -> None:
22922292
self.instances.append(t)
22932293
super().visit_instance(t)
22942294

2295+
def visit_type_alias_type(self, t: TypeAliasType) -> None:
2296+
if t.alias and not t.is_recursive:
2297+
t.alias.target.accept(self)
2298+
super().visit_type_alias_type(t)
2299+
22952300

22962301
def find_type_overlaps(*types: Type) -> Set[str]:
22972302
"""Return a set of fullnames that share a short name and appear in either type.

mypy/nodes.py

+25
Original file line numberDiff line numberDiff line change
@@ -2656,6 +2656,7 @@ class is generic then it will be a type constructor of higher kind.
26562656
"bases",
26572657
"_promote",
26582658
"tuple_type",
2659+
"tuple_alias",
26592660
"is_named_tuple",
26602661
"typeddict_type",
26612662
"is_newtype",
@@ -2794,6 +2795,9 @@ class is generic then it will be a type constructor of higher kind.
27942795
# It is useful for plugins to add their data to save in the cache.
27952796
metadata: Dict[str, JsonDict]
27962797

2798+
# Store type alias representing this type (for named tuples).
2799+
tuple_alias: Optional["TypeAlias"]
2800+
27972801
FLAGS: Final = [
27982802
"is_abstract",
27992803
"is_enum",
@@ -2840,6 +2844,7 @@ def __init__(self, names: "SymbolTable", defn: ClassDef, module_name: str) -> No
28402844
self._promote = []
28412845
self.alt_promote = None
28422846
self.tuple_type = None
2847+
self.tuple_alias = None
28432848
self.is_named_tuple = False
28442849
self.typeddict_type = None
28452850
self.is_newtype = False
@@ -2970,6 +2975,15 @@ def direct_base_classes(self) -> "List[TypeInfo]":
29702975
"""
29712976
return [base.type for base in self.bases]
29722977

2978+
def update_tuple_type(self, typ: "mypy.types.TupleType") -> None:
2979+
"""Update tuple_type and tuple_alias as needed."""
2980+
self.tuple_type = typ
2981+
alias = TypeAlias.from_tuple_type(self)
2982+
if not self.tuple_alias:
2983+
self.tuple_alias = alias
2984+
else:
2985+
self.tuple_alias.target = alias.target
2986+
29732987
def __str__(self) -> str:
29742988
"""Return a string representation of the type.
29752989
@@ -3258,6 +3272,17 @@ def __init__(
32583272
self.eager = eager
32593273
super().__init__(line, column)
32603274

3275+
@classmethod
3276+
def from_tuple_type(cls, info: TypeInfo) -> "TypeAlias":
3277+
"""Generate an alias to the tuple type described by a given TypeInfo."""
3278+
assert info.tuple_type
3279+
return TypeAlias(
3280+
info.tuple_type.copy_modified(fallback=mypy.types.Instance(info, [])),
3281+
info.fullname,
3282+
info.line,
3283+
info.column,
3284+
)
3285+
32613286
@property
32623287
def name(self) -> str:
32633288
return self._fullname.split(".")[-1]

mypy/semanal.py

+25-31
Original file line numberDiff line numberDiff line change
@@ -221,11 +221,11 @@
221221
PRIORITY_FALLBACKS,
222222
SemanticAnalyzerInterface,
223223
calculate_tuple_fallback,
224+
has_placeholder,
224225
set_callable_name as set_callable_name,
225226
)
226227
from mypy.semanal_typeddict import TypedDictAnalyzer
227228
from mypy.tvar_scope import TypeVarLikeScope
228-
from mypy.type_visitor import TypeQuery
229229
from mypy.typeanal import (
230230
TypeAnalyser,
231231
TypeVarLikeList,
@@ -1425,7 +1425,12 @@ def analyze_class_body_common(self, defn: ClassDef) -> None:
14251425

14261426
def analyze_namedtuple_classdef(self, defn: ClassDef) -> bool:
14271427
"""Check if this class can define a named tuple."""
1428-
if defn.info and defn.info.is_named_tuple:
1428+
if (
1429+
defn.info
1430+
and defn.info.is_named_tuple
1431+
and defn.info.tuple_type
1432+
and not has_placeholder(defn.info.tuple_type)
1433+
):
14291434
# Don't reprocess everything. We just need to process methods defined
14301435
# in the named tuple class body.
14311436
is_named_tuple, info = True, defn.info # type: bool, Optional[TypeInfo]
@@ -1782,10 +1787,9 @@ def configure_base_classes(
17821787
base_types: List[Instance] = []
17831788
info = defn.info
17841789

1785-
info.tuple_type = None
17861790
for base, base_expr in bases:
17871791
if isinstance(base, TupleType):
1788-
actual_base = self.configure_tuple_base_class(defn, base, base_expr)
1792+
actual_base = self.configure_tuple_base_class(defn, base)
17891793
base_types.append(actual_base)
17901794
elif isinstance(base, Instance):
17911795
if base.type.is_newtype:
@@ -1828,23 +1832,19 @@ def configure_base_classes(
18281832
return
18291833
self.calculate_class_mro(defn, self.object_type)
18301834

1831-
def configure_tuple_base_class(
1832-
self, defn: ClassDef, base: TupleType, base_expr: Expression
1833-
) -> Instance:
1835+
def configure_tuple_base_class(self, defn: ClassDef, base: TupleType) -> Instance:
18341836
info = defn.info
18351837

18361838
# There may be an existing valid tuple type from previous semanal iterations.
18371839
# Use equality to check if it is the case.
1838-
if info.tuple_type and info.tuple_type != base:
1840+
if info.tuple_type and info.tuple_type != base and not has_placeholder(info.tuple_type):
18391841
self.fail("Class has two incompatible bases derived from tuple", defn)
18401842
defn.has_incompatible_baseclass = True
1841-
info.tuple_type = base
1842-
if isinstance(base_expr, CallExpr):
1843-
defn.analyzed = NamedTupleExpr(base.partial_fallback.type)
1844-
defn.analyzed.line = defn.line
1845-
defn.analyzed.column = defn.column
1843+
if info.tuple_alias and has_placeholder(info.tuple_alias.target):
1844+
self.defer(force_progress=True)
1845+
info.update_tuple_type(base)
18461846

1847-
if base.partial_fallback.type.fullname == "builtins.tuple":
1847+
if base.partial_fallback.type.fullname == "builtins.tuple" and not has_placeholder(base):
18481848
# Fallback can only be safely calculated after semantic analysis, since base
18491849
# classes may be incomplete. Postpone the calculation.
18501850
self.schedule_patch(PRIORITY_FALLBACKS, lambda: calculate_tuple_fallback(base))
@@ -2627,7 +2627,10 @@ def analyze_enum_assign(self, s: AssignmentStmt) -> bool:
26272627
def analyze_namedtuple_assign(self, s: AssignmentStmt) -> bool:
26282628
"""Check if s defines a namedtuple."""
26292629
if isinstance(s.rvalue, CallExpr) and isinstance(s.rvalue.analyzed, NamedTupleExpr):
2630-
return True # This is a valid and analyzed named tuple definition, nothing to do here.
2630+
if s.rvalue.analyzed.info.tuple_type and not has_placeholder(
2631+
s.rvalue.analyzed.info.tuple_type
2632+
):
2633+
return True # This is a valid and analyzed named tuple definition, nothing to do here.
26312634
if len(s.lvalues) != 1 or not isinstance(s.lvalues[0], (NameExpr, MemberExpr)):
26322635
return False
26332636
lvalue = s.lvalues[0]
@@ -3028,6 +3031,9 @@ def check_and_set_up_type_alias(self, s: AssignmentStmt) -> bool:
30283031
# unless using PEP 613 `cls: TypeAlias = A`
30293032
return False
30303033

3034+
if isinstance(s.rvalue, CallExpr) and s.rvalue.analyzed:
3035+
return False
3036+
30313037
existing = self.current_symbol_table().get(lvalue.name)
30323038
# Third rule: type aliases can't be re-defined. For example:
30333039
# A: Type[float] = int
@@ -3157,9 +3163,8 @@ def check_and_set_up_type_alias(self, s: AssignmentStmt) -> bool:
31573163
self.cannot_resolve_name(lvalue.name, "name", s)
31583164
return True
31593165
else:
3160-
self.progress = True
31613166
# We need to defer so that this change can get propagated to base classes.
3162-
self.defer(s)
3167+
self.defer(s, force_progress=True)
31633168
else:
31643169
self.add_symbol(lvalue.name, alias_node, s)
31653170
if isinstance(rvalue, RefExpr) and isinstance(rvalue.node, TypeAlias):
@@ -5484,7 +5489,7 @@ def tvar_scope_frame(self, frame: TypeVarLikeScope) -> Iterator[None]:
54845489
yield
54855490
self.tvar_scope = old_scope
54865491

5487-
def defer(self, debug_context: Optional[Context] = None) -> None:
5492+
def defer(self, debug_context: Optional[Context] = None, force_progress: bool = False) -> None:
54885493
"""Defer current analysis target to be analyzed again.
54895494
54905495
This must be called if something in the current target is
@@ -5498,6 +5503,8 @@ def defer(self, debug_context: Optional[Context] = None) -> None:
54985503
They are usually preferable to a direct defer() call.
54995504
"""
55005505
assert not self.final_iteration, "Must not defer during final iteration"
5506+
if force_progress:
5507+
self.progress = True
55015508
self.deferred = True
55025509
# Store debug info for this deferral.
55035510
line = (
@@ -5999,19 +6006,6 @@ def is_future_flag_set(self, flag: str) -> bool:
59996006
return self.modules[self.cur_mod_id].is_future_flag_set(flag)
60006007

60016008

6002-
class HasPlaceholders(TypeQuery[bool]):
6003-
def __init__(self) -> None:
6004-
super().__init__(any)
6005-
6006-
def visit_placeholder_type(self, t: PlaceholderType) -> bool:
6007-
return True
6008-
6009-
6010-
def has_placeholder(typ: Type) -> bool:
6011-
"""Check if a type contains any placeholder types (recursively)."""
6012-
return typ.accept(HasPlaceholders())
6013-
6014-
60156009
def replace_implicit_first_type(sig: FunctionLike, new: Type) -> FunctionLike:
60166010
if isinstance(sig, CallableType):
60176011
if len(sig.arg_types) == 0:

mypy/semanal_namedtuple.py

+41-8
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
PRIORITY_FALLBACKS,
4545
SemanticAnalyzerInterface,
4646
calculate_tuple_fallback,
47+
has_placeholder,
4748
set_callable_name,
4849
)
4950
from mypy.types import (
@@ -109,8 +110,11 @@ def analyze_namedtuple_classdef(
109110
items, types, default_items = result
110111
if is_func_scope and "@" not in defn.name:
111112
defn.name += "@" + str(defn.line)
113+
existing_info = None
114+
if isinstance(defn.analyzed, NamedTupleExpr):
115+
existing_info = defn.analyzed.info
112116
info = self.build_namedtuple_typeinfo(
113-
defn.name, items, types, default_items, defn.line
117+
defn.name, items, types, default_items, defn.line, existing_info
114118
)
115119
defn.info = info
116120
defn.analyzed = NamedTupleExpr(info, is_typed=True)
@@ -164,7 +168,14 @@ def check_namedtuple_classdef(
164168
if stmt.type is None:
165169
types.append(AnyType(TypeOfAny.unannotated))
166170
else:
167-
analyzed = self.api.anal_type(stmt.type)
171+
# We never allow recursive types at function scope. Although it is
172+
# possible to support this for named tuples, it is still tricky, and
173+
# it would be inconsistent with type aliases.
174+
analyzed = self.api.anal_type(
175+
stmt.type,
176+
allow_placeholder=self.options.enable_recursive_aliases
177+
and not self.api.is_func_scope(),
178+
)
168179
if analyzed is None:
169180
# Something is incomplete. We need to defer this named tuple.
170181
return None
@@ -226,7 +237,7 @@ def check_namedtuple(
226237
name += "@" + str(call.line)
227238
else:
228239
name = var_name = "namedtuple@" + str(call.line)
229-
info = self.build_namedtuple_typeinfo(name, [], [], {}, node.line)
240+
info = self.build_namedtuple_typeinfo(name, [], [], {}, node.line, None)
230241
self.store_namedtuple_info(info, var_name, call, is_typed)
231242
if name != var_name or is_func_scope:
232243
# NOTE: we skip local namespaces since they are not serialized.
@@ -262,12 +273,22 @@ def check_namedtuple(
262273
}
263274
else:
264275
default_items = {}
265-
info = self.build_namedtuple_typeinfo(name, items, types, default_items, node.line)
276+
277+
existing_info = None
278+
if isinstance(node.analyzed, NamedTupleExpr):
279+
existing_info = node.analyzed.info
280+
info = self.build_namedtuple_typeinfo(
281+
name, items, types, default_items, node.line, existing_info
282+
)
283+
266284
# If var_name is not None (i.e. this is not a base class expression), we always
267285
# store the generated TypeInfo under var_name in the current scope, so that
268286
# other definitions can use it.
269287
if var_name:
270288
self.store_namedtuple_info(info, var_name, call, is_typed)
289+
else:
290+
call.analyzed = NamedTupleExpr(info, is_typed=is_typed)
291+
call.analyzed.set_line(call)
271292
# There are three cases where we need to store the generated TypeInfo
272293
# second time (for the purpose of serialization):
273294
# * If there is a name mismatch like One = NamedTuple('Other', [...])
@@ -408,7 +429,12 @@ def parse_namedtuple_fields_with_types(
408429
except TypeTranslationError:
409430
self.fail("Invalid field type", type_node)
410431
return None
411-
analyzed = self.api.anal_type(type)
432+
# We never allow recursive types at function scope.
433+
analyzed = self.api.anal_type(
434+
type,
435+
allow_placeholder=self.options.enable_recursive_aliases
436+
and not self.api.is_func_scope(),
437+
)
412438
# Workaround #4987 and avoid introducing a bogus UnboundType
413439
if isinstance(analyzed, UnboundType):
414440
analyzed = AnyType(TypeOfAny.from_error)
@@ -428,6 +454,7 @@ def build_namedtuple_typeinfo(
428454
types: List[Type],
429455
default_items: Mapping[str, Expression],
430456
line: int,
457+
existing_info: Optional[TypeInfo],
431458
) -> TypeInfo:
432459
strtype = self.api.named_type("builtins.str")
433460
implicit_any = AnyType(TypeOfAny.special_form)
@@ -448,18 +475,23 @@ def build_namedtuple_typeinfo(
448475
literals: List[Type] = [LiteralType(item, strtype) for item in items]
449476
match_args_type = TupleType(literals, basetuple_type)
450477

451-
info = self.api.basic_new_typeinfo(name, fallback, line)
478+
info = existing_info or self.api.basic_new_typeinfo(name, fallback, line)
452479
info.is_named_tuple = True
453480
tuple_base = TupleType(types, fallback)
454-
info.tuple_type = tuple_base
481+
if info.tuple_alias and has_placeholder(info.tuple_alias.target):
482+
self.api.defer(force_progress=True)
483+
info.update_tuple_type(tuple_base)
455484
info.line = line
456485
# For use by mypyc.
457486
info.metadata["namedtuple"] = {"fields": items.copy()}
458487

459488
# We can't calculate the complete fallback type until after semantic
460489
# analysis, since otherwise base classes might be incomplete. Postpone a
461490
# callback function that patches the fallback.
462-
self.api.schedule_patch(PRIORITY_FALLBACKS, lambda: calculate_tuple_fallback(tuple_base))
491+
if not has_placeholder(tuple_base):
492+
self.api.schedule_patch(
493+
PRIORITY_FALLBACKS, lambda: calculate_tuple_fallback(tuple_base)
494+
)
463495

464496
def add_field(
465497
var: Var, is_initialized_in_class: bool = False, is_property: bool = False
@@ -489,6 +521,7 @@ def add_field(
489521
if self.options.python_version >= (3, 10):
490522
add_field(Var("__match_args__", match_args_type), is_initialized_in_class=True)
491523

524+
assert info.tuple_type is not None # Set by update_tuple_type() above.
492525
tvd = TypeVarType(
493526
SELF_TVAR_NAME, info.fullname + "." + SELF_TVAR_NAME, -1, [], info.tuple_type
494527
)

0 commit comments

Comments
 (0)