@@ -96,9 +96,9 @@ def _detect_precision(self, node: torch.fx.Node) -> ConfigPrecisionType:
96
96
def _overwrite_precision (self , node : torch .fx .Node ):
97
97
precision = self ._detect_precision (node )
98
98
if precision not in self .enabled_precision_types :
99
- # detected precision is not enabled, lets try to partition it as fp32
99
+ # detected precision is not enabled, try to partition it as fp32
100
100
if self .enabled_precision_types == [ConfigPrecisionType .FP32 ]:
101
- # if only fp32 is enabled, then we can still partition fp32 gemms
101
+ # when only fp32 is enabled, then we can still partition fp32 gemms
102
102
# even with in a quantized graph
103
103
if precision in [
104
104
ConfigPrecisionType .STATIC_QUANT ,
@@ -107,6 +107,7 @@ def _overwrite_precision(self, node: torch.fx.Node):
107
107
precision = ConfigPrecisionType .FP32
108
108
logging .info (f"Overwriting precision, partitioning { node } as FP32" )
109
109
return True , precision
110
+
110
111
return False , precision
111
112
112
113
def get_deps (
@@ -226,8 +227,11 @@ def _get_bias_deps(
226
227
self , node : torch .fx .Node , ep : ExportedProgram , precision : ConfigPrecisionType
227
228
) -> Tuple [bool , List [torch .fx .Node ]]:
228
229
gemm_deps = []
229
- if precision == ConfigPrecisionType .FP32 and self .force_fp32_dynamic_linear :
230
- # if force force_fp32_dynamic_linear is enabled, then we
230
+ if (
231
+ precision == ConfigPrecisionType .FP32
232
+ and self .force_non_static_weights_for_f32_linear
233
+ ):
234
+ # if force_non_static_weights_for_f32_linear is enabled, then we
231
235
# do not partition the weight node
232
236
return (True , gemm_deps )
233
237
@@ -305,8 +309,11 @@ def get_original_aten(self) -> Optional[torch._ops.OpOverload]:
305
309
def _get_weight_deps (
306
310
self , node : torch .fx .Node , ep : ExportedProgram , precision : ConfigPrecisionType
307
311
) -> Tuple [bool , List [torch .fx .Node ]]:
308
- if precision == ConfigPrecisionType .FP32 and self .force_fp32_dynamic_linear :
309
- # if force fp32_dynamic_linear is enabled, then we
312
+ if (
313
+ precision == ConfigPrecisionType .FP32
314
+ and self .force_non_static_weights_for_f32_linear
315
+ ):
316
+ # if force_non_static_weights_for_f32_linear is enabled, then we
310
317
# do not partition the weight node
311
318
return (True , [])
312
319
@@ -412,9 +419,11 @@ def __init__(self, **kwargs):
412
419
def _get_weight_deps (
413
420
self , node : torch .fx .Node , ep : ExportedProgram , precision : ConfigPrecisionType
414
421
) -> Tuple [bool , List [torch .fx .Node ]]:
415
- # TODO(maxren, T210537195):
416
- if precision == ConfigPrecisionType .FP32 and self .force_fp32_dynamic_linear :
417
- # if force fp32_dynamic_linear is on and we detected this as fp32, then we
422
+ if (
423
+ precision == ConfigPrecisionType .FP32
424
+ and self .force_non_static_weights_for_f32_linear
425
+ ):
426
+ # if force_non_static_weights_for_f32_linear is on and we detected this as fp32, then we
418
427
# do not partition the weight node
419
428
return (True , [])
420
429
@@ -501,11 +510,11 @@ def find_partition_args(input_node):
501
510
node .args = old_args
502
511
node .users = old_users
503
512
504
- # When using force_fp32_dynamic_linear , we want to get_deps to overwrite the source partition nodes.
513
+ # When using force_non_static_weights_for_f32_linear , we want to get_deps to overwrite the source partition nodes.
505
514
# Else we want to be greedy.
506
515
ret_deps = (
507
516
list (set (deps ) & set (src_partition .nodes ))
508
- if self .force_fp32_dynamic_linear
517
+ if self .force_non_static_weights_for_f32_linear
509
518
else list (set (deps ) | set (src_partition .nodes ))
510
519
)
511
520
@@ -531,8 +540,11 @@ def __init__(self, **kwargs):
531
540
def _get_weight_deps (
532
541
self , node : torch .fx .Node , ep : ExportedProgram , precision : ConfigPrecisionType
533
542
) -> Tuple [bool , List [torch .fx .Node ]]:
534
- if precision == ConfigPrecisionType .FP32 and self .force_fp32_dynamic_linear :
535
- # if force fp32_dynamic_linear is on and we detected this as fp32, then we
543
+ if (
544
+ precision == ConfigPrecisionType .FP32
545
+ and self .force_non_static_weights_for_f32_linear
546
+ ):
547
+ # if force_non_static_weights_for_f32_linear is on and we detected this as fp32, then we
536
548
# do not partition the weight node
537
549
return (True , [])
538
550
0 commit comments