-
Notifications
You must be signed in to change notification settings - Fork 528
/
Copy patharg_validator.py
140 lines (124 loc) · 6.4 KB
/
arg_validator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from collections import defaultdict
from typing import Any, Dict, Optional, Sequence, Tuple
import torch
from executorch.exir.dialects.edge._ops import EdgeDialectFunctionSchema, EdgeOpOverload
from executorch.exir.emit._emitter import _Argument, _Target
from executorch.exir.error import ExportError, InternalError
from torch._ops import HigherOrderOperator
class RunHigherOrderOperatorError(Exception):
"""
Raised when an we try to run delegate or other HigherOrderOperator in a graph module.
E.g., %executorch_call_delegate : [#users=1] = call_function[
target=torch.ops.higher_order.executorch_call_delegate](args = (%lowered_module_0, %arg0_1), kwargs = {})
"""
def __init__(self, message: str) -> None:
super().__init__(message)
# pyre-ignore[13]: Attribute `node` is never initialized.
class EdgeOpArgValidator(torch.fx.Interpreter):
"""
Validate whether all the Tensor arguments passed to an operator are valid in terms of allowed dtype.
Expecting all the operators are EdgeOpOverload which contains the allowed dtype information.
Violating operators are being kept in self.violating_ops
"""
node: torch.fx.Node
def __init__(self, graph_module: torch.fx.GraphModule) -> None:
super().__init__(graph_module)
self.violating_ops: Dict[
EdgeOpOverload, Tuple[Dict[str, Optional[torch.dtype]], torch.fx.Node]
] = defaultdict(dict)
def run_node(self, n: torch.fx.Node) -> None:
self.node = n
try:
ret = super().run_node(n)
except Exception as e:
if isinstance(e, (InternalError, ExportError, RunHigherOrderOperatorError)):
raise e
else:
raise InternalError(str(e)) from e
return ret
def _get_kernel_arg(self, schema_arg, schema_arg_idx, args, kwargs):
if schema_arg.name in kwargs:
kernel_arg = kwargs[schema_arg.name]
elif not schema_arg.kwarg_only and schema_arg_idx < len(args):
kernel_arg = args[schema_arg_idx]
else:
kernel_arg = schema_arg.default_value
return kernel_arg
def call_function( # noqa: C901 # pyre-fixme[14]
self, target: _Target, args: Tuple[_Argument, ...], kwargs: Dict[str, _Argument]
) -> Any:
"""
Go through all the node.target and validate their Tensor arguments are having the allowed dtypes.
"""
if not isinstance(target, EdgeOpOverload) or not isinstance(
target._schema, EdgeDialectFunctionSchema
):
if isinstance(target, HigherOrderOperator):
raise RunHigherOrderOperatorError("Can't run delegate")
return super().call_function(target, args, kwargs) # pyre-fixme[6]
# TODO(gasoonjia): Update Optional[torch.dtype] to a concrete class to support mixed dtypes in tensorlist.
tensor_arg_types: Dict[str, Optional[torch.dtype]] = {}
for i, schema_arg in enumerate(target._schema.arguments):
if (
isinstance(schema_arg.type, torch.TensorType)
or schema_arg.type == torch.OptionalType.ofTensor()
):
kernel_arg = self._get_kernel_arg(schema_arg, i, args, kwargs)
if not isinstance(kernel_arg, torch.Tensor):
continue
tensor_arg_types[schema_arg.name] = kernel_arg.dtype
elif schema_arg.type == torch.ListType.ofTensors():
kernel_arg = self._get_kernel_arg(schema_arg, i, args, kwargs)
if not isinstance(kernel_arg, list) or not all(
isinstance(kernel_arg[i], torch.Tensor)
for i in range(len(kernel_arg))
):
continue
if len(kernel_arg):
tensor_arg_types[schema_arg.name] = kernel_arg[0].dtype
else:
# If kernel_arg is an empty list, treat its type as None.
# FunctionDtypeConstraint.validate will take None as any legal dtype.
tensor_arg_types[schema_arg.name] = None
ret_index = 0
kernel_rets = self.node.meta["val"]
ret_iter = iter(
kernel_rets if isinstance(kernel_rets, Sequence) else [kernel_rets]
)
for schema_ret in target._schema.returns:
name = schema_ret.name if schema_ret.name else f"__ret_{ret_index}"
kernel_ret = next(ret_iter)
if isinstance(schema_ret.type, torch.TensorType):
if isinstance(kernel_ret, torch.Tensor):
tensor_arg_types[name] = kernel_ret.dtype
ret_index += 1
# Exceptionally rarely (basically only backwards ops) you might see an OptionalTensor returned.
# The schema of these ops though is typically -> (Tensor, Tensor ...). So the actual type
# returned in cpp is empty/undefined tensor. There is no analogy to this in python so it
# gets crudely mapped to None. To properly fix this core pytorch would have to change the
# schema to (Tensor?, ...) which is just never going to happen. So we have to handle this case
# here in the verifier and in memory planning as well.
elif kernel_ret is None:
tensor_arg_types[name] = schema_ret.default_value
ret_index += 1
else:
raise InternalError(
f"encountered return with type Tensor but value wasnt a tensor or None. schema:{target._schema}, output:{ret_index}"
)
elif schema_ret.type == torch.ListType.ofTensors() and all(
isinstance(kernel_ret[i], torch.Tensor) for i in range(len(kernel_ret))
):
if len(kernel_ret):
tensor_arg_types[name] = kernel_ret[0].dtype
else:
tensor_arg_types[name] = None
ret_index += 1
valid = target._schema.dtype_constraint.validate(tensor_arg_types)
if not valid:
self.violating_ops[target] = (tensor_arg_types, self.node)
return super().call_function(target, args, kwargs) # pyre-fixme[6]