The Arm Ethos-U backend is the ExecuTorch solution for executing quantized models on Ethos-U55, Ethos-U65, and Ethos-U85 NPUs. It leverages the TOSA operator set which can be compiled by the ethos-u-vela graph compiler.
- Wide operator support for delegating large parts of models to highly optimized and low power Ethos-U NPUs.
- A quantizer that optimizes quantization for the NPU target.
The target system must include an Ethos-U NPU.
To compile for the NPUs, the Ethos-U Vela compiler is needed. A target-specific toolchain is also needed for building the runtime. Finally, to test models, Arm provides freely available Fixed Virtual Platforms (FVP), allowing running code on the Ethos-U without a a physical development board by emulating reference designs. For Ethos-U55, there is Corstone-300, and for Ethos-U85, there is Corstone-320.
These dependencies can easily be downloaded using the script examples/arm/setup.sh
.
To work with with quantized models, build the quantize_ops_aot library that contains kernels for quantization and dequantization. This can be done with the script
backends/arm/scripts/build_quantized_ops_aot_lib.sh
.
The example below demonstrates the lowering processs of a MobileNet V2 model from torchvision for a Ethos-U55 target. Since the model is a floating point model, first quantize it using the EthosUQuantizer
. Then, pass an instance of the EthosUPartitioner
to to_edge_transform_and_lower
. Both the quantizer and the partitioner need a compilation specification created using ArmCompileSpecBuilder
.
import torch
from executorch.backends.arm.arm_backend import ArmCompileSpecBuilder
from executorch.backends.arm.ethosu_partitioner import EthosUPartitioner
from executorch.backends.arm.quantizer.arm_quantizer import (
EthosUQuantizer,
get_symmetric_quantization_config,
)
from executorch.exir import (
EdgeCompileConfig,
ExecutorchBackendConfig,
to_edge_transform_and_lower,
)
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
from torchvision.models import mobilenetv2
mobilenet_v2 = mobilenetv2.mobilenet_v2(
weights=mobilenetv2.MobileNet_V2_Weights.DEFAULT
).eval()
example_inputs = (torch.randn(1, 3, 224, 224),)
# .so suffix is .dylib on MacOS.
torch.ops.load_library(
"cmake-out-aot-lib/kernels/quantized/libquantized_ops_aot_lib.so"
)
compile_spec = ArmCompileSpecBuilder().ethosu_compile_spec(
"ethos-u55-128",
system_config="Ethos_U55_High_End_Embedded",
memory_mode="Shared_Sram",
extra_flags="--output-format=raw --debug-force-regor",
).build()
# Post training quantization
graph_module = torch.export.export_for_training(mobilenet_v2, example_inputs).module()
quantizer = EthosUQuantizer(compile_spec)
operator_config = get_symmetric_quantization_config(is_per_channel=False)
quantizer.set_global(operator_config)
graph_module = prepare_pt2e(graph_module, quantizer)
graph_module(*example_inputs)
graph_module = convert_pt2e(graph_module)
exported_program = torch.export.export_for_training(graph_module, example_inputs)
# Lower the exported program to the Ethos-U backend and save pte file.
edge_program_manager = to_edge_transform_and_lower(
exported_program,
partitioner=[EthosUPartitioner(compile_spec)],
compile_config=EdgeCompileConfig(
_check_ir_validity=False,
),
).to_executorch(config=ExecutorchBackendConfig(extract_delegate_segments=False))
with open("mv2_arm_ethos_u55.pte", "wb") as file:
edge_program_manager.write_to_file(file)
EthosUPartitioner
tries to partition as much of the model as possible. It will never delegate unsupported operators, but a user can pass additional checks to the constructor to avoid partitioning additional operators. To do this, subclass OperatorSupportBase
and implement the function is_node_supported
. A few such checks exist in executorch.exir.backend.operator_support
:
DontPartition
: Don't partition operators based on operator type.DontPartitionModule
: Don't partition operators based on which python module the operator comes from.DontPartitionName
: Don't partition opertors based on the operator name.
A fully integer model is required for using the Arm Ethos-U backend. As discussed above, you can quantize floating point models with the the EthosUQuantizer
. Quantizers are backend specific, which means the EthosUQuantizer
is configured to quantize models correctly for the target.
To run the model on-device, build the executorch library and EthosUDelegate using the script
executorch/backends/arm/scripts/build_executorch.sh
.
Then build the arm executorch runtime using the script
executorch/backends/arm/scripts/build_executorch_runner.sh --pte=mv2_arm_ethos_u55.pte --target=ethos-u55-128
.
Finally, run the elf file on FVP using the script
executorch/backends/arm/scripts/run_fvp.sh --elf=executorch/mv2_arm_ethos_u55/cmake-out/arm_executor_runner --target=ethos-u55-128
.