33
33
from executorch .backends .cadence .aot .utils import get_edge_overload_packet
34
34
from executorch .backends .transforms .remove_clone_ops import RemoveCloneOpsTransform
35
35
from executorch .exir .dialects ._ops import ops as exir_ops
36
- from executorch .exir .dialects .edge ._ops import EdgeOpOverload
36
+ from executorch .exir .dialects .edge ._ops import EdgeOpOverload , EdgeOpOverloadPacket
37
37
from executorch .exir .pass_base import ExportPass , NodeMetadata , PassResult , ProxyValue
38
38
from executorch .exir .pass_manager import PassManager , PassType
39
39
from executorch .exir .passes import dead_code_elimination_pass
@@ -745,6 +745,68 @@ def permute_shape(
745
745
return [shape [p ] for p in permute_dims ]
746
746
747
747
748
+ @register_cadence_pass (CadencePassAttribute (opt_level = 1 ))
749
+ class RemoveBranchedQuantDequant (ExportPass ):
750
+ """
751
+ This pass looks for adjacent quant and dequant nodes with identical
752
+ parameters, where the quant node has other users in addition to the
753
+ dequant. The quant and dequant pair would be removed by the
754
+ FuseQuantDequantToRequantizePass if not for the multiple users. This pass
755
+ removes just the dequant node by connecting it to the quant's parent node
756
+ """
757
+
758
+ quantize_op_packets : set [EdgeOpOverloadPacket ] = {
759
+ exir_ops .edge .cadence .quantize_per_tensor ,
760
+ exir_ops .edge .quantized_decomposed .quantize_per_tensor ,
761
+ }
762
+ dequantize_op_packets : set [EdgeOpOverloadPacket ] = {
763
+ exir_ops .edge .cadence .dequantize_per_tensor ,
764
+ exir_ops .edge .quantized_decomposed .dequantize_per_tensor ,
765
+ }
766
+
767
+ def call (self , graph_module : torch .fx .GraphModule ) -> PassResult :
768
+ self .remove_branched (
769
+ graph_module , self .quantize_op_packets , self .dequantize_op_packets
770
+ )
771
+ self .remove_branched (
772
+ graph_module , self .dequantize_op_packets , self .quantize_op_packets
773
+ )
774
+
775
+ graph_module .graph .eliminate_dead_code ()
776
+ result = super ().call (graph_module )
777
+ return result
778
+
779
+ def remove_branched (
780
+ self ,
781
+ graph_module : torch .fx .GraphModule ,
782
+ producer_pkts : set [EdgeOpOverloadPacket ],
783
+ consumer_pkts : set [EdgeOpOverloadPacket ],
784
+ ) -> None :
785
+ for node in graph_module .graph .nodes :
786
+ if (
787
+ node .op != "call_function"
788
+ or not isinstance (node .target , EdgeOpOverload )
789
+ or get_edge_overload_packet (node .target ) not in producer_pkts
790
+ ):
791
+ continue
792
+
793
+ if len (node .users ) < 2 :
794
+ continue
795
+
796
+ for user in node .users :
797
+ if (
798
+ not isinstance (user .target , EdgeOpOverload )
799
+ or get_edge_overload_packet (user .target ) not in consumer_pkts
800
+ ):
801
+ continue
802
+
803
+ # check qparams match
804
+ if node .args [1 :] != user .args [1 :]:
805
+ continue
806
+
807
+ user .replace_all_uses_with (node .args [0 ])
808
+
809
+
748
810
# The following class consolidates functions to remove ops that are redundant
749
811
# in Jarvis. Currently, each function in this class iterates over each node of
750
812
# the graph module once. In future, we could consolidate them into a monolithic
@@ -765,4 +827,5 @@ class CadenceRemoveNops:
765
827
RemoveNopMulOpPass ,
766
828
RemoveNopAddOpPass ,
767
829
RemoveNopLinalgVectorNormOpPass ,
830
+ RemoveBranchedQuantDequant ,
768
831
]
0 commit comments