-
Notifications
You must be signed in to change notification settings - Fork 528
/
Copy pathtosa_specification.py
213 lines (167 loc) · 7.08 KB
/
tosa_specification.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
# Copyright 2024-2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
# pyre-unsafe
#
# Main implementation of AoT flow to partition and preprocess for Arm target
# backends. Converts via TOSA as an intermediate form supported by AoT and
# JIT compiler flows.
#
import re
from typing import List
from packaging.version import Version
class TosaSpecification:
"""
This class implements a representation of TOSA specification
(https://www.mlplatform.org/tosa/tosa_spec.html) with a version, a profile
(with extension) and a level (8k).
For 0.80 releases the profile is BI or MI, with u55 handled as an inofficial extension
For 1.00 releases the profile is INT or FP, and the extensions are for
INT: int16, int4, var, cf
FP: bf16, fp8e4m3, fp8e5m2, fft, var, cf
The TOSA specification is encoded in the string represenatation
TOSA-major.minor.patch+profile[+level][+extensions]
For 0.80 MI implies BI, while for 1.0 the profiles has to explicitely be specified.
Profiles are uppercase letters and extensions and level is lowercase.
"""
version: Version
def support_integer(self) -> bool:
"""
Returns true if any integer operations are supported for the specification.
"""
raise NotImplementedError
def support_float(self) -> bool:
"""
Returns true if any float operations are supported for the specification.
"""
raise NotImplementedError
def __init__(self, version: Version):
self.version = version
@staticmethod
def create_from_string(repr: str) -> "TosaSpecification":
"""
Creates a TOSA specification class from a string representation:
TOSA-0.80+MI
TOSA-0.80+BI+8k
TOSA-0.80+BI+u55 # Ethos-U55 extension to handle TOSA subset
TOSA-0.90.0+MI
TOSA-1.00.0+INT+FP+int4+cf
"""
pattern = r"^(TOSA)-([\d.]+)\+(.+)$"
match = re.match(pattern, repr)
if match:
name = match.group(1)
version = Version(match.group(2))
extras = match.group(3).split("+")
if name != "TOSA":
raise ValueError(f"Malformed TOSA specification representation: {repr}")
match version:
case _ if version.major == 0 and version.minor == 80:
return Tosa_0_80(version, extras)
case _ if version.major == 1 and version.minor == 0:
return Tosa_1_00(version, extras)
case _:
raise ValueError(f"Wrong TOSA version: {version} from {repr}")
raise ValueError(f"Failed to parse TOSA specification representation: {repr}")
class Tosa_0_80(TosaSpecification):
profile: str
level_8k: bool
is_U55_subset: bool
available_profiles = ["BI", "MI"] # MT is not defined
def __init__(self, version: Version, extras: List[str]):
super().__init__(version)
assert version >= Version("0.80") and version < Version("0.90")
# Check that we only have one profile in the extensions list
if [e in Tosa_0_80.available_profiles for e in extras].count(True) != 1:
raise ValueError(
f"Bad combination of extras: {extras}, more than one of {Tosa_0_80.available_profiles} found."
)
# The list contains one profile at most, so pick it
self.profile = [e for e in extras if e in Tosa_0_80.available_profiles][0]
extras.remove(self.profile)
self.level_8k = "8k" in extras
if self.level_8k:
extras.remove("8k")
self.is_U55_subset = "u55" in extras
if self.is_U55_subset:
extras.remove("u55")
if len(extras) > 0:
raise ValueError(f"Unhandled extras found: {extras}")
def __repr__(self) -> str:
extensions = ""
if self.level_8k:
extensions += "+8k"
if self.is_U55_subset:
extensions += "+u55"
return f"TOSA-{str(self.version)}+{self.profile}{extensions}"
def __hash__(self) -> int:
return hash(str(self.version) + self.profile)
def __eq__(self, other: object) -> bool:
if isinstance(other, Tosa_0_80):
return (self.version == other.version) and (self.profile == other.profile)
return False
def support_integer(self):
return True
def support_float(self):
return self.profile == "MI"
class Tosa_1_00(TosaSpecification):
profiles: List[str]
level_8k: bool
extensions: List[str]
available_profiles = ["INT", "FP"]
valid_extensions = {
"INT": ["int16", "int4", "var", "cf", "u55"],
"FP": ["bf16", "fp8e4m3", "fp8e5m2", "fft", "var", "cf"],
}
def __init__(self, version: Version, extras: List[str]):
super().__init__(version)
# Check that we have at least one profile in the extensions list
if [e in Tosa_1_00.available_profiles for e in extras].count(True) == 0:
raise ValueError(
f"No profile ({Tosa_1_00.available_profiles}) found in: {extras}."
)
# and not more than number of available profiles
if [e in Tosa_1_00.available_profiles for e in extras].count(True) > len(
Tosa_1_00.available_profiles
):
raise ValueError(
f"Too many profiles ({Tosa_1_00.available_profiles}) found in: {extras}."
)
# The list contains one profile at least, so pick them
self.profiles = [e for e in extras if e in Tosa_1_00.available_profiles]
for p in self.profiles:
extras.remove(p)
self.level_8k = "8k" in extras
if self.level_8k:
extras.remove("8k")
combined_extensions = []
for p in self.profiles:
combined_extensions += Tosa_1_00.valid_extensions[p]
if not all(e in combined_extensions for e in extras):
raise ValueError(
f"Bad extensions for TOSA-{version}{self._get_profiles_string()}: {extras}"
)
# all the rest of the extras are handled extensions
self.extensions = extras
def _get_profiles_string(self) -> str:
return "".join(["+" + p for p in self.profiles])
def _get_extensions_string(self) -> str:
return "".join(["+" + e for e in self.extensions])
def __repr__(self):
extensions = self._get_extensions_string()
if self.level_8k:
extensions += "+8k"
return f"TOSA-{self.version}{self._get_profiles_string()}{extensions}"
def __hash__(self) -> int:
return hash(str(self.version) + self._get_profiles_string())
def __eq__(self, other: object) -> bool:
if isinstance(other, Tosa_1_00):
return (self.version == other.version) and (
self._get_profiles_string() == other._get_profiles_string()
)
return False
def support_integer(self):
return "INT" in self.profiles
def support_float(self):
return "FP" in self.profiles