-
Notifications
You must be signed in to change notification settings - Fork 528
/
Copy pathview_copy_to_squeeze_unsqueeze.py
135 lines (120 loc) · 4.64 KB
/
view_copy_to_squeeze_unsqueeze.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
# pyre-strict
from typing import List, Optional, Union
import torch
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass, PassResult
class ViewCopyToSqueezeUnsqueezePass(ExportPass):
"""
Replaces view_copy nodes with squeeze_copy.dims nodes if the view node reduces dims of size 1.
Replaces view_copy nodes with unsqueeze_copy.default nodes if the view node adds a dim of size 1.
"""
def __init__(self) -> None:
super().__init__()
self.view_copy_op: torch._ops.OpOverload = exir_ops.edge.aten.view_copy.default
self.squeeze_op: torch._ops.OpOverload = exir_ops.edge.aten.squeeze_copy.dims
self.unsqueeze_op: torch._ops.OpOverload = (
exir_ops.edge.aten.unsqueeze_copy.default
)
def is_node_target(
self, node: torch.fx.Node, target: torch._ops.OperatorBase
) -> bool:
return node.op == "call_function" and node.target == target
def find_squeeze_dims(
self,
input_shape: List[int],
view_shape: List[int],
) -> Optional[List[int]]:
# view_shape should be a subset of input_shape
if len(input_shape) <= len(view_shape):
return None
# check that all dims are equal except the removed dims
i = 0
j = 0
idx = []
while i < len(input_shape) and j < len(view_shape):
if input_shape[i] == view_shape[j]:
i += 1
j += 1
elif input_shape[i] == 1:
# squeeze axis on i and check next dim
idx.append(i)
i += 1
else:
return None
# If there are remaining dimensions, shapes do not match
if i < len(input_shape) or j < len(view_shape):
return None
return idx
def find_unsqueeze_dim(
self,
input_shape: List[int],
view_shape: List[int],
) -> Optional[int]:
# unsqueeze should increase the length of input_shape by 1
if len(view_shape) - len(input_shape) != 1:
return None
# check that all dims are equal except the added dim
i = 0
j = 0
idx = -1
while j < len(view_shape):
# account for added dim being last dim in view_shape
if i == j and j == len(input_shape):
if view_shape[j] != 1:
return None
elif input_shape[i] != view_shape[j]:
if view_shape[j] == 1:
idx = j
i -= 1
# continue to check remaining dims are equal
else:
return None
i += 1
j += 1
return idx
def replace_view_copy_node(
self,
graph_module: torch.fx.GraphModule,
view_node: torch.fx.Node,
op: torch._ops.OpOverload,
arg: Union[List[int], int],
) -> None:
with graph_module.graph.inserting_before(view_node):
new_node = graph_module.graph.create_node(
"call_function",
op,
(view_node.args[0], arg),
)
new_node.meta = view_node.meta
view_node.replace_all_uses_with(new_node)
graph_module.graph.erase_node(view_node)
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
modified = False
for node in graph_module.graph.nodes:
if self.is_node_target(node, self.view_copy_op):
input_node = node.args[0]
input_shape = input_node.meta["val"].shape
view_shape = node.args[1]
squeeze_dims = self.find_squeeze_dims(input_shape, view_shape)
if squeeze_dims:
self.replace_view_copy_node(
graph_module, node, self.squeeze_op, squeeze_dims
)
modified = True
continue
unsqueeze_dim = self.find_unsqueeze_dim(input_shape, view_shape)
if unsqueeze_dim:
self.replace_view_copy_node(
graph_module, node, self.unsqueeze_op, unsqueeze_dim
)
modified = True
continue
if modified:
graph_module.recompile()
graph_module = super().call(graph_module).graph_module
return PassResult(graph_module, modified)