Skip to content

Commit 0b6a3cd

Browse files
ymao1elasticsearchmachinedavidkyle
authored
Expose input_type option at root level for text_embedding task type in Perform Inference API (#122638)
* wip * wip * [CI] Auto commit changes from spotless * Adding internal input types * [CI] Auto commit changes from spotless * Throwing validation exception for services that don't support input type * linting * hugging face * voyage ai * google ai studio * bedrock updates * Fixing tests * Fixing tests * Fixing tests * bedrock updates * elasticsearch * azure openai * [CI] Auto commit changes from spotless * Refactoring all the things * [CI] Auto commit changes from spotless * Everything compiles * spotless * external actions tests * external request tests * service tests * Fixing integration tests * Cleanup * Update docs/changelog/122638.yaml * Cleanup * Update x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java Co-authored-by: David Kyle <[email protected]> * Update x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java Co-authored-by: David Kyle <[email protected]> * PR feedback --------- Co-authored-by: elasticsearchmachine <[email protected]> Co-authored-by: David Kyle <[email protected]>
1 parent 053938a commit 0b6a3cd

File tree

183 files changed

+3286
-1505
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

183 files changed

+3286
-1505
lines changed

docs/changelog/122638.yaml

+7
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
pr: 122638
2+
summary: Expose `input_type` option at root level for `text_embedding` task type in
3+
Perform Inference API
4+
area: Machine Learning
5+
type: enhancement
6+
issues:
7+
- 117856

server/src/main/java/org/elasticsearch/inference/InputType.java

+23-1
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111

1212
import java.util.Locale;
1313

14+
import static org.elasticsearch.core.Strings.format;
15+
1416
/**
1517
* Defines the type of request, whether the request is to ingest a document or search for a document.
1618
*/
@@ -19,7 +21,11 @@ public enum InputType {
1921
SEARCH,
2022
UNSPECIFIED,
2123
CLASSIFICATION,
22-
CLUSTERING;
24+
CLUSTERING,
25+
26+
// Use the following enums when calling the inference API internally
27+
INTERNAL_SEARCH,
28+
INTERNAL_INGEST;
2329

2430
@Override
2531
public String toString() {
@@ -29,4 +35,20 @@ public String toString() {
2935
public static InputType fromString(String name) {
3036
return valueOf(name.trim().toUpperCase(Locale.ROOT));
3137
}
38+
39+
public static InputType fromRestString(String name) {
40+
var inputType = InputType.fromString(name);
41+
if (inputType == InputType.INTERNAL_INGEST || inputType == InputType.INTERNAL_SEARCH) {
42+
throw new IllegalArgumentException(format("Unrecognized input_type [%s]", inputType));
43+
}
44+
return inputType;
45+
}
46+
47+
public static boolean isInternalTypeOrUnspecified(InputType inputType) {
48+
return inputType == InputType.INTERNAL_INGEST || inputType == InputType.INTERNAL_SEARCH || inputType == InputType.UNSPECIFIED;
49+
}
50+
51+
public static boolean isSpecified(InputType inputType) {
52+
return inputType != null && inputType != InputType.UNSPECIFIED;
53+
}
3254
}

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceAction.java

+15-2
Original file line numberDiff line numberDiff line change
@@ -57,13 +57,15 @@ public static class Request extends BaseInferenceActionRequest {
5757

5858
public static final TimeValue DEFAULT_TIMEOUT = TimeValue.timeValueSeconds(30);
5959
public static final ParseField INPUT = new ParseField("input");
60+
public static final ParseField INPUT_TYPE = new ParseField("input_type");
6061
public static final ParseField TASK_SETTINGS = new ParseField("task_settings");
6162
public static final ParseField QUERY = new ParseField("query");
6263
public static final ParseField TIMEOUT = new ParseField("timeout");
6364

6465
static final ObjectParser<Request.Builder, Void> PARSER = new ObjectParser<>(NAME, Request.Builder::new);
6566
static {
6667
PARSER.declareStringArray(Request.Builder::setInput, INPUT);
68+
PARSER.declareString(Request.Builder::setInputType, INPUT_TYPE);
6769
PARSER.declareObject(Request.Builder::setTaskSettings, (p, c) -> p.mapOrdered(), TASK_SETTINGS);
6870
PARSER.declareString(Request.Builder::setQuery, QUERY);
6971
PARSER.declareString(Builder::setInferenceTimeout, TIMEOUT);
@@ -80,8 +82,6 @@ public static Builder parseRequest(String inferenceEntityId, TaskType taskType,
8082
Request.Builder builder = PARSER.apply(parser, null);
8183
builder.setInferenceEntityId(inferenceEntityId);
8284
builder.setTaskType(taskType);
83-
// For rest requests we won't know what the input type is
84-
builder.setInputType(InputType.UNSPECIFIED);
8585
builder.setContext(context);
8686
return builder;
8787
}
@@ -227,6 +227,14 @@ public ActionRequestValidationException validate() {
227227
}
228228
}
229229

230+
if (taskType.equals(TaskType.TEXT_EMBEDDING) == false
231+
&& taskType.equals(TaskType.ANY) == false
232+
&& (inputType != null && InputType.isInternalTypeOrUnspecified(inputType) == false)) {
233+
var e = new ActionRequestValidationException();
234+
e.addValidationError(format("Field [input_type] cannot be specified for task type [%s]", taskType));
235+
return e;
236+
}
237+
230238
return null;
231239
}
232240

@@ -335,6 +343,11 @@ public Builder setInputType(InputType inputType) {
335343
return this;
336344
}
337345

346+
public Builder setInputType(String inputType) {
347+
this.inputType = InputType.fromRestString(inputType);
348+
return this;
349+
}
350+
338351
public Builder setTaskSettings(Map<String, Object> taskSettings) {
339352
this.taskSettings = taskSettings;
340353
return this;

x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/InferenceActionRequestTests.java

+70
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,76 @@ public void testValidation_Rerank_Empty() {
174174
assertThat(queryEmptyError.getMessage(), is("Validation Failed: 1: Field [query] cannot be empty for task type [rerank];"));
175175
}
176176

177+
public void testValidation_Rerank_WithInputType() {
178+
InferenceAction.Request request = new InferenceAction.Request(
179+
TaskType.RERANK,
180+
"model",
181+
"query",
182+
List.of("input"),
183+
null,
184+
InputType.SEARCH,
185+
null,
186+
false
187+
);
188+
ActionRequestValidationException queryError = request.validate();
189+
assertNotNull(queryError);
190+
assertThat(queryError.getMessage(), is("Validation Failed: 1: Field [input_type] cannot be specified for task type [rerank];"));
191+
}
192+
193+
public void testValidation_SparseEmbedding_WithInputType() {
194+
InferenceAction.Request queryRequest = new InferenceAction.Request(
195+
TaskType.SPARSE_EMBEDDING,
196+
"model",
197+
"",
198+
List.of("input"),
199+
null,
200+
InputType.SEARCH,
201+
null,
202+
false
203+
);
204+
ActionRequestValidationException queryError = queryRequest.validate();
205+
assertNotNull(queryError);
206+
assertThat(
207+
queryError.getMessage(),
208+
is("Validation Failed: 1: Field [input_type] cannot be specified for task type [sparse_embedding];")
209+
);
210+
}
211+
212+
public void testValidation_Completion_WithInputType() {
213+
InferenceAction.Request queryRequest = new InferenceAction.Request(
214+
TaskType.COMPLETION,
215+
"model",
216+
"",
217+
List.of("input"),
218+
null,
219+
InputType.SEARCH,
220+
null,
221+
false
222+
);
223+
ActionRequestValidationException queryError = queryRequest.validate();
224+
assertNotNull(queryError);
225+
assertThat(queryError.getMessage(), is("Validation Failed: 1: Field [input_type] cannot be specified for task type [completion];"));
226+
}
227+
228+
public void testValidation_ChatCompletion_WithInputType() {
229+
InferenceAction.Request queryRequest = new InferenceAction.Request(
230+
TaskType.CHAT_COMPLETION,
231+
"model",
232+
"",
233+
List.of("input"),
234+
null,
235+
InputType.SEARCH,
236+
null,
237+
false
238+
);
239+
ActionRequestValidationException queryError = queryRequest.validate();
240+
assertNotNull(queryError);
241+
assertThat(
242+
queryError.getMessage(),
243+
is("Validation Failed: 1: Field [input_type] cannot be specified for task type [chat_completion];")
244+
);
245+
}
246+
177247
public void testParseRequest_DefaultsInputTypeToIngest() throws IOException {
178248
String singleInputRequest = """
179249
{

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/alibabacloudsearch/AlibabaCloudSearchActionCreator.java

+4-5
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77

88
package org.elasticsearch.xpack.inference.external.action.alibabacloudsearch;
99

10-
import org.elasticsearch.inference.InputType;
1110
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
1211
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
1312
import org.elasticsearch.xpack.inference.services.ServiceComponents;
@@ -32,15 +31,15 @@ public AlibabaCloudSearchActionCreator(Sender sender, ServiceComponents serviceC
3231
}
3332

3433
@Override
35-
public ExecutableAction create(AlibabaCloudSearchEmbeddingsModel model, Map<String, Object> taskSettings, InputType inputType) {
36-
var overriddenModel = AlibabaCloudSearchEmbeddingsModel.of(model, taskSettings, inputType);
34+
public ExecutableAction create(AlibabaCloudSearchEmbeddingsModel model, Map<String, Object> taskSettings) {
35+
var overriddenModel = AlibabaCloudSearchEmbeddingsModel.of(model, taskSettings);
3736

3837
return new AlibabaCloudSearchEmbeddingsAction(sender, overriddenModel, serviceComponents);
3938
}
4039

4140
@Override
42-
public ExecutableAction create(AlibabaCloudSearchSparseModel model, Map<String, Object> taskSettings, InputType inputType) {
43-
var overriddenModel = AlibabaCloudSearchSparseModel.of(model, taskSettings, inputType);
41+
public ExecutableAction create(AlibabaCloudSearchSparseModel model, Map<String, Object> taskSettings) {
42+
var overriddenModel = AlibabaCloudSearchSparseModel.of(model, taskSettings);
4443

4544
return new AlibabaCloudSearchSparseAction(sender, overriddenModel, serviceComponents);
4645
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/alibabacloudsearch/AlibabaCloudSearchActionVisitor.java

+2-3
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77

88
package org.elasticsearch.xpack.inference.external.action.alibabacloudsearch;
99

10-
import org.elasticsearch.inference.InputType;
1110
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
1211
import org.elasticsearch.xpack.inference.services.alibabacloudsearch.completion.AlibabaCloudSearchCompletionModel;
1312
import org.elasticsearch.xpack.inference.services.alibabacloudsearch.embeddings.AlibabaCloudSearchEmbeddingsModel;
@@ -17,9 +16,9 @@
1716
import java.util.Map;
1817

1918
public interface AlibabaCloudSearchActionVisitor {
20-
ExecutableAction create(AlibabaCloudSearchEmbeddingsModel model, Map<String, Object> taskSettings, InputType inputType);
19+
ExecutableAction create(AlibabaCloudSearchEmbeddingsModel model, Map<String, Object> taskSettings);
2120

22-
ExecutableAction create(AlibabaCloudSearchSparseModel model, Map<String, Object> taskSettings, InputType inputType);
21+
ExecutableAction create(AlibabaCloudSearchSparseModel model, Map<String, Object> taskSettings);
2322

2423
ExecutableAction create(AlibabaCloudSearchRerankModel model, Map<String, Object> taskSettings);
2524

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/alibabacloudsearch/AlibabaCloudSearchCompletionAction.java

+3-3
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
2020
import org.elasticsearch.xpack.inference.external.alibabacloudsearch.AlibabaCloudSearchAccount;
2121
import org.elasticsearch.xpack.inference.external.http.sender.AlibabaCloudSearchCompletionRequestManager;
22-
import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput;
22+
import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput;
2323
import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs;
2424
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
2525
import org.elasticsearch.xpack.inference.services.ServiceComponents;
@@ -51,7 +51,7 @@ public AlibabaCloudSearchCompletionAction(Sender sender, AlibabaCloudSearchCompl
5151

5252
@Override
5353
public void execute(InferenceInputs inferenceInputs, TimeValue timeout, ActionListener<InferenceServiceResults> listener) {
54-
if (inferenceInputs instanceof DocumentsOnlyInput == false) {
54+
if (inferenceInputs instanceof EmbeddingsInput == false) {
5555
listener.onFailure(
5656
new ElasticsearchStatusException(
5757
format("Invalid inference input type, task type [%s] do not support Field [query]", TaskType.COMPLETION),
@@ -61,7 +61,7 @@ public void execute(InferenceInputs inferenceInputs, TimeValue timeout, ActionLi
6161
return;
6262
}
6363

64-
var docsOnlyInput = (DocumentsOnlyInput) inferenceInputs;
64+
var docsOnlyInput = (EmbeddingsInput) inferenceInputs;
6565
if (docsOnlyInput.getInputs().size() % 2 == 0) {
6666
listener.onFailure(
6767
new ElasticsearchStatusException(

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereActionCreator.java

+2-3
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77

88
package org.elasticsearch.xpack.inference.external.action.cohere;
99

10-
import org.elasticsearch.inference.InputType;
1110
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
1211
import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction;
1312
import org.elasticsearch.xpack.inference.external.action.SingleInputSenderExecutableAction;
@@ -40,8 +39,8 @@ public CohereActionCreator(Sender sender, ServiceComponents serviceComponents) {
4039
}
4140

4241
@Override
43-
public ExecutableAction create(CohereEmbeddingsModel model, Map<String, Object> taskSettings, InputType inputType) {
44-
var overriddenModel = CohereEmbeddingsModel.of(model, taskSettings, inputType);
42+
public ExecutableAction create(CohereEmbeddingsModel model, Map<String, Object> taskSettings) {
43+
var overriddenModel = CohereEmbeddingsModel.of(model, taskSettings);
4544
var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage("Cohere embeddings");
4645
// TODO - Batching pass the batching class on to the CohereEmbeddingsRequestManager
4746
var requestCreator = CohereEmbeddingsRequestManager.of(overriddenModel, serviceComponents.threadPool());

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereActionVisitor.java

+1-2
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77

88
package org.elasticsearch.xpack.inference.external.action.cohere;
99

10-
import org.elasticsearch.inference.InputType;
1110
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
1211
import org.elasticsearch.xpack.inference.services.cohere.completion.CohereCompletionModel;
1312
import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsModel;
@@ -16,7 +15,7 @@
1615
import java.util.Map;
1716

1817
public interface CohereActionVisitor {
19-
ExecutableAction create(CohereEmbeddingsModel model, Map<String, Object> taskSettings, InputType inputType);
18+
ExecutableAction create(CohereEmbeddingsModel model, Map<String, Object> taskSettings);
2019

2120
ExecutableAction create(CohereRerankModel model, Map<String, Object> taskSettings);
2221

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/elastic/ElasticInferenceServiceActionCreator.java

+2-11
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77

88
package org.elasticsearch.xpack.inference.external.action.elastic;
99

10-
import org.elasticsearch.inference.InputType;
1110
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
1211
import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction;
1312
import org.elasticsearch.xpack.inference.external.http.sender.ElasticInferenceServiceSparseEmbeddingsRequestManager;
@@ -30,23 +29,15 @@ public class ElasticInferenceServiceActionCreator implements ElasticInferenceSer
3029

3130
private final TraceContext traceContext;
3231

33-
private final InputType inputType;
34-
35-
public ElasticInferenceServiceActionCreator(
36-
Sender sender,
37-
ServiceComponents serviceComponents,
38-
TraceContext traceContext,
39-
InputType inputType
40-
) {
32+
public ElasticInferenceServiceActionCreator(Sender sender, ServiceComponents serviceComponents, TraceContext traceContext) {
4133
this.sender = Objects.requireNonNull(sender);
4234
this.serviceComponents = Objects.requireNonNull(serviceComponents);
4335
this.traceContext = traceContext;
44-
this.inputType = inputType;
4536
}
4637

4738
@Override
4839
public ExecutableAction create(ElasticInferenceServiceSparseEmbeddingsModel model) {
49-
var requestManager = new ElasticInferenceServiceSparseEmbeddingsRequestManager(model, serviceComponents, traceContext, inputType);
40+
var requestManager = new ElasticInferenceServiceSparseEmbeddingsRequestManager(model, serviceComponents, traceContext);
5041
var errorMessage = constructFailedToSendRequestMessage(
5142
String.format(Locale.ROOT, "%s sparse embeddings", ELASTIC_INFERENCE_SERVICE_IDENTIFIER)
5243
);

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/googlevertexai/GoogleVertexAiActionCreator.java

+2-3
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77

88
package org.elasticsearch.xpack.inference.external.action.googlevertexai;
99

10-
import org.elasticsearch.inference.InputType;
1110
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
1211
import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction;
1312
import org.elasticsearch.xpack.inference.external.http.sender.GoogleVertexAiEmbeddingsRequestManager;
@@ -34,8 +33,8 @@ public GoogleVertexAiActionCreator(Sender sender, ServiceComponents serviceCompo
3433
}
3534

3635
@Override
37-
public ExecutableAction create(GoogleVertexAiEmbeddingsModel model, Map<String, Object> taskSettings, InputType inputType) {
38-
var overriddenModel = GoogleVertexAiEmbeddingsModel.of(model, taskSettings, inputType);
36+
public ExecutableAction create(GoogleVertexAiEmbeddingsModel model, Map<String, Object> taskSettings) {
37+
var overriddenModel = GoogleVertexAiEmbeddingsModel.of(model, taskSettings);
3938
var requestManager = new GoogleVertexAiEmbeddingsRequestManager(
4039
overriddenModel,
4140
serviceComponents.truncator(),

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/googlevertexai/GoogleVertexAiActionVisitor.java

+1-2
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77

88
package org.elasticsearch.xpack.inference.external.action.googlevertexai;
99

10-
import org.elasticsearch.inference.InputType;
1110
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
1211
import org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.GoogleVertexAiEmbeddingsModel;
1312
import org.elasticsearch.xpack.inference.services.googlevertexai.rerank.GoogleVertexAiRerankModel;
@@ -16,7 +15,7 @@
1615

1716
public interface GoogleVertexAiActionVisitor {
1817

19-
ExecutableAction create(GoogleVertexAiEmbeddingsModel model, Map<String, Object> taskSettings, InputType inputType);
18+
ExecutableAction create(GoogleVertexAiEmbeddingsModel model, Map<String, Object> taskSettings);
2019

2120
ExecutableAction create(GoogleVertexAiRerankModel model, Map<String, Object> taskSettings);
2221
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/huggingface/HuggingFaceActionVisitor.java

+2-2
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import org.elasticsearch.xpack.inference.services.huggingface.embeddings.HuggingFaceEmbeddingsModel;
1313

1414
public interface HuggingFaceActionVisitor {
15-
ExecutableAction create(HuggingFaceEmbeddingsModel mode);
15+
ExecutableAction create(HuggingFaceEmbeddingsModel model);
1616

17-
ExecutableAction create(HuggingFaceElserModel mode);
17+
ExecutableAction create(HuggingFaceElserModel model);
1818
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/jinaai/JinaAIActionCreator.java

+2-3
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77

88
package org.elasticsearch.xpack.inference.external.action.jinaai;
99

10-
import org.elasticsearch.inference.InputType;
1110
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
1211
import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction;
1312
import org.elasticsearch.xpack.inference.external.http.sender.JinaAIEmbeddingsRequestManager;
@@ -35,8 +34,8 @@ public JinaAIActionCreator(Sender sender, ServiceComponents serviceComponents) {
3534
}
3635

3736
@Override
38-
public ExecutableAction create(JinaAIEmbeddingsModel model, Map<String, Object> taskSettings, InputType inputType) {
39-
var overriddenModel = JinaAIEmbeddingsModel.of(model, taskSettings, inputType);
37+
public ExecutableAction create(JinaAIEmbeddingsModel model, Map<String, Object> taskSettings) {
38+
var overriddenModel = JinaAIEmbeddingsModel.of(model, taskSettings);
4039
var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage("JinaAI embeddings");
4140
var requestCreator = JinaAIEmbeddingsRequestManager.of(overriddenModel, serviceComponents.threadPool());
4241
return new SenderExecutableAction(sender, requestCreator, failedToSendRequestErrorMessage);

0 commit comments

Comments
 (0)