Skip to content

[8.x] [ML] Inference duration and error metrics (#115876) #118700

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Dec 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/115876.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 115876
summary: Inference duration and error metrics
area: Machine Learning
type: enhancement
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,6 @@
import org.elasticsearch.xpack.inference.services.ibmwatsonx.IbmWatsonxService;
import org.elasticsearch.xpack.inference.services.mistral.MistralService;
import org.elasticsearch.xpack.inference.services.openai.OpenAiService;
import org.elasticsearch.xpack.inference.telemetry.ApmInferenceStats;
import org.elasticsearch.xpack.inference.telemetry.InferenceStats;

import java.util.ArrayList;
Expand Down Expand Up @@ -240,7 +239,7 @@ public Collection<?> createComponents(PluginServices services) {
shardBulkInferenceActionFilter.set(actionFilter);

var meterRegistry = services.telemetryProvider().getMeterRegistry();
var stats = new PluginComponentBinding<>(InferenceStats.class, ApmInferenceStats.create(meterRegistry));
var stats = new PluginComponentBinding<>(InferenceStats.class, InferenceStats.create(meterRegistry));

return List.of(modelRegistry, registry, httpClientManager, stats);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,16 @@

package org.elasticsearch.xpack.inference.action;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.support.ActionFilters;
import org.elasticsearch.action.support.HandledTransportAction;
import org.elasticsearch.common.logging.DeprecationLogger;
import org.elasticsearch.common.util.concurrent.EsExecutors;
import org.elasticsearch.common.xcontent.ChunkedToXContent;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.inference.InferenceService;
import org.elasticsearch.inference.InferenceServiceRegistry;
import org.elasticsearch.inference.InferenceServiceResults;
Expand All @@ -26,20 +29,22 @@
import org.elasticsearch.transport.TransportService;
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
import org.elasticsearch.xpack.inference.action.task.StreamingTaskManager;
import org.elasticsearch.xpack.inference.common.DelegatingProcessor;
import org.elasticsearch.xpack.inference.registry.ModelRegistry;
import org.elasticsearch.xpack.inference.telemetry.InferenceStats;
import org.elasticsearch.xpack.inference.telemetry.InferenceTimer;

import java.util.Set;
import java.util.stream.Collectors;

import static org.elasticsearch.core.Strings.format;
import static org.elasticsearch.xpack.inference.telemetry.InferenceStats.modelAttributes;
import static org.elasticsearch.xpack.inference.telemetry.InferenceStats.responseAttributes;

public class TransportInferenceAction extends HandledTransportAction<InferenceAction.Request, InferenceAction.Response> {
private static final Logger log = LogManager.getLogger(TransportInferenceAction.class);
private static final String STREAMING_INFERENCE_TASK_TYPE = "streaming_inference";
private static final String STREAMING_TASK_ACTION = "xpack/inference/streaming_inference[n]";

private static final Set<Class<? extends InferenceService>> supportsStreaming = Set.of();

private final ModelRegistry modelRegistry;
private final InferenceServiceRegistry serviceRegistry;
private final InferenceStats inferenceStats;
Expand All @@ -64,17 +69,22 @@ public TransportInferenceAction(

@Override
protected void doExecute(Task task, InferenceAction.Request request, ActionListener<InferenceAction.Response> listener) {
var timer = InferenceTimer.start();

ActionListener<UnparsedModel> getModelListener = listener.delegateFailureAndWrap((delegate, unparsedModel) -> {
var getModelListener = ActionListener.wrap((UnparsedModel unparsedModel) -> {
var service = serviceRegistry.getService(unparsedModel.service());
if (service.isEmpty()) {
listener.onFailure(unknownServiceException(unparsedModel.service(), request.getInferenceEntityId()));
var e = unknownServiceException(unparsedModel.service(), request.getInferenceEntityId());
recordMetrics(unparsedModel, timer, e);
listener.onFailure(e);
return;
}

if (request.getTaskType().isAnyOrSame(unparsedModel.taskType()) == false) {
// not the wildcard task type and not the model task type
listener.onFailure(incompatibleTaskTypeException(request.getTaskType(), unparsedModel.taskType()));
var e = incompatibleTaskTypeException(request.getTaskType(), unparsedModel.taskType());
recordMetrics(unparsedModel, timer, e);
listener.onFailure(e);
return;
}

Expand All @@ -85,20 +95,69 @@ protected void doExecute(Task task, InferenceAction.Request request, ActionListe
unparsedModel.settings(),
unparsedModel.secrets()
);
inferOnService(model, request, service.get(), delegate);
inferOnServiceWithMetrics(model, request, service.get(), timer, listener);
}, e -> {
try {
inferenceStats.inferenceDuration().record(timer.elapsedMillis(), responseAttributes(e));
} catch (Exception metricsException) {
log.atDebug().withThrowable(metricsException).log("Failed to record metrics when the model is missing, dropping metrics");
}
listener.onFailure(e);
});

modelRegistry.getModelWithSecrets(request.getInferenceEntityId(), getModelListener);
}

private void inferOnService(
private void recordMetrics(UnparsedModel model, InferenceTimer timer, @Nullable Throwable t) {
try {
inferenceStats.inferenceDuration().record(timer.elapsedMillis(), responseAttributes(model, t));
} catch (Exception e) {
log.atDebug().withThrowable(e).log("Failed to record metrics with an unparsed model, dropping metrics");
}
}

private void inferOnServiceWithMetrics(
Model model,
InferenceAction.Request request,
InferenceService service,
InferenceTimer timer,
ActionListener<InferenceAction.Response> listener
) {
inferenceStats.requestCount().incrementBy(1, modelAttributes(model));
inferOnService(model, request, service, ActionListener.wrap(inferenceResults -> {
if (request.isStreaming()) {
var taskProcessor = streamingTaskManager.<ChunkedToXContent>create(STREAMING_INFERENCE_TASK_TYPE, STREAMING_TASK_ACTION);
inferenceResults.publisher().subscribe(taskProcessor);

var instrumentedStream = new PublisherWithMetrics(timer, model);
taskProcessor.subscribe(instrumentedStream);

listener.onResponse(new InferenceAction.Response(inferenceResults, instrumentedStream));
} else {
recordMetrics(model, timer, null);
listener.onResponse(new InferenceAction.Response(inferenceResults));
}
}, e -> {
recordMetrics(model, timer, e);
listener.onFailure(e);
}));
}

private void recordMetrics(Model model, InferenceTimer timer, @Nullable Throwable t) {
try {
inferenceStats.inferenceDuration().record(timer.elapsedMillis(), responseAttributes(model, t));
} catch (Exception e) {
log.atDebug().withThrowable(e).log("Failed to record metrics with a parsed model, dropping metrics");
}
}

private void inferOnService(
Model model,
InferenceAction.Request request,
InferenceService service,
ActionListener<InferenceServiceResults> listener
) {
if (request.isStreaming() == false || service.canStream(request.getTaskType())) {
inferenceStats.incrementRequestCount(model);
service.infer(
model,
request.getQuery(),
Expand All @@ -107,7 +166,7 @@ private void inferOnService(
request.getTaskSettings(),
request.getInputType(),
request.getInferenceTimeout(),
createListener(request, listener)
listener
);
} else {
listener.onFailure(unsupportedStreamingTaskException(request, service));
Expand Down Expand Up @@ -135,20 +194,6 @@ private ElasticsearchStatusException unsupportedStreamingTaskException(Inference
}
}

private ActionListener<InferenceServiceResults> createListener(
InferenceAction.Request request,
ActionListener<InferenceAction.Response> listener
) {
if (request.isStreaming()) {
return listener.delegateFailureAndWrap((l, inferenceResults) -> {
var taskProcessor = streamingTaskManager.<ChunkedToXContent>create(STREAMING_INFERENCE_TASK_TYPE, STREAMING_TASK_ACTION);
inferenceResults.publisher().subscribe(taskProcessor);
l.onResponse(new InferenceAction.Response(inferenceResults, taskProcessor));
});
}
return listener.delegateFailureAndWrap((l, inferenceResults) -> l.onResponse(new InferenceAction.Response(inferenceResults)));
}

private static ElasticsearchStatusException unknownServiceException(String service, String inferenceId) {
return new ElasticsearchStatusException("Unknown service [{}] for model [{}]. ", RestStatus.BAD_REQUEST, service, inferenceId);
}
Expand All @@ -162,4 +207,37 @@ private static ElasticsearchStatusException incompatibleTaskTypeException(TaskTy
);
}

private class PublisherWithMetrics extends DelegatingProcessor<ChunkedToXContent, ChunkedToXContent> {
private final InferenceTimer timer;
private final Model model;

private PublisherWithMetrics(InferenceTimer timer, Model model) {
this.timer = timer;
this.model = model;
}

@Override
protected void next(ChunkedToXContent item) {
downstream().onNext(item);
}

@Override
public void onError(Throwable throwable) {
recordMetrics(model, timer, throwable);
super.onError(throwable);
}

@Override
protected void onCancel() {
recordMetrics(model, timer, null);
super.onCancel();
}

@Override
public void onComplete() {
recordMetrics(model, timer, null);
super.onComplete();
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,14 @@ public void request(long n) {
public void cancel() {
if (isClosed.compareAndSet(false, true) && upstream != null) {
upstream.cancel();
onCancel();
}
}
};
}

protected void onCancel() {}

@Override
public void onSubscribe(Flow.Subscription subscription) {
if (upstream != null) {
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,89 @@

package org.elasticsearch.xpack.inference.telemetry;

import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.inference.Model;
import org.elasticsearch.inference.UnparsedModel;
import org.elasticsearch.telemetry.metric.LongCounter;
import org.elasticsearch.telemetry.metric.LongHistogram;
import org.elasticsearch.telemetry.metric.MeterRegistry;

public interface InferenceStats {
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;
import java.util.stream.Stream;

/**
* Increment the counter for a particular value in a thread safe manner.
* @param model the model to increment request count for
*/
void incrementRequestCount(Model model);
import static java.util.Map.entry;
import static java.util.stream.Stream.concat;

InferenceStats NOOP = model -> {};
public record InferenceStats(LongCounter requestCount, LongHistogram inferenceDuration) {

public InferenceStats {
Objects.requireNonNull(requestCount);
Objects.requireNonNull(inferenceDuration);
}

public static InferenceStats create(MeterRegistry meterRegistry) {
return new InferenceStats(
meterRegistry.registerLongCounter(
"es.inference.requests.count.total",
"Inference API request counts for a particular service, task type, model ID",
"operations"
),
meterRegistry.registerLongHistogram(
"es.inference.requests.time",
"Inference API request counts for a particular service, task type, model ID",
"ms"
)
);
}

public static Map<String, Object> modelAttributes(Model model) {
return toMap(modelAttributeEntries(model));
}

private static Stream<Map.Entry<String, Object>> modelAttributeEntries(Model model) {
var stream = Stream.<Map.Entry<String, Object>>builder()
.add(entry("service", model.getConfigurations().getService()))
.add(entry("task_type", model.getTaskType().toString()));
if (model.getServiceSettings().modelId() != null) {
stream.add(entry("model_id", model.getServiceSettings().modelId()));
}
return stream.build();
}

private static Map<String, Object> toMap(Stream<Map.Entry<String, Object>> stream) {
return stream.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
}

public static Map<String, Object> responseAttributes(Model model, @Nullable Throwable t) {
return toMap(concat(modelAttributeEntries(model), errorAttributes(t)));
}

public static Map<String, Object> responseAttributes(UnparsedModel model, @Nullable Throwable t) {
var unknownModelAttributes = Stream.<Map.Entry<String, Object>>builder()
.add(entry("service", model.service()))
.add(entry("task_type", model.taskType().toString()))
.build();

return toMap(concat(unknownModelAttributes, errorAttributes(t)));
}

public static Map<String, Object> responseAttributes(@Nullable Throwable t) {
return toMap(errorAttributes(t));
}

private static Stream<Map.Entry<String, Object>> errorAttributes(@Nullable Throwable t) {
if (t == null) {
return Stream.of(entry("status_code", 200));
} else if (t instanceof ElasticsearchStatusException ese) {
return Stream.<Map.Entry<String, Object>>builder()
.add(entry("status_code", ese.status().getStatus()))
.add(entry("error.type", String.valueOf(ese.status().getStatus())))
.build();
} else {
return Stream.of(entry("error.type", t.getClass().getSimpleName()));
}
}
}
Loading