Skip to content

Commit 1a9a59b

Browse files
[ExecuTorch][XNNPACK] Rename linear weight partitioning flag for clarity
Pull Request resolved: #8892 Differential Revision: [D70372220](https://our.internmc.facebook.com/intern/diff/D70372220/) ghstack-source-id: 269599293 Co-authored-by: Digant Desai <[email protected]>
1 parent ef73540 commit 1a9a59b

File tree

4 files changed

+35
-19
lines changed

4 files changed

+35
-19
lines changed

backends/xnnpack/partition/config/gemm_configs.py

+25-13
Original file line numberDiff line numberDiff line change
@@ -96,9 +96,9 @@ def _detect_precision(self, node: torch.fx.Node) -> ConfigPrecisionType:
9696
def _overwrite_precision(self, node: torch.fx.Node):
9797
precision = self._detect_precision(node)
9898
if precision not in self.enabled_precision_types:
99-
# detected precision is not enabled, lets try to partition it as fp32
99+
# detected precision is not enabled, try to partition it as fp32
100100
if self.enabled_precision_types == [ConfigPrecisionType.FP32]:
101-
# if only fp32 is enabled, then we can still partition fp32 gemms
101+
# when only fp32 is enabled, then we can still partition fp32 gemms
102102
# even with in a quantized graph
103103
if precision in [
104104
ConfigPrecisionType.STATIC_QUANT,
@@ -107,6 +107,7 @@ def _overwrite_precision(self, node: torch.fx.Node):
107107
precision = ConfigPrecisionType.FP32
108108
logging.info(f"Overwriting precision, partitioning {node} as FP32")
109109
return True, precision
110+
110111
return False, precision
111112

112113
def get_deps(
@@ -226,8 +227,11 @@ def _get_bias_deps(
226227
self, node: torch.fx.Node, ep: ExportedProgram, precision: ConfigPrecisionType
227228
) -> Tuple[bool, List[torch.fx.Node]]:
228229
gemm_deps = []
229-
if precision == ConfigPrecisionType.FP32 and self.force_fp32_dynamic_linear:
230-
# if force force_fp32_dynamic_linear is enabled, then we
230+
if (
231+
precision == ConfigPrecisionType.FP32
232+
and self.force_non_static_weights_for_f32_linear
233+
):
234+
# if force_non_static_weights_for_f32_linear is enabled, then we
231235
# do not partition the weight node
232236
return (True, gemm_deps)
233237

@@ -305,8 +309,11 @@ def get_original_aten(self) -> Optional[torch._ops.OpOverload]:
305309
def _get_weight_deps(
306310
self, node: torch.fx.Node, ep: ExportedProgram, precision: ConfigPrecisionType
307311
) -> Tuple[bool, List[torch.fx.Node]]:
308-
if precision == ConfigPrecisionType.FP32 and self.force_fp32_dynamic_linear:
309-
# if force fp32_dynamic_linear is enabled, then we
312+
if (
313+
precision == ConfigPrecisionType.FP32
314+
and self.force_non_static_weights_for_f32_linear
315+
):
316+
# if force_non_static_weights_for_f32_linear is enabled, then we
310317
# do not partition the weight node
311318
return (True, [])
312319

@@ -412,9 +419,11 @@ def __init__(self, **kwargs):
412419
def _get_weight_deps(
413420
self, node: torch.fx.Node, ep: ExportedProgram, precision: ConfigPrecisionType
414421
) -> Tuple[bool, List[torch.fx.Node]]:
415-
# TODO(maxren, T210537195):
416-
if precision == ConfigPrecisionType.FP32 and self.force_fp32_dynamic_linear:
417-
# if force fp32_dynamic_linear is on and we detected this as fp32, then we
422+
if (
423+
precision == ConfigPrecisionType.FP32
424+
and self.force_non_static_weights_for_f32_linear
425+
):
426+
# if force_non_static_weights_for_f32_linear is on and we detected this as fp32, then we
418427
# do not partition the weight node
419428
return (True, [])
420429

@@ -501,11 +510,11 @@ def find_partition_args(input_node):
501510
node.args = old_args
502511
node.users = old_users
503512

504-
# When using force_fp32_dynamic_linear, we want to get_deps to overwrite the source partition nodes.
513+
# When using force_non_static_weights_for_f32_linear, we want to get_deps to overwrite the source partition nodes.
505514
# Else we want to be greedy.
506515
ret_deps = (
507516
list(set(deps) & set(src_partition.nodes))
508-
if self.force_fp32_dynamic_linear
517+
if self.force_non_static_weights_for_f32_linear
509518
else list(set(deps) | set(src_partition.nodes))
510519
)
511520

@@ -531,8 +540,11 @@ def __init__(self, **kwargs):
531540
def _get_weight_deps(
532541
self, node: torch.fx.Node, ep: ExportedProgram, precision: ConfigPrecisionType
533542
) -> Tuple[bool, List[torch.fx.Node]]:
534-
if precision == ConfigPrecisionType.FP32 and self.force_fp32_dynamic_linear:
535-
# if force fp32_dynamic_linear is on and we detected this as fp32, then we
543+
if (
544+
precision == ConfigPrecisionType.FP32
545+
and self.force_non_static_weights_for_f32_linear
546+
):
547+
# if force_non_static_weights_for_f32_linear is on and we detected this as fp32, then we
536548
# do not partition the weight node
537549
return (True, [])
538550

backends/xnnpack/partition/config/xnnpack_config.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,9 @@ def __init__(self, **kwargs):
4141
super().__init__()
4242
self.enabled_precision_types = self.supported_precision_types()
4343
# Flag used in GEMMConfig()
44-
self.force_fp32_dynamic_linear = kwargs.get("force_fp32_dynamic_linear", False)
44+
self.force_non_static_weights_for_f32_linear = kwargs.get(
45+
"force_non_static_weights_for_f32_linear", False
46+
)
4547

4648
def get_partition(
4749
self, node: torch.fx.Node, ep: ExportedProgram

backends/xnnpack/test/ops/test_linear.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -948,7 +948,7 @@ def test_linear_qd8_as_fp32(self):
948948
},
949949
)
950950

951-
def test_linear_fp32_with_force_as_mm(self):
951+
def test_linear_with_force_non_static_weights_for_f32_linear(self):
952952
def check_signature(
953953
signature: ExportGraphSignature,
954954
force_flag: bool,
@@ -981,7 +981,7 @@ def check_signature(
981981
inputs = module.get_inputs()
982982
tester = Tester(module, inputs).export()
983983
partitioner = XnnpackPartitioner(
984-
force_fp32_dynamic_linear=force_flag
984+
force_non_static_weights_for_f32_linear=force_flag
985985
)
986986
if legacy_mode:
987987
tester.to_edge()

backends/xnnpack/test/ops/test_lstm.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -43,18 +43,20 @@ def test_fp32_lstm(self):
4343
.run_method_and_compare_outputs()
4444
)
4545

46-
def test_fp32_lstm_force_dynamic_linear(self):
46+
def test_lstm_with_force_non_static_weights_for_f32_linear(self):
4747
(
4848
Tester(self.LSTMLinear(32, 32, 10), (torch.rand(1, 32, 32),))
4949
.export()
5050
.to_edge_transform_and_lower(
5151
ToEdgeTransformAndLower(
52-
partitioners=[XnnpackPartitioner(force_fp32_dynamic_linear=True)]
52+
partitioners=[
53+
XnnpackPartitioner(force_non_static_weights_for_f32_linear=True)
54+
]
5355
)
5456
)
5557
.check_not(["executorch_exir_dialects_edge__ops_aten_addmm_default"])
5658
# Weights are supplied as input to linears
57-
# Biases are not owned by delegates when force_fp32_dynamic_linear is set
59+
# Biases are not owned by delegates when force_non_static_weights_for_f32_linear is set
5860
.check(["p_lstm_weight_hh_l0", "p_lstm_weight_ih_l0", "p_lstm_bias"])
5961
.to_executorch()
6062
.serialize()

0 commit comments

Comments
 (0)