@@ -132,16 +132,18 @@ class ExecuTorchLlamaJni
132
132
jint model_type_category,
133
133
facebook::jni::alias_ref<jstring> model_path,
134
134
facebook::jni::alias_ref<jstring> tokenizer_path,
135
- jfloat temperature) {
135
+ jfloat temperature,
136
+ facebook::jni::alias_ref<jstring> data_path) {
136
137
return makeCxxInstance (
137
- model_type_category, model_path, tokenizer_path, temperature);
138
+ model_type_category, model_path, tokenizer_path, temperature, data_path );
138
139
}
139
140
140
141
ExecuTorchLlamaJni (
141
142
jint model_type_category,
142
143
facebook::jni::alias_ref<jstring> model_path,
143
144
facebook::jni::alias_ref<jstring> tokenizer_path,
144
- jfloat temperature) {
145
+ jfloat temperature,
146
+ facebook::jni::alias_ref<jstring> data_path = nullptr ) {
145
147
#if defined(ET_USE_THREADPOOL)
146
148
// Reserve 1 thread for the main thread.
147
149
uint32_t num_performant_cores =
@@ -160,10 +162,18 @@ class ExecuTorchLlamaJni
160
162
tokenizer_path->toStdString ().c_str (),
161
163
temperature);
162
164
} else if (model_type_category == MODEL_TYPE_CATEGORY_LLM) {
163
- runner_ = std::make_unique<example::Runner>(
165
+ if (data_path != nullptr ) {
166
+ runner_ = std::make_unique<example::Runner>(
164
167
model_path->toStdString ().c_str (),
165
168
tokenizer_path->toStdString ().c_str (),
166
- temperature);
169
+ temperature,
170
+ data_path->toStdString ().c_str ());
171
+ } else {
172
+ runner_ = std::make_unique<example::Runner>(
173
+ model_path->toStdString ().c_str (),
174
+ tokenizer_path->toStdString ().c_str (),
175
+ temperature);
176
+ }
167
177
#if defined(EXECUTORCH_BUILD_MEDIATEK)
168
178
} else if (model_type_category == MODEL_TYPE_MEDIATEK_LLAMA) {
169
179
runner_ = std::make_unique<MTKLlamaRunner>(
0 commit comments