Skip to content

Commit 01a0264

Browse files
lucylqfacebook-github-bot
authored andcommitted
Update jni runner
Summary: Add data map to JNI layer and LlamaModule ctor. Reviewed By: kirklandsign Differential Revision: D70597652
1 parent 5ac4a3c commit 01a0264

File tree

2 files changed

+20
-10
lines changed

2 files changed

+20
-10
lines changed

extension/android/jni/jni_layer_llama.cpp

+15-5
Original file line numberDiff line numberDiff line change
@@ -132,16 +132,18 @@ class ExecuTorchLlamaJni
132132
jint model_type_category,
133133
facebook::jni::alias_ref<jstring> model_path,
134134
facebook::jni::alias_ref<jstring> tokenizer_path,
135-
jfloat temperature) {
135+
jfloat temperature,
136+
facebook::jni::alias_ref<jstring> data_path) {
136137
return makeCxxInstance(
137-
model_type_category, model_path, tokenizer_path, temperature);
138+
model_type_category, model_path, tokenizer_path, temperature, data_path);
138139
}
139140

140141
ExecuTorchLlamaJni(
141142
jint model_type_category,
142143
facebook::jni::alias_ref<jstring> model_path,
143144
facebook::jni::alias_ref<jstring> tokenizer_path,
144-
jfloat temperature) {
145+
jfloat temperature,
146+
facebook::jni::alias_ref<jstring> data_path = nullptr) {
145147
#if defined(ET_USE_THREADPOOL)
146148
// Reserve 1 thread for the main thread.
147149
uint32_t num_performant_cores =
@@ -160,10 +162,18 @@ class ExecuTorchLlamaJni
160162
tokenizer_path->toStdString().c_str(),
161163
temperature);
162164
} else if (model_type_category == MODEL_TYPE_CATEGORY_LLM) {
163-
runner_ = std::make_unique<example::Runner>(
165+
if (data_path != nullptr) {
166+
runner_ = std::make_unique<example::Runner>(
164167
model_path->toStdString().c_str(),
165168
tokenizer_path->toStdString().c_str(),
166-
temperature);
169+
temperature,
170+
data_path->toStdString().c_str());
171+
} else {
172+
runner_ = std::make_unique<example::Runner>(
173+
model_path->toStdString().c_str(),
174+
tokenizer_path->toStdString().c_str(),
175+
temperature);
176+
}
167177
#if defined(EXECUTORCH_BUILD_MEDIATEK)
168178
} else if (model_type_category == MODEL_TYPE_MEDIATEK_LLAMA) {
169179
runner_ = std::make_unique<MTKLlamaRunner>(

extension/android/src/main/java/org/pytorch/executorch/LlamaModule.java

+5-5
Original file line numberDiff line numberDiff line change
@@ -39,16 +39,16 @@ public class LlamaModule {
3939

4040
@DoNotStrip
4141
private static native HybridData initHybrid(
42-
int modelType, String modulePath, String tokenizerPath, float temperature);
42+
int modelType, String modulePath, String tokenizerPath, float temperature, String dataPath);
4343

44-
/** Constructs a LLAMA Module for a model with given path, tokenizer, and temperature. */
45-
public LlamaModule(String modulePath, String tokenizerPath, float temperature) {
46-
mHybridData = initHybrid(MODEL_TYPE_TEXT, modulePath, tokenizerPath, temperature);
44+
/** Constructs a LLAMA Module for a model with given model path, tokenizer, temperature and data path. */
45+
public LlamaModule(String modulePath, String tokenizerPath, float temperature, String dataPath) {
46+
mHybridData = initHybrid(MODEL_TYPE_TEXT, modulePath, tokenizerPath, temperature, dataPath);
4747
}
4848

4949
/** Constructs a LLM Module for a model with given path, tokenizer, and temperature. */
5050
public LlamaModule(int modelType, String modulePath, String tokenizerPath, float temperature) {
51-
mHybridData = initHybrid(modelType, modulePath, tokenizerPath, temperature);
51+
mHybridData = initHybrid(modelType, modulePath, tokenizerPath, temperature, null);
5252
}
5353

5454
public void resetNative() {

0 commit comments

Comments
 (0)