Skip to content

Commit

Permalink
[ML] Add option to disable inference process cache by default (#108784)
Browse files Browse the repository at this point in the history
* Add option to disable inference process cache by default

* Add test

* improve tests

* Update docs and improve code

---------

Co-authored-by: Elastic Machine <elasticmachine@users.noreply.github.com>
  • Loading branch information
maxhniebergall and elasticmachine committed May 19, 2024
1 parent a7e4423 commit a2008bd
Show file tree
Hide file tree
Showing 5 changed files with 96 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@ Currently only `pytorch` models are supported for deployment. Once deployed
the model can be used by the <<inference-processor,{infer-cap} processor>>
in an ingest pipeline or directly in the <<infer-trained-model>> API.

A model can be deployed multiple times by using deployment IDs. A deployment ID
must be unique and should not match any other deployment ID or model ID, unless
it is the same as the ID of the model being deployed. If `deployment_id` is not
A model can be deployed multiple times by using deployment IDs. A deployment ID
must be unique and should not match any other deployment ID or model ID, unless
it is the same as the ID of the model being deployed. If `deployment_id` is not
set, it defaults to the `model_id`.

Scaling inference performance can be achieved by setting the parameters
Expand Down Expand Up @@ -61,7 +61,7 @@ include::{es-ref-dir}/ml/ml-shared.asciidoc[tag=model-id]
`cache_size`::
(Optional, <<byte-units,byte value>>)
The inference cache size (in memory outside the JVM heap) per node for the
model. The default value is the size of the model as reported by the
model. In serverless, the cache is disabled by default. Otherwise, the default value is the size of the model as reported by the
`model_size_bytes` field in the <<get-trained-models-stats>>. To disable the
cache, `0b` can be provided.

Expand Down Expand Up @@ -165,8 +165,8 @@ The API returns the following results:
[[start-trained-model-deployment-deployment-id-example]]
=== Using deployment IDs

The following example starts a new deployment for the `my_model` trained model
with the ID `my_model_for_ingest`. The deployment ID an be used in {infer} API
The following example starts a new deployment for the `my_model` trained model
with the ID `my_model_for_ingest`. The deployment ID an be used in {infer} API
calls or in {infer} processors.

[source,console]
Expand All @@ -181,4 +181,4 @@ The `my_model` trained model can be deployed again with a different ID:
--------------------------------------------------
POST _ml/trained_models/my_model/deployment/_start?deployment_id=my_model_for_search
--------------------------------------------------
// TEST[skip:TBD]
// TEST[skip:TBD]
Original file line number Diff line number Diff line change
Expand Up @@ -1475,7 +1475,7 @@ public List<RestHandler> getRestHandlers(
restHandlers.add(new RestCatDataFrameAnalyticsAction());
}
if (machineLearningExtension.get().isNlpEnabled()) {
restHandlers.add(new RestStartTrainedModelDeploymentAction());
restHandlers.add(new RestStartTrainedModelDeploymentAction(machineLearningExtension.get().disableInferenceProcessCache()));
restHandlers.add(new RestStopTrainedModelDeploymentAction());
restHandlers.add(new RestInferTrainedModelDeploymentAction());
restHandlers.add(new RestUpdateTrainedModelDeploymentAction());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@ default boolean isLearningToRankEnabled() {
return false;
}

default boolean disableInferenceProcessCache() {
return false;
}

String[] getAnalyticsDestIndexAllowedSettings();

AbstractNodeAvailabilityZoneMapper getNodeAvailabilityZoneMapper(Settings settings, ClusterSettings clusterSettings);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,18 @@
@ServerlessScope(Scope.PUBLIC)
public class RestStartTrainedModelDeploymentAction extends BaseRestHandler {

public RestStartTrainedModelDeploymentAction(boolean disableInferenceProcessCache) {
super();
if (disableInferenceProcessCache) {
this.defaultCacheSize = ByteSizeValue.ZERO;
} else {
// Don't set the default cache size yet
defaultCacheSize = null;
}
}

private final ByteSizeValue defaultCacheSize;

@Override
public String getName() {
return "xpack_ml_start_trained_models_deployment_action";
Expand Down Expand Up @@ -98,6 +110,8 @@ protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient
request.setCacheSize(
ByteSizeValue.parseBytesSizeValue(restRequest.param(CACHE_SIZE.getPreferredName()), CACHE_SIZE.getPreferredName())
);
} else if (defaultCacheSize != null) {
request.setCacheSize(defaultCacheSize);
}
request.setQueueCapacity(restRequest.paramAsInt(QUEUE_CAPACITY.getPreferredName(), request.getQueueCapacity()));
request.setPriority(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.ml.rest.inference;

import org.apache.lucene.util.SetOnce;
import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.rest.RestRequest;
import org.elasticsearch.test.rest.FakeRestRequest;
import org.elasticsearch.test.rest.RestActionTestCase;
import org.elasticsearch.xpack.core.ml.action.CreateTrainedModelAssignmentAction;
import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction;
import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignmentTests;

import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.instanceOf;

public class RestStartTrainedModelDeploymentActionTests extends RestActionTestCase {

public void testCacheDisabled() {
final boolean disableInferenceProcessCache = true;
controller().registerHandler(new RestStartTrainedModelDeploymentAction(disableInferenceProcessCache));
SetOnce<Boolean> executeCalled = new SetOnce<>();
verifyingClient.setExecuteVerifier(((actionType, actionRequest) -> {
assertThat(actionRequest, instanceOf(StartTrainedModelDeploymentAction.Request.class));

var request = (StartTrainedModelDeploymentAction.Request) actionRequest;
assertThat(request.getCacheSize(), is(ByteSizeValue.ZERO));

executeCalled.set(true);
return createResponse();
}));

RestRequest inferenceRequest = new FakeRestRequest.Builder(xContentRegistry()).withMethod(RestRequest.Method.POST)
.withPath("_ml/trained_models/test_id/deployment/_start")
.build();
dispatchRequest(inferenceRequest);
assertThat(executeCalled.get(), equalTo(true));
}

public void testCacheEnabled() {
final boolean disableInferenceProcessCache = false;
controller().registerHandler(new RestStartTrainedModelDeploymentAction(disableInferenceProcessCache));
SetOnce<Boolean> executeCalled = new SetOnce<>();
verifyingClient.setExecuteVerifier(((actionType, actionRequest) -> {
assertThat(actionRequest, instanceOf(StartTrainedModelDeploymentAction.Request.class));

var request = (StartTrainedModelDeploymentAction.Request) actionRequest;
assertNull(request.getCacheSize());

executeCalled.set(true);
return createResponse();
}));

RestRequest inferenceRequest = new FakeRestRequest.Builder(xContentRegistry()).withMethod(RestRequest.Method.POST)
.withPath("_ml/trained_models/test_id/deployment/_start")
.build();
dispatchRequest(inferenceRequest);
assertThat(executeCalled.get(), equalTo(true));
}

private static CreateTrainedModelAssignmentAction.Response createResponse() {
return new CreateTrainedModelAssignmentAction.Response(TrainedModelAssignmentTests.randomInstance());
}
}

0 comments on commit a2008bd

Please sign in to comment.