Skip to content

Commit ddf03d1

Browse files
sixoletJukkaL
authored andcommitted
Better callable: Callable[[Arg('x', int), VarArg(str)], int] now a thing you can do (#2607)
Implements an experimental feature to allow Callable to have any kind of signature an actual function definition does. This should enable better typing of callbacks &c. Initial discussion: python/typing#239 Proposal, v. similar to this impl: python/typing#264 Relevant typeshed PR: python/typeshed#793
1 parent 058a8a6 commit ddf03d1

19 files changed

+660
-81
lines changed

extensions/mypy_extensions.py

+36
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
from mypy_extensions import TypedDict
66
"""
77

8+
from typing import Any
9+
810
# NOTE: This module must support Python 2.7 in addition to Python 3.x
911

1012
import sys
@@ -92,6 +94,40 @@ class Point2D(TypedDict):
9294
syntax forms work for Python 2.7 and 3.2+
9395
"""
9496

97+
# Argument constructors for making more-detailed Callables. These all just
98+
# return their type argument, to make them complete noops in terms of the
99+
# `typing` module.
100+
101+
102+
def Arg(type=Any, name=None):
103+
"""A normal positional argument"""
104+
return type
105+
106+
107+
def DefaultArg(type=Any, name=None):
108+
"""A positional argument with a default value"""
109+
return type
110+
111+
112+
def NamedArg(type=Any, name=None):
113+
"""A keyword-only argument"""
114+
return type
115+
116+
117+
def DefaultNamedArg(type=Any, name=None):
118+
"""A keyword-only argument with a default value"""
119+
return type
120+
121+
122+
def VarArg(type=Any):
123+
"""A *args-style variadic positional argument"""
124+
return type
125+
126+
127+
def KwArg(type=Any):
128+
"""A **kwargs-style variadic keyword argument"""
129+
return type
130+
95131

96132
# Return type that indicates a function does not return
97133
class NoReturn: pass

mypy/exprtotype.py

+65-7
Original file line numberDiff line numberDiff line change
@@ -2,23 +2,38 @@
22

33
from mypy.nodes import (
44
Expression, NameExpr, MemberExpr, IndexExpr, TupleExpr,
5-
ListExpr, StrExpr, BytesExpr, UnicodeExpr, EllipsisExpr,
6-
get_member_expr_fullname
5+
ListExpr, StrExpr, BytesExpr, UnicodeExpr, EllipsisExpr, CallExpr,
6+
ARG_POS, ARG_NAMED, get_member_expr_fullname
77
)
88
from mypy.fastparse import parse_type_comment
9-
from mypy.types import Type, UnboundType, TypeList, EllipsisType
9+
from mypy.types import (
10+
Type, UnboundType, TypeList, EllipsisType, AnyType, Optional, CallableArgument,
11+
)
1012

1113

1214
class TypeTranslationError(Exception):
1315
"""Exception raised when an expression is not valid as a type."""
1416

1517

16-
def expr_to_unanalyzed_type(expr: Expression) -> Type:
18+
def _extract_argument_name(expr: Expression) -> Optional[str]:
19+
if isinstance(expr, NameExpr) and expr.name == 'None':
20+
return None
21+
elif isinstance(expr, StrExpr):
22+
return expr.value
23+
elif isinstance(expr, UnicodeExpr):
24+
return expr.value
25+
else:
26+
raise TypeTranslationError()
27+
28+
29+
def expr_to_unanalyzed_type(expr: Expression, _parent: Optional[Expression] = None) -> Type:
1730
"""Translate an expression to the corresponding type.
1831
1932
The result is not semantically analyzed. It can be UnboundType or TypeList.
2033
Raise TypeTranslationError if the expression cannot represent a type.
2134
"""
35+
# The `parent` paremeter is used in recursive calls to provide context for
36+
# understanding whether an CallableArgument is ok.
2237
if isinstance(expr, NameExpr):
2338
name = expr.name
2439
return UnboundType(name, line=expr.line, column=expr.column)
@@ -29,22 +44,65 @@ def expr_to_unanalyzed_type(expr: Expression) -> Type:
2944
else:
3045
raise TypeTranslationError()
3146
elif isinstance(expr, IndexExpr):
32-
base = expr_to_unanalyzed_type(expr.base)
47+
base = expr_to_unanalyzed_type(expr.base, expr)
3348
if isinstance(base, UnboundType):
3449
if base.args:
3550
raise TypeTranslationError()
3651
if isinstance(expr.index, TupleExpr):
3752
args = expr.index.items
3853
else:
3954
args = [expr.index]
40-
base.args = [expr_to_unanalyzed_type(arg) for arg in args]
55+
base.args = [expr_to_unanalyzed_type(arg, expr) for arg in args]
4156
if not base.args:
4257
base.empty_tuple_index = True
4358
return base
4459
else:
4560
raise TypeTranslationError()
61+
elif isinstance(expr, CallExpr) and isinstance(_parent, ListExpr):
62+
c = expr.callee
63+
names = []
64+
# Go through the dotted member expr chain to get the full arg
65+
# constructor name to look up
66+
while True:
67+
if isinstance(c, NameExpr):
68+
names.append(c.name)
69+
break
70+
elif isinstance(c, MemberExpr):
71+
names.append(c.name)
72+
c = c.expr
73+
else:
74+
raise TypeTranslationError()
75+
arg_const = '.'.join(reversed(names))
76+
77+
# Go through the constructor args to get its name and type.
78+
name = None
79+
default_type = AnyType(implicit=True)
80+
typ = default_type # type: Type
81+
for i, arg in enumerate(expr.args):
82+
if expr.arg_names[i] is not None:
83+
if expr.arg_names[i] == "name":
84+
if name is not None:
85+
# Two names
86+
raise TypeTranslationError()
87+
name = _extract_argument_name(arg)
88+
continue
89+
elif expr.arg_names[i] == "type":
90+
if typ is not default_type:
91+
# Two types
92+
raise TypeTranslationError()
93+
typ = expr_to_unanalyzed_type(arg, expr)
94+
continue
95+
else:
96+
raise TypeTranslationError()
97+
elif i == 0:
98+
typ = expr_to_unanalyzed_type(arg, expr)
99+
elif i == 1:
100+
name = _extract_argument_name(arg)
101+
else:
102+
raise TypeTranslationError()
103+
return CallableArgument(typ, name, arg_const, expr.line, expr.column)
46104
elif isinstance(expr, ListExpr):
47-
return TypeList([expr_to_unanalyzed_type(t) for t in expr.items],
105+
return TypeList([expr_to_unanalyzed_type(t, expr) for t in expr.items],
48106
line=expr.line, column=expr.column)
49107
elif isinstance(expr, (StrExpr, BytesExpr, UnicodeExpr)):
50108
# Parse string literal type.

mypy/fastparse.py

+83-19
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,12 @@
1919
StarExpr, YieldFromExpr, NonlocalDecl, DictionaryComprehension,
2020
SetComprehension, ComplexExpr, EllipsisExpr, YieldExpr, Argument,
2121
AwaitExpr, TempNode, Expression, Statement,
22-
ARG_POS, ARG_OPT, ARG_STAR, ARG_NAMED, ARG_NAMED_OPT, ARG_STAR2
22+
ARG_POS, ARG_OPT, ARG_STAR, ARG_NAMED, ARG_NAMED_OPT, ARG_STAR2,
23+
check_arg_names,
2324
)
2425
from mypy.types import (
2526
Type, CallableType, AnyType, UnboundType, TupleType, TypeList, EllipsisType,
27+
CallableArgument,
2628
)
2729
from mypy import defaults
2830
from mypy import experiments
@@ -444,24 +446,12 @@ def make_argument(arg: ast3.arg, default: Optional[ast3.expr], kind: int) -> Arg
444446
new_args.append(make_argument(args.kwarg, None, ARG_STAR2))
445447
names.append(args.kwarg)
446448

447-
seen_names = set() # type: Set[str]
448-
for name in names:
449-
if name.arg in seen_names:
450-
self.fail("duplicate argument '{}' in function definition".format(name.arg),
451-
name.lineno, name.col_offset)
452-
break
453-
seen_names.add(name.arg)
449+
def fail_arg(msg: str, arg: ast3.arg) -> None:
450+
self.fail(msg, arg.lineno, arg.col_offset)
454451

455-
return new_args
452+
check_arg_names([name.arg for name in names], names, fail_arg)
456453

457-
def stringify_name(self, n: ast3.AST) -> str:
458-
if isinstance(n, ast3.Name):
459-
return n.id
460-
elif isinstance(n, ast3.Attribute):
461-
sv = self.stringify_name(n.value)
462-
if sv is not None:
463-
return "{}.{}".format(sv, n.attr)
464-
return None # Can't do it.
454+
return new_args
465455

466456
# ClassDef(identifier name,
467457
# expr* bases,
@@ -474,7 +464,7 @@ def visit_ClassDef(self, n: ast3.ClassDef) -> ClassDef:
474464
metaclass_arg = find(lambda x: x.arg == 'metaclass', n.keywords)
475465
metaclass = None
476466
if metaclass_arg:
477-
metaclass = self.stringify_name(metaclass_arg.value)
467+
metaclass = stringify_name(metaclass_arg.value)
478468
if metaclass is None:
479469
metaclass = '<error>' # To be reported later
480470

@@ -965,6 +955,21 @@ class TypeConverter(ast3.NodeTransformer): # type: ignore # typeshed PR #931
965955
def __init__(self, errors: Errors, line: int = -1) -> None:
966956
self.errors = errors
967957
self.line = line
958+
self.node_stack = [] # type: List[ast3.AST]
959+
960+
def visit(self, node: ast3.AST) -> Type:
961+
"""Modified visit -- keep track of the stack of nodes"""
962+
self.node_stack.append(node)
963+
try:
964+
return super().visit(node)
965+
finally:
966+
self.node_stack.pop()
967+
968+
def parent(self) -> ast3.AST:
969+
"""Return the AST node above the one we are processing"""
970+
if len(self.node_stack) < 2:
971+
return None
972+
return self.node_stack[-2]
968973

969974
def fail(self, msg: str, line: int, column: int) -> None:
970975
self.errors.report(line, column, msg)
@@ -985,6 +990,55 @@ def visit_NoneType(self, n: Any) -> Type:
985990
def translate_expr_list(self, l: Sequence[ast3.AST]) -> List[Type]:
986991
return [self.visit(e) for e in l]
987992

993+
def visit_Call(self, e: ast3.Call) -> Type:
994+
# Parse the arg constructor
995+
if not isinstance(self.parent(), ast3.List):
996+
return self.generic_visit(e)
997+
f = e.func
998+
constructor = stringify_name(f)
999+
if not constructor:
1000+
self.fail("Expected arg constructor name", e.lineno, e.col_offset)
1001+
name = None # type: Optional[str]
1002+
default_type = AnyType(implicit=True)
1003+
typ = default_type # type: Type
1004+
for i, arg in enumerate(e.args):
1005+
if i == 0:
1006+
typ = self.visit(arg)
1007+
elif i == 1:
1008+
name = self._extract_argument_name(arg)
1009+
else:
1010+
self.fail("Too many arguments for argument constructor",
1011+
f.lineno, f.col_offset)
1012+
for k in e.keywords:
1013+
value = k.value
1014+
if k.arg == "name":
1015+
if name is not None:
1016+
self.fail('"{}" gets multiple values for keyword argument "name"'.format(
1017+
constructor), f.lineno, f.col_offset)
1018+
name = self._extract_argument_name(value)
1019+
elif k.arg == "type":
1020+
if typ is not default_type:
1021+
self.fail('"{}" gets multiple values for keyword argument "type"'.format(
1022+
constructor), f.lineno, f.col_offset)
1023+
typ = self.visit(value)
1024+
else:
1025+
self.fail(
1026+
'Unexpected argument "{}" for argument constructor'.format(k.arg),
1027+
value.lineno, value.col_offset)
1028+
return CallableArgument(typ, name, constructor, e.lineno, e.col_offset)
1029+
1030+
def translate_argument_list(self, l: Sequence[ast3.AST]) -> TypeList:
1031+
return TypeList([self.visit(e) for e in l], line=self.line)
1032+
1033+
def _extract_argument_name(self, n: ast3.expr) -> str:
1034+
if isinstance(n, ast3.Str):
1035+
return n.s.strip()
1036+
elif isinstance(n, ast3.NameConstant) and str(n.value) == 'None':
1037+
return None
1038+
self.fail('Expected string literal for argument name, got {}'.format(
1039+
type(n).__name__), self.line, 0)
1040+
return None
1041+
9881042
def visit_Name(self, n: ast3.Name) -> Type:
9891043
return UnboundType(n.id, line=self.line)
9901044

@@ -1036,4 +1090,14 @@ def visit_Ellipsis(self, n: ast3.Ellipsis) -> Type:
10361090

10371091
# List(expr* elts, expr_context ctx)
10381092
def visit_List(self, n: ast3.List) -> Type:
1039-
return TypeList(self.translate_expr_list(n.elts), line=self.line)
1093+
return self.translate_argument_list(n.elts)
1094+
1095+
1096+
def stringify_name(n: ast3.AST) -> Optional[str]:
1097+
if isinstance(n, ast3.Name):
1098+
return n.id
1099+
elif isinstance(n, ast3.Attribute):
1100+
sv = stringify_name(n.value)
1101+
if sv is not None:
1102+
return "{}.{}".format(sv, n.attr)
1103+
return None # Can't do it.

mypy/fastparse2.py

+5-7
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
UnaryExpr, LambdaExpr, ComparisonExpr, DictionaryComprehension,
3434
SetComprehension, ComplexExpr, EllipsisExpr, YieldExpr, Argument,
3535
Expression, Statement, BackquoteExpr, PrintStmt, ExecStmt,
36-
ARG_POS, ARG_OPT, ARG_STAR, ARG_NAMED, ARG_STAR2, OverloadPart,
36+
ARG_POS, ARG_OPT, ARG_STAR, ARG_NAMED, ARG_STAR2, OverloadPart, check_arg_names,
3737
)
3838
from mypy.types import (
3939
Type, CallableType, AnyType, UnboundType, EllipsisType
@@ -439,12 +439,10 @@ def get_type(i: int) -> Optional[Type]:
439439
new_args.append(Argument(Var(n.kwarg), typ, None, ARG_STAR2))
440440
names.append(n.kwarg)
441441

442-
seen_names = set() # type: Set[str]
443-
for name in names:
444-
if name in seen_names:
445-
self.fail("duplicate argument '{}' in function definition".format(name), line, 0)
446-
break
447-
seen_names.add(name)
442+
# We don't have any context object to give, but we have closed around the line num
443+
def fail_arg(msg: str, arg: None) -> None:
444+
self.fail(msg, line, 0)
445+
check_arg_names(names, [None] * len(names), fail_arg)
448446

449447
return new_args, decompose_stmts
450448

mypy/indirection.py

+3
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,9 @@ def visit_unbound_type(self, t: types.UnboundType) -> Set[str]:
4545
def visit_type_list(self, t: types.TypeList) -> Set[str]:
4646
return self._visit(*t.items)
4747

48+
def visit_callable_argument(self, t: types.CallableArgument) -> Set[str]:
49+
return self._visit(t.typ)
50+
4851
def visit_any(self, t: types.AnyType) -> Set[str]:
4952
return set()
5053

mypy/messages.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@
9393
ARG_OPT: "DefaultArg",
9494
ARG_NAMED: "NamedArg",
9595
ARG_NAMED_OPT: "DefaultNamedArg",
96-
ARG_STAR: "StarArg",
96+
ARG_STAR: "VarArg",
9797
ARG_STAR2: "KwArg",
9898
}
9999

@@ -214,15 +214,15 @@ def format(self, typ: Type, verbosity: int = 0) -> str:
214214
verbosity = max(verbosity - 1, 0))))
215215
else:
216216
constructor = ARG_CONSTRUCTOR_NAMES[arg_kind]
217-
if arg_kind in (ARG_STAR, ARG_STAR2):
217+
if arg_kind in (ARG_STAR, ARG_STAR2) or arg_name is None:
218218
arg_strings.append("{}({})".format(
219219
constructor,
220220
strip_quotes(self.format(arg_type))))
221221
else:
222-
arg_strings.append("{}('{}', {})".format(
222+
arg_strings.append("{}({}, {})".format(
223223
constructor,
224-
arg_name,
225-
strip_quotes(self.format(arg_type))))
224+
strip_quotes(self.format(arg_type)),
225+
repr(arg_name)))
226226

227227
return 'Callable[[{}], {}]'.format(", ".join(arg_strings), return_type)
228228
else:

0 commit comments

Comments
 (0)