Skip to content

Commit bf4cf3f

Browse files
author
zhenyanzhang
committed
[ExecuTorch][#10375] Add extension.BundledModule to Wrap extension.Module with Bundled Program Logic
#10375 # Context This issue is a step of #9638. In #9638, we want to have `extension.Module` as the single source of implementation in `pybindings`, which means that `pybindings.PyModule` should use `extension.Module` rather than its own `pybindings.Module`. The issue is that `pybindings.PyModule` is dependent on the `method` getter from `pybindings.Module`, which `extension.Module` do not have. Since we don't want to expose `method` getter in `extension.Module`, we have to protect the getter, wrap the functions that is dependent on it and use the protected getter there, ultimately decouple `pybindings` from a `method` getter. # Proposal Now that we have a protected `method` getter, we can introduce a `extension.BundledModule`, a child class inheriting `extension.Module` which wraps up bundled program logic that is dependent on the `method` getter. Differential Revision: [D73564125](https://our.internmc.facebook.com/intern/diff/D73564125/) ghstack-source-id: 279969988 Pull Request resolved: #10449
1 parent 59dce3d commit bf4cf3f

File tree

4 files changed

+132
-16
lines changed

4 files changed

+132
-16
lines changed

devtools/bundled_program/schema/targets.bzl

+1
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ def define_common_targets():
7474
visibility = [
7575
"//executorch/devtools/bundled_program/...",
7676
"//executorch/extension/pybindings/...",
77+
"//executorch/extension/module/...",
7778
],
7879
exported_headers = {
7980
OUTPUT_BUNDLED_HEADER: ":{}[{}]".format(BUNDLED_GEN_RULE_NAME, OUTPUT_BUNDLED_HEADER),

extension/module/module.cpp

+54-8
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,9 @@
88

99
#include <executorch/extension/module/module.h>
1010

11+
#include <executorch/devtools/bundled_program/bundled_program.h>
12+
#include <executorch/devtools/bundled_program/schema/bundled_program_schema_generated.h>
13+
#include <executorch/extension/data_loader/buffer_data_loader.h>
1114
#include <executorch/extension/data_loader/file_data_loader.h>
1215
#include <executorch/extension/data_loader/mmap_data_loader.h>
1316
#include <executorch/extension/flat_tensor/flat_tensor_data_map.h>
@@ -302,14 +305,57 @@ runtime::Error Module::set_output(
302305
output_tensor.mutable_data_ptr(), output_tensor.nbytes(), output_index);
303306
}
304307

305-
ET_NODISCARD inline runtime::Result<Method*> Module::get_method(
306-
const std::string& method_name) {
307-
ET_CHECK_OR_RETURN_ERROR(
308-
methods_.count(method_name) > 0,
309-
InvalidArgument,
310-
"no such method in program: %s",
311-
method_name.c_str());
312-
return methods_[method_name].method.get();
308+
namespace {
309+
std::unique_ptr<BufferDataLoader> program_data_loader(
310+
const void* bundled_program_ptr) {
311+
auto bundled_program =
312+
bundled_program_flatbuffer::GetBundledProgram(bundled_program_ptr);
313+
// the program inside the bundled program
314+
auto program = bundled_program->program();
315+
return std::make_unique<BufferDataLoader>(program->data(), program->size());
316+
}
317+
} // namespace
318+
319+
BundledModule::BundledModule(
320+
const void* bundled_program_ptr,
321+
std::unique_ptr<runtime::MemoryAllocator> memory_allocator,
322+
std::unique_ptr<runtime::MemoryAllocator> temp_allocator,
323+
std::unique_ptr<runtime::EventTracer> event_tracer,
324+
std::unique_ptr<runtime::DataLoader> data_map_loader)
325+
: Module(
326+
program_data_loader(bundled_program_ptr),
327+
std::move(memory_allocator),
328+
std::move(temp_allocator),
329+
std::move(event_tracer),
330+
std::move(data_map_loader)),
331+
bundled_program_ptr_(bundled_program_ptr) {}
332+
333+
runtime::Error BundledModule::load_bundled_input(
334+
const std::string& method_name,
335+
const size_t testset_idx) {
336+
ET_CHECK_OK_OR_RETURN_ERROR(load_method(method_name));
337+
auto& method = methods_.at(method_name).method;
338+
auto& inputs = methods_.at(method_name).inputs;
339+
340+
auto status = executorch::BUNDLED_PROGRAM_NAMESPACE::load_bundled_input(
341+
*method, bundled_program_ptr_, testset_idx);
342+
ET_CHECK_OK_OR_RETURN_ERROR(
343+
status,
344+
"Bundled Program's load_bundled_input failed with status 0x%" PRIx32,
345+
static_cast<uint32_t>(status));
346+
347+
return method->get_inputs(inputs.data(), inputs.size());
348+
}
349+
350+
runtime::Error BundledModule::verify_method_outputs(
351+
const std::string& method_name,
352+
const size_t testset_idx,
353+
double rtol,
354+
double atol) {
355+
ET_CHECK_OK_OR_RETURN_ERROR(load_method(method_name));
356+
auto& method = methods_.at(method_name).method;
357+
return executorch::BUNDLED_PROGRAM_NAMESPACE::verify_method_outputs(
358+
*method, bundled_program_ptr_, testset_idx, rtol, atol);
313359
}
314360

315361
} // namespace extension

extension/module/module.h

+74-8
Original file line numberDiff line numberDiff line change
@@ -493,19 +493,85 @@ class Module {
493493
std::unique_ptr<NamedDataMap> data_map_;
494494

495495
protected:
496+
std::unordered_map<std::string, MethodHolder> methods_;
497+
498+
friend class ExecuTorchJni;
499+
};
500+
501+
/**
502+
* A facade class for loading bundled programs and executing methods within
503+
* them.
504+
*/
505+
class BundledModule : public Module {
506+
public:
496507
/**
497-
* Get a method by method name.
508+
* Constructs an instance with the bundled program buffer pointer.
498509
*
499-
* @param[in] method_name The name of the method to get.
510+
* This constructor reads the program from bundled program buffer to load the
511+
* module with data loader. The bundled program pointer is preserved so that
512+
* the portion outside of program is accessible.
500513
*
501-
* @returns A Result object containing either a pointer to the requested
502-
* method or an error to indicate failure.
514+
* @param[in] bundled_program_ptr A DataLoader used for loading program data.
515+
* @param[in] memory_allocator A MemoryAllocator used for memory management.
516+
* @param[in] temp_allocator A MemoryAllocator to use when allocating
517+
* temporary data during kernel or delegate execution.
518+
* @param[in] event_tracer A EventTracer used for tracking and logging events.
519+
* @param[in] data_map_loader A DataLoader used for loading external weights.
503520
*/
504-
ET_NODISCARD inline runtime::Result<Method*> get_method(
505-
const std::string& method_name);
506-
std::unordered_map<std::string, MethodHolder> methods_;
521+
explicit BundledModule(
522+
const void* bundled_program_ptr,
523+
std::unique_ptr<runtime::MemoryAllocator> memory_allocator = nullptr,
524+
std::unique_ptr<runtime::MemoryAllocator> temp_allocator = nullptr,
525+
std::unique_ptr<runtime::EventTracer> event_tracer = nullptr,
526+
std::unique_ptr<runtime::DataLoader> data_map_loader = nullptr);
507527

508-
friend class ExecuTorchJni;
528+
BundledModule(const BundledModule&) = delete;
529+
BundledModule& operator=(const BundledModule&) = delete;
530+
BundledModule(Module&&) = delete;
531+
BundledModule& operator=(BundledModule&&) = delete;
532+
533+
/**
534+
* Execute a specific method with the input value at the given testset_idx
535+
* from the program bundle.
536+
*
537+
* Before execution, this function loads the program and method with
538+
* load_bundled_input in bundled_program.
539+
*
540+
* @param[in] method_name The name of the method to execute.
541+
* @param[in] testset_idx The index of the input value to be passed to
542+
* the method.
543+
*
544+
* @returns A Result object containing either a vector of output values
545+
* from the method or an error to indicate failure.
546+
*/
547+
ET_NODISCARD
548+
runtime::Error load_bundled_input(
549+
const std::string& method_name,
550+
const size_t testset_idx);
551+
552+
/**
553+
* Verify the output of a specific method with the expected output from the
554+
* program bundle at the given testset_idx.
555+
*
556+
* This function is a wrapper of verify_method_outputs in bundled_program.
557+
*
558+
* @param[in] method_name The name of the method to extract outputs from.
559+
* @param[in] testset_idx The index of expected output needs to be compared.
560+
* @param[in] rtol Relative tolerance used for data comparsion.
561+
* @param[in] atol Absolute tolerance used for data comparsion.
562+
*
563+
* @returns Return Error::Ok if two outputs match, or the error happens during
564+
* execution.
565+
*/
566+
ET_NODISCARD
567+
runtime::Error verify_method_outputs(
568+
const std::string& method_name,
569+
const size_t testset_idx,
570+
double rtol = 1e-5,
571+
double atol = 1e-8);
572+
573+
private:
574+
const void* bundled_program_ptr_;
509575
};
510576

511577
} // namespace extension

extension/module/targets.bzl

+3
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@ def define_common_targets():
2222
"@EXECUTORCH_CLIENTS",
2323
],
2424
deps = [
25+
"//executorch/extension/data_loader:buffer_data_loader",
26+
"//executorch/devtools/bundled_program:runtime",
27+
"//executorch/devtools/bundled_program/schema:bundled_program_schema_fbs",
2528
"//executorch/extension/memory_allocator:malloc_memory_allocator",
2629
"//executorch/extension/data_loader:file_data_loader",
2730
"//executorch/extension/data_loader:mmap_data_loader",

0 commit comments

Comments
 (0)