-
Notifications
You must be signed in to change notification settings - Fork 527
/
Copy pathtest_cat.py
172 lines (153 loc) · 6.14 KB
/
test_cat.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
# Copyright (c) Meta Platforms, Inc. and affiliates.
# Copyright 2024-2025 Arm Limited and/or its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import unittest
from typing import Tuple
import pytest
import torch
from executorch.backends.arm.test import common, conftest
from executorch.backends.arm.test.tester.arm_tester import ArmTester
from executorch.exir.backend.compile_spec_schema import CompileSpec
from parameterized import parameterized
class TestCat(unittest.TestCase):
class Cat(torch.nn.Module):
test_parameters = [
((torch.ones(1), torch.ones(1)), 0),
((torch.ones(1, 2), torch.randn(1, 5), torch.randn(1, 1)), 1),
(
(
torch.ones(1, 2, 5),
torch.randn(1, 2, 4),
torch.randn(1, 2, 2),
torch.randn(1, 2, 1),
),
-1,
),
((torch.randn(1, 2, 4, 4), torch.randn(1, 2, 4, 1)), 3),
((torch.randn(1, 2, 4, 4), torch.randn(1, 2, 4, 4)), 0),
((torch.randn(2, 2, 4, 4), torch.randn(2, 2, 4, 1)), 3),
(
(
10000 * torch.randn(2, 3, 1, 4),
torch.randn(2, 7, 1, 4),
torch.randn(2, 1, 1, 4),
),
-3,
),
]
def __init__(self):
super().__init__()
def forward(self, t: tuple[torch.Tensor, ...], dim: int) -> torch.Tensor:
return torch.cat(t, dim=dim)
def _test_cat_tosa_MI_pipeline(
self, module: torch.nn.Module, test_data: Tuple[tuple[torch.Tensor, ...], int]
):
(
ArmTester(
module,
example_inputs=test_data,
compile_spec=common.get_tosa_compile_spec("TOSA-0.80+MI"),
)
.export()
.check_count({"torch.ops.aten.cat.default": 1})
.check_not(["torch.ops.quantized_decomposed"])
.to_edge()
.partition()
.check_not(["executorch_exir_dialects_edge__ops_aten_cat_default"])
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
.to_executorch()
.run_method_and_compare_outputs(inputs=test_data)
)
def _test_cat_tosa_BI_pipeline(
self, module: torch.nn.Module, test_data: Tuple[tuple[torch.Tensor, ...], int]
):
(
ArmTester(
module,
example_inputs=test_data,
compile_spec=common.get_tosa_compile_spec("TOSA-0.80+BI"),
)
.quantize()
.export()
.check_count({"torch.ops.aten.cat.default": 1})
.check(["torch.ops.quantized_decomposed"])
.to_edge()
.partition()
.check_not(["executorch_exir_dialects_edge__ops_aten_cat_default"])
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
.to_executorch()
.run_method_and_compare_outputs(inputs=test_data, qtol=1)
)
def _test_cat_ethosu_BI_pipeline(
self,
module: torch.nn.Module,
compile_spec: CompileSpec,
test_data: Tuple[tuple[torch.Tensor, ...], int],
):
tester = (
ArmTester(
module,
example_inputs=test_data,
compile_spec=compile_spec,
)
.quantize()
.export()
.check_count({"torch.ops.aten.cat.default": 1})
.check(["torch.ops.quantized_decomposed"])
.to_edge()
.partition()
.check_not(["executorch_exir_dialects_edge__ops_aten_cat_default"])
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
.to_executorch()
.serialize()
)
if conftest.is_option_enabled("corstone_fvp"):
tester.run_method_and_compare_outputs(inputs=test_data)
@parameterized.expand(Cat.test_parameters)
def test_cat_tosa_MI(self, operands: tuple[torch.Tensor, ...], dim: int):
test_data = (operands, dim)
self._test_cat_tosa_MI_pipeline(self.Cat(), test_data)
def test_cat_4d_tosa_MI(self):
square = torch.ones((2, 2, 2, 2))
for dim in range(-3, 3):
test_data = ((square, square.clone()), dim)
self._test_cat_tosa_MI_pipeline(self.Cat(), test_data)
@parameterized.expand(Cat.test_parameters)
def test_cat_tosa_BI(self, operands: tuple[torch.Tensor, ...], dim: int):
test_data = (operands, dim)
self._test_cat_tosa_BI_pipeline(self.Cat(), test_data)
@parameterized.expand(Cat.test_parameters[:-3])
@pytest.mark.corstone_fvp
def test_cat_u55_BI(self, operands: tuple[torch.Tensor, ...], dim: int):
test_data = (operands, dim)
self._test_cat_ethosu_BI_pipeline(
self.Cat(), common.get_u55_compile_spec(), test_data
)
# MLETORCH-630 Cat does not work on FVP with batch>1
@parameterized.expand(Cat.test_parameters[-3:])
@pytest.mark.corstone_fvp
@conftest.expectedFailureOnFVP
def test_cat_u55_BI_xfails(self, operands: tuple[torch.Tensor, ...], dim: int):
test_data = (operands, dim)
self._test_cat_ethosu_BI_pipeline(
self.Cat(), common.get_u55_compile_spec(), test_data
)
@parameterized.expand(Cat.test_parameters[:-3])
@pytest.mark.corstone_fvp
def test_cat_u85_BI(self, operands: tuple[torch.Tensor, ...], dim: int):
test_data = (operands, dim)
self._test_cat_ethosu_BI_pipeline(
self.Cat(), common.get_u85_compile_spec(), test_data
)
# MLETORCH-630 Cat does not work on FVP with batch>1
@parameterized.expand(Cat.test_parameters[-3:])
@pytest.mark.corstone_fvp
@conftest.expectedFailureOnFVP
def test_cat_u85_BI_xfails(self, operands: tuple[torch.Tensor, ...], dim: int):
test_data = (operands, dim)
self._test_cat_ethosu_BI_pipeline(
self.Cat(), common.get_u85_compile_spec(), test_data
)