Skip to content

Enable recursive type aliases behind a flag #13297

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Aug 3, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -3662,20 +3662,23 @@ def check_simple_assignment(
# '...' is always a valid initializer in a stub.
return AnyType(TypeOfAny.special_form)
else:
orig_lvalue = lvalue_type
lvalue_type = get_proper_type(lvalue_type)
always_allow_any = lvalue_type is not None and not isinstance(lvalue_type, AnyType)
rvalue_type = self.expr_checker.accept(
rvalue, lvalue_type, always_allow_any=always_allow_any
)
orig_rvalue = rvalue_type
rvalue_type = get_proper_type(rvalue_type)
if isinstance(rvalue_type, DeletedType):
self.msg.deleted_as_rvalue(rvalue_type, context)
if isinstance(lvalue_type, DeletedType):
self.msg.deleted_as_lvalue(lvalue_type, context)
elif lvalue_type:
self.check_subtype(
rvalue_type,
lvalue_type,
# Preserve original aliases for error messages when possible.
orig_rvalue,
orig_lvalue or lvalue_type,
context,
msg,
f"{rvalue_name} has type",
Expand Down Expand Up @@ -5568,7 +5571,9 @@ def check_subtype(
code = msg.code
else:
msg_text = msg
orig_subtype = subtype
subtype = get_proper_type(subtype)
orig_supertype = supertype
supertype = get_proper_type(supertype)
if self.msg.try_report_long_tuple_assignment_error(
subtype, supertype, context, msg_text, subtype_label, supertype_label, code=code
Expand All @@ -5580,7 +5585,7 @@ def check_subtype(
note_msg = ""
notes: List[str] = []
if subtype_label is not None or supertype_label is not None:
subtype_str, supertype_str = format_type_distinctly(subtype, supertype)
subtype_str, supertype_str = format_type_distinctly(orig_subtype, orig_supertype)
if subtype_label is not None:
extra_info.append(subtype_label + " " + subtype_str)
if supertype_label is not None:
Expand Down
13 changes: 13 additions & 0 deletions mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@
flatten_nested_unions,
get_proper_type,
get_proper_types,
has_recursive_types,
is_generic_instance,
is_named_instance,
is_optional,
Expand Down Expand Up @@ -1534,13 +1535,25 @@ def infer_function_type_arguments(
else:
pass1_args.append(arg)

# This is a hack to better support inference for recursive types.
# When the outer context for a function call is known to be recursive,
# we solve type constraints inferred from arguments using unions instead
# of joins. This is a bit arbitrary, but in practice it works for most
# cases. A cleaner alternative would be to switch to single bin type
# inference, but this is a lot of work.
ctx = self.type_context[-1]
if ctx and has_recursive_types(ctx):
infer_unions = True
else:
infer_unions = False
inferred_args = infer_function_type_arguments(
callee_type,
pass1_args,
arg_kinds,
formal_to_actual,
context=self.argument_infer_context(),
strict=self.chk.in_checked_function(),
infer_unions=infer_unions,
)

if 2 in arg_pass_nums:
Expand Down
3 changes: 2 additions & 1 deletion mypy/expandtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,8 @@ def expand_types_with_unpack(
else:
items.extend(unpacked_items)
else:
items.append(proper_item.accept(self))
# Must preserve original aliases when possible.
items.append(item.accept(self))
return items

def visit_tuple_type(self, t: TupleType) -> Type:
Expand Down
3 changes: 2 additions & 1 deletion mypy/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def infer_function_type_arguments(
formal_to_actual: List[List[int]],
context: ArgumentInferContext,
strict: bool = True,
infer_unions: bool = False,
) -> List[Optional[Type]]:
"""Infer the type arguments of a generic function.

Expand All @@ -55,7 +56,7 @@ def infer_function_type_arguments(

# Solve constraints.
type_vars = callee_type.type_var_ids()
return solve_constraints(type_vars, constraints, strict)
return solve_constraints(type_vars, constraints, strict, infer_unions=infer_unions)


def infer_type_arguments(
Expand Down
5 changes: 5 additions & 0 deletions mypy/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -977,6 +977,11 @@ def add_invertible_flag(
dest="custom_typing_module",
help="Use a custom typing module",
)
internals_group.add_argument(
"--enable-recursive-aliases",
action="store_true",
help="Experimental support for recursive type aliases",
)
internals_group.add_argument(
"--custom-typeshed-dir", metavar="DIR", help="Use the custom typeshed in DIR"
)
Expand Down
13 changes: 12 additions & 1 deletion mypy/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@
ProperType,
TupleType,
Type,
TypeAliasType,
TypedDictType,
TypeOfAny,
TypeType,
Expand Down Expand Up @@ -2128,7 +2129,17 @@ def format_literal_value(typ: LiteralType) -> str:
else:
return typ.value_repr()

# TODO: show type alias names in errors.
if isinstance(typ, TypeAliasType) and typ.is_recursive:
# TODO: find balance here, str(typ) doesn't support custom verbosity, and may be
# too verbose for user messages, OTOH it nicely shows structure of recursive types.
if verbosity < 2:
type_str = typ.alias.name if typ.alias else "<alias (unfixed)>"
if typ.args:
type_str += f"[{format_list(typ.args)}]"
return type_str
return str(typ)

# TODO: always mention type alias names in errors.
typ = get_proper_type(typ)

if isinstance(typ, Instance):
Expand Down
2 changes: 2 additions & 0 deletions mypy/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,8 @@ def __init__(self) -> None:
# skip most errors after this many messages have been reported.
# -1 means unlimited.
self.many_errors_threshold = defaults.MANY_ERRORS_THRESHOLD
# Enable recursive type aliases (currently experimental)
self.enable_recursive_aliases = False

# To avoid breaking plugin compatibility, keep providing new_semantic_analyzer
@property
Expand Down
4 changes: 4 additions & 0 deletions mypy/sametypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@
def is_same_type(left: Type, right: Type) -> bool:
"""Is 'left' the same type as 'right'?"""

if isinstance(left, TypeAliasType) and isinstance(right, TypeAliasType):
if left.is_recursive and right.is_recursive:
return left.alias == right.alias and left.args == right.args

left = get_proper_type(left)
right = get_proper_type(right)

Expand Down
80 changes: 71 additions & 9 deletions mypy/semanal.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,6 +450,14 @@ def __init__(
# current SCC or top-level function.
self.deferral_debug_context: List[Tuple[str, int]] = []

# This is needed to properly support recursive type aliases. The problem is that
# Foo[Bar] could mean three things depending on context: a target for type alias,
# a normal index expression (including enum index), or a type application.
# The latter is particularly problematic as it can falsely create incomplete
# refs while analysing rvalues of type aliases. To avoid this we first analyse
# rvalues while temporarily setting this to True.
self.basic_type_applications = False

# mypyc doesn't properly handle implementing an abstractproperty
# with a regular attribute so we make them properties
@property
Expand Down Expand Up @@ -2286,14 +2294,25 @@ def visit_assignment_stmt(self, s: AssignmentStmt) -> None:
return

tag = self.track_incomplete_refs()
s.rvalue.accept(self)

# Here we have a chicken and egg problem: at this stage we can't call
# can_be_type_alias(), because we have not enough information about rvalue.
# But we can't use a full visit because it may emit extra incomplete refs (namely
# when analysing any type applications there) thus preventing the further analysis.
# To break the tie, we first analyse rvalue partially, if it can be a type alias.
with self.basic_type_applications_set(s):
s.rvalue.accept(self)
if self.found_incomplete_ref(tag) or self.should_wait_rhs(s.rvalue):
# Initializer couldn't be fully analyzed. Defer the current node and give up.
# Make sure that if we skip the definition of some local names, they can't be
# added later in this scope, since an earlier definition should take precedence.
for expr in names_modified_by_assignment(s):
self.mark_incomplete(expr.name, expr)
return
if self.can_possibly_be_index_alias(s):
# Now re-visit those rvalues that were we skipped type applications above.
# This should be safe as generally semantic analyzer is idempotent.
s.rvalue.accept(self)

# The r.h.s. is now ready to be classified, first check if it is a special form:
special_form = False
Expand Down Expand Up @@ -2432,6 +2451,36 @@ def can_be_type_alias(self, rv: Expression, allow_none: bool = False) -> bool:
return True
return False

def can_possibly_be_index_alias(self, s: AssignmentStmt) -> bool:
"""Like can_be_type_alias(), but simpler and doesn't require analyzed rvalue.

Instead, use lvalues/annotations structure to figure out whether this can
potentially be a type alias definition. Another difference from above function
is that we are only interested IndexExpr and OpExpr rvalues, since only those
can be potentially recursive (things like `A = A` are never valid).
"""
if len(s.lvalues) > 1:
return False
if not isinstance(s.lvalues[0], NameExpr):
return False
if s.unanalyzed_type is not None and not self.is_pep_613(s):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't we always return True if is_pep_613(s) returns True?

It may also be reasonable to require explicit TypeAlias for recursive aliases.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't we always return True if is_pep_613(s) returns True?

We are interested only in aliases that can be recursive, so as a micro-optimization we only do this for r.h.s that can be recursive. I will add a comment on this.

It may also be reasonable to require explicit TypeAlias for recursive aliases.

Hm, yeah this will simplify logic a bit. OTOH it may be hard to reliably detect cases where it was intended and missing, to give a good error message. I will keep the current logic as is, it is easy to change later, as this part is decoupled from everything else.

return False
if not isinstance(s.rvalue, (IndexExpr, OpExpr)):
return False
# Something that looks like Foo = Bar[Baz, ...]
return True

@contextmanager
def basic_type_applications_set(self, s: AssignmentStmt) -> Iterator[None]:
old = self.basic_type_applications
# As an optimization, only use the double visit logic if this
# can possibly be a recursive type alias.
self.basic_type_applications = self.can_possibly_be_index_alias(s)
try:
yield
finally:
self.basic_type_applications = old

def is_type_ref(self, rv: Expression, bare: bool = False) -> bool:
"""Does this expression refer to a type?

Expand Down Expand Up @@ -2908,6 +2957,13 @@ def analyze_alias(
qualified_tvars = []
return typ, alias_tvars, depends_on, qualified_tvars

def is_pep_613(self, s: AssignmentStmt) -> bool:
if s.unanalyzed_type is not None and isinstance(s.unanalyzed_type, UnboundType):
lookup = self.lookup_qualified(s.unanalyzed_type.name, s, suppress_errors=True)
if lookup and lookup.fullname in TYPE_ALIAS_NAMES:
return True
return False

def check_and_set_up_type_alias(self, s: AssignmentStmt) -> bool:
"""Check if assignment creates a type alias and set it up as needed.

Expand All @@ -2922,11 +2978,7 @@ def check_and_set_up_type_alias(self, s: AssignmentStmt) -> bool:
# First rule: Only simple assignments like Alias = ... create aliases.
return False

pep_613 = False
if s.unanalyzed_type is not None and isinstance(s.unanalyzed_type, UnboundType):
lookup = self.lookup_qualified(s.unanalyzed_type.name, s, suppress_errors=True)
if lookup and lookup.fullname in TYPE_ALIAS_NAMES:
pep_613 = True
pep_613 = self.is_pep_613(s)
if not pep_613 and s.unanalyzed_type is not None:
# Second rule: Explicit type (cls: Type[A] = A) always creates variable, not alias.
# unless using PEP 613 `cls: TypeAlias = A`
Expand Down Expand Up @@ -2990,9 +3042,16 @@ def check_and_set_up_type_alias(self, s: AssignmentStmt) -> bool:
)
if not res:
return False
# TODO: Maybe we only need to reject top-level placeholders, similar
# to base classes.
if self.found_incomplete_ref(tag) or has_placeholder(res):
if self.options.enable_recursive_aliases:
# Only marking incomplete for top-level placeholders makes recursive aliases like
# `A = Sequence[str | A]` valid here, similar to how we treat base classes in class
# definitions, allowing `class str(Sequence[str]): ...`
incomplete_target = isinstance(res, ProperType) and isinstance(
res, PlaceholderType
)
else:
incomplete_target = has_placeholder(res)
if self.found_incomplete_ref(tag) or incomplete_target:
# Since we have got here, we know this must be a type alias (incomplete refs
# may appear in nested positions), therefore use becomes_typeinfo=True.
self.mark_incomplete(lvalue.name, rvalue, becomes_typeinfo=True)
Expand Down Expand Up @@ -4499,6 +4558,9 @@ def analyze_type_application_args(self, expr: IndexExpr) -> Optional[List[Type]]
self.analyze_type_expr(index)
if self.found_incomplete_ref(tag):
return None
if self.basic_type_applications:
# Postpone the rest until we have more information (for r.h.s. of an assignment)
return None
types: List[Type] = []
if isinstance(index, TupleExpr):
items = index.items
Expand Down
22 changes: 19 additions & 3 deletions mypy/solve.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,22 @@
from mypy.join import join_types
from mypy.meet import meet_types
from mypy.subtypes import is_subtype
from mypy.types import AnyType, Type, TypeOfAny, TypeVarId, UninhabitedType, get_proper_type
from mypy.types import (
AnyType,
Type,
TypeOfAny,
TypeVarId,
UninhabitedType,
UnionType,
get_proper_type,
)


def solve_constraints(
vars: List[TypeVarId], constraints: List[Constraint], strict: bool = True
vars: List[TypeVarId],
constraints: List[Constraint],
strict: bool = True,
infer_unions: bool = False,
) -> List[Optional[Type]]:
"""Solve type constraints.

Expand Down Expand Up @@ -43,7 +54,12 @@ def solve_constraints(
if bottom is None:
bottom = c.target
else:
bottom = join_types(bottom, c.target)
if infer_unions:
# This deviates from the general mypy semantics because
# recursive types are union-heavy in 95% of cases.
bottom = UnionType.make_union([bottom, c.target])
else:
bottom = join_types(bottom, c.target)
else:
if top is None:
top = c.target
Expand Down
4 changes: 3 additions & 1 deletion mypy/subtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,13 +105,15 @@ def is_subtype(
if TypeState.is_assumed_subtype(left, right):
return True
if (
# TODO: recursive instances like `class str(Sequence[str])` can also cause
# issues, so we also need to include them in the assumptions stack
isinstance(left, TypeAliasType)
and isinstance(right, TypeAliasType)
and left.is_recursive
and right.is_recursive
):
# This case requires special care because it may cause infinite recursion.
# Our view on recursive types is known under a fancy name of equirecursive mu-types.
# Our view on recursive types is known under a fancy name of iso-recursive mu-types.
# Roughly this means that a recursive type is defined as an alias where right hand side
# can refer to the type as a whole, for example:
# A = Union[int, Tuple[A, ...]]
Expand Down
Loading