-
Notifications
You must be signed in to change notification settings - Fork 528
/
Copy pathreplace_aten_with_edge_pass.py
59 lines (47 loc) · 1.79 KB
/
replace_aten_with_edge_pass.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
# pyre-unsafe
import torch
from executorch.exir.dialects._ops import ops
from executorch.exir.dialects.edge._ops import EdgeOpOverload
from executorch.exir.pass_base import ExportPass
from executorch.exir.passes.executorch_prim_ops_registry import _EXECUTORCH_SYM_OPS
from torch.fx.node import Target
DISALLOW_LIST = [
torch.ops.aten._assert_scalar.default,
torch.ops.aten._assert_async.msg,
torch.ops.aten.scalar_tensor.default,
]
def aten_to_edge(aten_op: torch._ops.OpOverload) -> EdgeOpOverload:
# Assume qualified op name: aten::add.Tensor
op_namespace, op_name, op_overload_name = (
aten_op.namespace,
aten_op._schema.name.split("::")[1],
aten_op._overloadname,
)
edge_op = getattr(
getattr(getattr(ops.edge, op_namespace), op_name), op_overload_name
)
return edge_op
def should_lower_to_edge(op: Target) -> bool:
"""Returns true if the given operator should be lowered to edge op."""
return (
isinstance(op, torch._ops.OpOverload)
and op not in _EXECUTORCH_SYM_OPS
and op not in DISALLOW_LIST
)
class OpReplacePass(ExportPass):
"""
Goes through all ops and replaces torch (aten + custom) ops with edge ops.
Exclude those ops that don't care about input dtypes and out variants.
"""
def __init__(self) -> None:
super().__init__()
def call_operator(self, op, args, kwargs, meta):
if should_lower_to_edge(op):
return super().call_operator(aten_to_edge(op), args, kwargs, meta)
else:
return super().call_operator(op, args, kwargs, meta)