Skip to content

Commit ef73540

Browse files
authored
Add a pass to remove certain redundant branched quant/dequant nodes
Differential Revision: D69947096 Pull Request resolved: #8896
1 parent 6e09ea2 commit ef73540

File tree

4 files changed

+109
-5
lines changed

4 files changed

+109
-5
lines changed

backends/cadence/aot/pass_utils.py

+10
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,16 @@ def count_node(graph_module: torch.fx.GraphModule, target: torch.fx.node.Target)
104104
return total
105105

106106

107+
def op_counts_match(
108+
graph_module: torch.fx.GraphModule,
109+
expected_op_counts: dict[EdgeOpOverload, int],
110+
) -> bool:
111+
for op, count in expected_op_counts.items():
112+
if count_node(graph_module, op) != count:
113+
return False
114+
return True
115+
116+
107117
# Testing utils
108118
# Return the compute/function nodes in the graph
109119
def get_compute_nodes_in_gm(graph_module: torch.fx.GraphModule) -> List[torch.fx.Node]:

backends/cadence/aot/remove_ops.py

+64-1
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
from executorch.backends.cadence.aot.utils import get_edge_overload_packet
3434
from executorch.backends.transforms.remove_clone_ops import RemoveCloneOpsTransform
3535
from executorch.exir.dialects._ops import ops as exir_ops
36-
from executorch.exir.dialects.edge._ops import EdgeOpOverload
36+
from executorch.exir.dialects.edge._ops import EdgeOpOverload, EdgeOpOverloadPacket
3737
from executorch.exir.pass_base import ExportPass, NodeMetadata, PassResult, ProxyValue
3838
from executorch.exir.pass_manager import PassManager, PassType
3939
from executorch.exir.passes import dead_code_elimination_pass
@@ -745,6 +745,68 @@ def permute_shape(
745745
return [shape[p] for p in permute_dims]
746746

747747

748+
@register_cadence_pass(CadencePassAttribute(opt_level=1))
749+
class RemoveBranchedQuantDequant(ExportPass):
750+
"""
751+
This pass looks for adjacent quant and dequant nodes with identical
752+
parameters, where the quant node has other users in addition to the
753+
dequant. The quant and dequant pair would be removed by the
754+
FuseQuantDequantToRequantizePass if not for the multiple users. This pass
755+
removes just the dequant node by connecting it to the quant's parent node
756+
"""
757+
758+
quantize_op_packets: set[EdgeOpOverloadPacket] = {
759+
exir_ops.edge.cadence.quantize_per_tensor,
760+
exir_ops.edge.quantized_decomposed.quantize_per_tensor,
761+
}
762+
dequantize_op_packets: set[EdgeOpOverloadPacket] = {
763+
exir_ops.edge.cadence.dequantize_per_tensor,
764+
exir_ops.edge.quantized_decomposed.dequantize_per_tensor,
765+
}
766+
767+
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
768+
self.remove_branched(
769+
graph_module, self.quantize_op_packets, self.dequantize_op_packets
770+
)
771+
self.remove_branched(
772+
graph_module, self.dequantize_op_packets, self.quantize_op_packets
773+
)
774+
775+
graph_module.graph.eliminate_dead_code()
776+
result = super().call(graph_module)
777+
return result
778+
779+
def remove_branched(
780+
self,
781+
graph_module: torch.fx.GraphModule,
782+
producer_pkts: set[EdgeOpOverloadPacket],
783+
consumer_pkts: set[EdgeOpOverloadPacket],
784+
) -> None:
785+
for node in graph_module.graph.nodes:
786+
if (
787+
node.op != "call_function"
788+
or not isinstance(node.target, EdgeOpOverload)
789+
or get_edge_overload_packet(node.target) not in producer_pkts
790+
):
791+
continue
792+
793+
if len(node.users) < 2:
794+
continue
795+
796+
for user in node.users:
797+
if (
798+
not isinstance(user.target, EdgeOpOverload)
799+
or get_edge_overload_packet(user.target) not in consumer_pkts
800+
):
801+
continue
802+
803+
# check qparams match
804+
if node.args[1:] != user.args[1:]:
805+
continue
806+
807+
user.replace_all_uses_with(node.args[0])
808+
809+
748810
# The following class consolidates functions to remove ops that are redundant
749811
# in Jarvis. Currently, each function in this class iterates over each node of
750812
# the graph module once. In future, we could consolidate them into a monolithic
@@ -765,4 +827,5 @@ class CadenceRemoveNops:
765827
RemoveNopMulOpPass,
766828
RemoveNopAddOpPass,
767829
RemoveNopLinalgVectorNormOpPass,
830+
RemoveBranchedQuantDequant,
768831
]

backends/cadence/aot/tests/test_fusion_ops_passes.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
FuseTransposeOpPairsPass,
2121
)
2222
from executorch.backends.cadence.aot.graph_builder import GraphBuilder
23-
from executorch.backends.cadence.aot.pass_utils import count_node
23+
from executorch.backends.cadence.aot.pass_utils import count_node, op_counts_match
2424
from executorch.exir.dialects._ops import ops as exir_ops
2525
from executorch.exir.dialects.edge._ops import EdgeOpOverload
2626
from torch import nn
@@ -32,8 +32,7 @@ def check_op_counts(
3232
graph_module: torch.fx.GraphModule,
3333
expected_op_counts: dict[EdgeOpOverload, int],
3434
) -> None:
35-
for op, count in expected_op_counts.items():
36-
self.assertEqual(count_node(graph_module, op), count)
35+
self.assertTrue(op_counts_match(graph_module, expected_op_counts))
3736

3837

3938
class TestFusionPasses(TestFusionPassesBase):

backends/cadence/aot/tests/test_remove_ops_passes.py

+33-1
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,11 @@
1717
from executorch.backends.cadence.aot import compiler
1818
from executorch.backends.cadence.aot.compiler import export_to_edge
1919

20-
from executorch.backends.cadence.aot.pass_utils import count_node
20+
from executorch.backends.cadence.aot.pass_utils import count_node, op_counts_match
2121
from executorch.backends.cadence.aot.quantizer.quantizer import CadenceDefaultQuantizer
2222
from executorch.backends.cadence.aot.remove_ops import (
2323
RemoveAliasCopyOpPass,
24+
RemoveBranchedQuantDequant,
2425
RemoveCloneOpPass,
2526
RemoveContiguousOpPass,
2627
RemoveDetachCopyPass,
@@ -709,3 +710,34 @@ def forward(self, x):
709710
self.assertEqual(
710711
count_node(graph_module, exir_ops.edge.aten.permute_copy.default), 2
711712
)
713+
714+
def test_remove_dequant_on_branch(self):
715+
class M(torch.nn.Module):
716+
def forward(self, x):
717+
x = torch.abs(x)
718+
x0 = torch.ops.quantized_decomposed.quantize_per_tensor(
719+
x, 1.2, 3, 0, 127, torch.int8
720+
)
721+
x1 = torch.abs(x0)
722+
y0 = torch.ops.quantized_decomposed.dequantize_per_tensor(
723+
x0, 1.2, 3, 0, 127, torch.int8
724+
)
725+
y1 = y0.view(-1)
726+
return x1, y1
727+
728+
inputs = torch.rand(1, 8, 4, 6)
729+
model = M()
730+
graph_module = export_to_edge(model, (inputs,)).exported_program().graph_module
731+
732+
graph_module = RemoveBranchedQuantDequant()(graph_module).graph_module
733+
self.assertTrue(
734+
op_counts_match(
735+
graph_module,
736+
expected_op_counts={
737+
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default: 1,
738+
# we expect the pass to remove the dequantize node
739+
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default: 0,
740+
exir_ops.edge.aten.abs.default: 2,
741+
},
742+
)
743+
)

0 commit comments

Comments
 (0)