diff --git a/backends/xnnpack/operators/op_dynamic_dequantize_ops.py b/backends/xnnpack/operators/op_dynamic_dequantize_ops.py index f8f0c54ee68..82a35236294 100644 --- a/backends/xnnpack/operators/op_dynamic_dequantize_ops.py +++ b/backends/xnnpack/operators/op_dynamic_dequantize_ops.py @@ -13,6 +13,7 @@ ) from executorch.backends.xnnpack.serialization.xnnpack_graph_schema import XNNGraph from executorch.backends.xnnpack.utils.quant_utils import ( + is_dynamic_qdq, is_per_channel_group, is_per_token, ) @@ -92,7 +93,8 @@ def define_node( """ We always define dequantize affine nodes because they are always explicit """ - if is_per_channel_group(node): + is_dynamic = is_dynamic_qdq(node) + if is_per_channel_group(node) and not is_dynamic: check_or_raise( is_param_node(self._exported_program, node.all_input_nodes[0]), f"Expected quantize affine node with per-token semantics to be used " @@ -103,7 +105,7 @@ def define_node( return check_or_raise( - is_per_token(node), + is_per_token(node) and is_dynamic, "Expecting Affine Dequantized Op to have per-token semantics", ) # This must be a per-token affine dequantized node, so let us serialize as such diff --git a/backends/xnnpack/operators/op_dynamic_quantize_ops.py b/backends/xnnpack/operators/op_dynamic_quantize_ops.py index 23047e731f7..9369f025216 100644 --- a/backends/xnnpack/operators/op_dynamic_quantize_ops.py +++ b/backends/xnnpack/operators/op_dynamic_quantize_ops.py @@ -18,6 +18,7 @@ XNode, ) from executorch.backends.xnnpack.utils.quant_utils import ( + is_dynamic_qdq, is_per_channel_group, is_per_token, ) @@ -138,13 +139,14 @@ def define_node( """ We always define quantize affine nodes because they are always explicit """ - if is_per_channel_group(node): + is_dynamic = is_dynamic_qdq(node) + if is_per_channel_group(node) and not is_dynamic: # Affine quantized was recognized as per channel group which means that it should # be skipped as this means it is used in front of a weight node return check_or_raise( - is_per_token(node), + is_per_token(node) and is_dynamic, "Encountered affine quantized op which does not have per-token semantics", ) # Treat this node as dynamic per-token quantization diff --git a/backends/xnnpack/test/ops/test_linear.py b/backends/xnnpack/test/ops/test_linear.py index 690a1109a17..fec6005d706 100644 --- a/backends/xnnpack/test/ops/test_linear.py +++ b/backends/xnnpack/test/ops/test_linear.py @@ -645,31 +645,32 @@ def _test_qd8_per_token_weight_per_channel_group_int4( bl_sizes = [32, 32, 32, 64] N_sizes = [2, 17, 92, 128] - for use_bias in [True, False]: - for M, K, bl, N in zip(M_sizes, K_sizes, bl_sizes, N_sizes): - lin_mod = BaseLinear( - in_size=M, - input_channels=K, - output_channels=N, - dtype=dtype, - use_bias=use_bias, - ) + for input_rank in range(2, 4): + for use_bias in [True, False]: + for M, K, bl, N in zip(M_sizes, K_sizes, bl_sizes, N_sizes): + lin_mod = BaseLinear( + in_size=M, + input_channels=K, + output_channels=N, + dtype=dtype, + use_bias=use_bias, + ) - inputs = lin_mod.get_inputs() - # Half requires slightly higher atol, but if you look at error it is not that bad: - # Difference: max: 0.00140380859375, abs: 0.00140380859375, mean abs error: 0.00042724609375. - # -- Model vs. Reference -- - # Numel: 4, 4 - # Median: -0.05023193359375, -0.0516357421875 - # Mean: 0.2373046875, 0.237060546875 - # Max: 1.0078125, 1.0078125 - # Min: -0.08465576171875, -0.08441162109375 - atol = ( - 1e-2 if dtype == torch.half else 5e-3 - ) # TODO(T212995726): Investigate right atol for rand[n] inputs - self._test_groupwise_dq_linear( - lin_mod, inputs, group_size=bl, use_bias=use_bias, atol=atol - ) + inputs = lin_mod.get_inputs(rank=input_rank) + # Half requires slightly higher atol, but if you look at error it is not that bad: + # Difference: max: 0.00140380859375, abs: 0.00140380859375, mean abs error: 0.00042724609375. + # -- Model vs. Reference -- + # Numel: 4, 4 + # Median: -0.05023193359375, -0.0516357421875 + # Mean: 0.2373046875, 0.237060546875 + # Max: 1.0078125, 1.0078125 + # Min: -0.08465576171875, -0.08441162109375 + atol = ( + 1e-2 if dtype == torch.half else 5e-3 + ) # TODO(T212995726): Investigate right atol for rand[n] inputs + self._test_groupwise_dq_linear( + lin_mod, inputs, group_size=bl, use_bias=use_bias, atol=atol + ) def test_fp16_linear(self): for use_bias in (True, False): diff --git a/backends/xnnpack/utils/quant_utils.py b/backends/xnnpack/utils/quant_utils.py index 49c5a963161..7e3df9097cc 100644 --- a/backends/xnnpack/utils/quant_utils.py +++ b/backends/xnnpack/utils/quant_utils.py @@ -47,12 +47,30 @@ def is_dynamic_qdq(node: torch.fx.Node) -> bool: - if node.op != "call_function": + # check has dynamic qdq name + if not (is_quant(node) or is_dequant(node)): + return False + + # check scales and zp are dynamically chosen + node_input_args = node.args + if is_affine_qdq(node): + node_input_args = extract_qdq_affine_op_args_for_decomposed_ops(node) + + scale = node_input_args[1] + zp = node_input_args[2] + if not (isinstance(scale, torch.fx.Node) and isinstance(zp, torch.fx.Node)): + return False + + if not (scale.target == operator.getitem and zp.target == operator.getitem): + return False + + scale_choose_qparam = scale.all_input_nodes[0] + zp_choose_qparam = zp.all_input_nodes[0] + + if not (is_qparam(scale_choose_qparam) and is_qparam(zp_choose_qparam)): return False - node_name = format_target_name(node.target.__name__) # pyre-ignore - is_dynamic_affine = is_per_token(node) and not is_per_channel_group(node) - return node_name in _DYNAMIC_OPS or is_dynamic_affine + return True def is_qparam(node: torch.fx.Node) -> bool: