Skip to content

Csv and json multilines #9

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ WORKDIR /sagemaker-sparkml-model-server

RUN mvn clean package

RUN cp ./target/sparkml-serving-2.3.jar /usr/local/lib/sparkml-serving-2.3.jar
RUN cp ./target/sparkml-serving-2.4.jar /usr/local/lib/sparkml-serving-2.4.jar
RUN cp ./serve.sh /usr/local/bin/serve.sh

RUN chmod a+x /usr/local/bin/serve.sh
Expand Down
36 changes: 18 additions & 18 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -223,20 +223,20 @@ Calling `CreateModel` is required for creating a `Model` in SageMaker with this
SageMaker works with Docker images stored in [Amazon ECR](https://aws.amazon.com/ecr/). SageMaker team has prepared and uploaded the Docker images for SageMaker SparkML Serving Container in all regions where SageMaker operates.
Region to ECR container URL mapping can be found below. For a mapping from Region to Region Name, please see [here](https://docs.aws.amazon.com/AmazonRDS/latest/UserGuide/Concepts.RegionsAndAvailabilityZones.html).

* us-west-1 = 746614075791.dkr.ecr.us-west-1.amazonaws.com/sagemaker-sparkml-serving:2.2
* us-west-2 = 246618743249.dkr.ecr.us-west-2.amazonaws.com/sagemaker-sparkml-serving:2.2
* us-east-1 = 683313688378.dkr.ecr.us-east-1.amazonaws.com/sagemaker-sparkml-serving:2.2
* us-east-2 = 257758044811.dkr.ecr.us-east-2.amazonaws.com/sagemaker-sparkml-serving:2.2
* ap-northeast-1 = 354813040037.dkr.ecr.ap-northeast-1.amazonaws.com/sagemaker-sparkml-serving:2.2
* ap-northeast-2 = 366743142698.dkr.ecr.ap-northeast-2.amazonaws.com/sagemaker-sparkml-serving:2.2
* ap-southeast-1 = 121021644041.dkr.ecr.ap-southeast-1.amazonaws.com/sagemaker-sparkml-serving:2.2
* ap-southeast-2 = 783357654285.dkr.ecr.ap-southeast-2.amazonaws.com/sagemaker-sparkml-serving:2.2
* ap-south-1 = 720646828776.dkr.ecr.ap-south-1.amazonaws.com/sagemaker-sparkml-serving:2.2
* eu-west-1 = 141502667606.dkr.ecr.eu-west-1.amazonaws.com/sagemaker-sparkml-serving:2.2
* eu-west-2 = 764974769150.dkr.ecr.eu-west-2.amazonaws.com/sagemaker-sparkml-serving:2.2
* eu-central-1 = 492215442770.dkr.ecr.eu-central-1.amazonaws.com/sagemaker-sparkml-serving:2.2
* ca-central-1 = 341280168497.dkr.ecr.ca-central-1.amazonaws.com/sagemaker-sparkml-serving:2.2
* us-gov-west-1 = 414596584902.dkr.ecr.us-gov-west-1.amazonaws.com/sagemaker-sparkml-serving:2.2
* us-west-1 = 746614075791.dkr.ecr.us-west-1.amazonaws.com/sagemaker-sparkml-serving:2.4
* us-west-2 = 246618743249.dkr.ecr.us-west-2.amazonaws.com/sagemaker-sparkml-serving:2.4
* us-east-1 = 683313688378.dkr.ecr.us-east-1.amazonaws.com/sagemaker-sparkml-serving:2.4
* us-east-2 = 257758044811.dkr.ecr.us-east-2.amazonaws.com/sagemaker-sparkml-serving:2.4
* ap-northeast-1 = 354813040037.dkr.ecr.ap-northeast-1.amazonaws.com/sagemaker-sparkml-serving:2.4
* ap-northeast-2 = 366743142698.dkr.ecr.ap-northeast-2.amazonaws.com/sagemaker-sparkml-serving:2.4
* ap-southeast-1 = 121021644041.dkr.ecr.ap-southeast-1.amazonaws.com/sagemaker-sparkml-serving:2.4
* ap-southeast-2 = 783357654285.dkr.ecr.ap-southeast-2.amazonaws.com/sagemaker-sparkml-serving:2.4
* ap-south-1 = 720646828776.dkr.ecr.ap-south-1.amazonaws.com/sagemaker-sparkml-serving:2.4
* eu-west-1 = 141502667606.dkr.ecr.eu-west-1.amazonaws.com/sagemaker-sparkml-serving:2.4
* eu-west-2 = 764974769150.dkr.ecr.eu-west-2.amazonaws.com/sagemaker-sparkml-serving:2.4
* eu-central-1 = 492215442770.dkr.ecr.eu-central-1.amazonaws.com/sagemaker-sparkml-serving:2.4
* ca-central-1 = 341280168497.dkr.ecr.ca-central-1.amazonaws.com/sagemaker-sparkml-serving:2.4
* us-gov-west-1 = 414596584902.dkr.ecr.us-gov-west-1.amazonaws.com/sagemaker-sparkml-serving:2.4

With [SageMaker Python SDK](https://github.com./aws/sagemaker-python-sdk)
------------------------------------------------------------------------
Expand All @@ -263,7 +263,7 @@ First you need to ensure that have installed [Docker](https://www.docker.com/) o
In order to build the Docker image, you need to run a single Docker command:

```
docker build -t sagemaker-sparkml-serving:2.2 .
docker build -t sagemaker-sparkml-serving:2.4 .
```

#### Running the image locally
Expand All @@ -272,7 +272,7 @@ In order to run the Docker image, you need to run the following command. Please
The command will start the server on port 8080 and will also pass the schema as an environment variable to the Docker container. Alternatively, you can edit the `Dockerfile` to add `ENV SAGEMAKER_SPARKML_SCHEMA=schema` as well before building the Docker image.

```
docker run -p 8080:8080 -e SAGEMAKER_SPARKML_SCHEMA=schema -v /tmp/model:/opt/ml/model sagemaker-sparkml-serving:2.2 serve
docker run -p 8080:8080 -e SAGEMAKER_SPARKML_SCHEMA=schema -v /tmp/model:/opt/ml/model sagemaker-sparkml-serving:2.4 serve
```

#### Invoking with a payload
Expand All @@ -287,7 +287,7 @@ or
curl -i -H "content-type:application/json" -d "{\"data\":[feature_1,\"feature_2\",feature_3]}" http://localhost:8080/invocations
```

The `Dockerfile` can be found at the root directory of the package. SageMaker SparkML Serving Container tags the Docker images using the Spark major version it is compatible with. Right now, it only supports Spark 2.2 and as a result, the Docker image is tagged with 2.2.
The `Dockerfile` can be found at the root directory of the package. SageMaker SparkML Serving Container tags the Docker images using the Spark major version it is compatible with. Right now, it only supports Spark 2.4 and as a result, the Docker image is tagged with 2.4.

In order to save the effort of building the Docker image everytime you are making a code change, you can also install [Maven](http://maven.apache.org/) and run `mvn clean package` at your project root to verify if the code is compiling fine and unit tests are running without any issue.

Expand All @@ -310,7 +310,7 @@ aws ecr get-login --region us-west-2 --registry-ids 246618743249 --no-include-em
* Download the Docker image with the following command:

```
docker pull 246618743249.dkr.ecr.us-west-2.amazonaws.com/sagemaker-sparkml-serving:2.2
docker pull 246618743249.dkr.ecr.us-west-2.amazonaws.com/sagemaker-sparkml-serving:2.4
```

For running the Docker image, please see the Running the image locally section from above.
Expand Down
6 changes: 3 additions & 3 deletions ci/buildspec.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@ phases:
commands:
- echo Build started on `date`
- echo Building the Docker image...
- docker build -t sagemaker-sparkml-serving:2.3 .
- docker tag sagemaker-sparkml-serving:2.3 515193369038.dkr.ecr.us-west-2.amazonaws.com/sagemaker-sparkml-serving:2.3
- docker build -t sagemaker-sparkml-serving:2.4 .
- docker tag sagemaker-sparkml-serving:2.4 515193369038.dkr.ecr.us-west-2.amazonaws.com/sagemaker-sparkml-serving:2.4
post_build:
commands:
- echo Build completed on `date`
- echo Pushing the Docker image...
- docker push 515193369038.dkr.ecr.us-west-2.amazonaws.com/sagemaker-sparkml-serving:2.3
- docker push 515193369038.dkr.ecr.us-west-2.amazonaws.com/sagemaker-sparkml-serving:2.4
6 changes: 3 additions & 3 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
<modelVersion>4.0.0</modelVersion>
<groupId>org.amazonaws.sagemaker</groupId>
<artifactId>sparkml-serving</artifactId>
<version>2.3</version>
<version>2.4</version>
<build>
<plugins>
<plugin>
Expand Down Expand Up @@ -154,7 +154,7 @@
<dependency>
<groupId>ml.combust.mleap</groupId>
<artifactId>mleap-runtime_2.11</artifactId>
<version>0.13.0</version>
<version>0.14.0</version>
</dependency>
<dependency>
<groupId>org.apache.commons</groupId>
Expand Down Expand Up @@ -199,4 +199,4 @@
<properties>
<java.version>1.8</java.version>
</properties>
</project>
</project>
2 changes: 1 addition & 1 deletion serve.sh
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
#!/bin/bash
# This is needed to make sure Java correctly detects CPU/Memory set by the container limits
java -XX:+UnlockExperimentalVMOptions -XX:+UseCGroupMemoryLimitForHeap -jar /usr/local/lib/sparkml-serving-2.3.jar
java -XX:+UnlockExperimentalVMOptions -XX:+UseCGroupMemoryLimitForHeap -jar /usr/local/lib/sparkml-serving-2.4.jar
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,14 @@
import com.google.common.base.Preconditions;
import com.google.common.collect.Lists;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.stream.Collectors;

import ml.combust.mleap.runtime.frame.ArrayRow;
import ml.combust.mleap.runtime.frame.DefaultLeapFrame;
import ml.combust.mleap.runtime.frame.Row;
import ml.combust.mleap.runtime.frame.Transformer;
import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
Expand Down Expand Up @@ -99,23 +104,54 @@ public ResponseEntity returnBatchExecutionParameter() throws JsonProcessingExcep
}

/**
* Implements the invocations POST API for application/json input
* Implements the invocations POST API for application/jsonlines input
*
* @param sro, the request object
* @param accept, accept parameter from request
* @return ResponseEntity with body as the expected payload JSON & proper statuscode based on the input
*/
@RequestMapping(path = "/invocations", method = POST, consumes = MediaType.APPLICATION_JSON_VALUE)
public ResponseEntity<String> transformRequestJson(@RequestBody final SageMakerRequestObject sro,
@RequestHeader(value = HttpHeaders.ACCEPT, required = false) final String accept) {
@RequestHeader(value = HttpHeaders.ACCEPT, required = false) final String accept) {
if (sro == null) {
LOG.error("Input passed to the request is empty");
return ResponseEntity.noContent().build();
}
try {
final String acceptVal = this.retrieveAndVerifyAccept(accept);
final DataSchema schema = this.retrieveAndVerifySchema(sro.getSchema(), mapper);
return this.processInputData(sro.getData(), schema, acceptVal);
return this.processInputData(Collections.singletonList(sro.getData()), schema, acceptVal);
} catch (final Exception ex) {
LOG.error("Error in processing current request", ex);
return ResponseEntity.badRequest().body(ex.getMessage());
}
}

/**
* Implements the invocations POST API for application/json input
*
* @param jsonLines, lines of json values
* @param accept, accept parameter from request
* @return ResponseEntity with body as the expected payload JSON & proper statuscode based on the input
*/
@RequestMapping(path = "/invocations", method = POST, consumes = AdditionalMediaType.APPLICATION_JSONLINES_VALUE)
public ResponseEntity<String> transformRequestJsonLines(@RequestBody final byte[] jsonLines,
@RequestHeader(value = HttpHeaders.ACCEPT, required = false) final String accept) {
if (jsonLines == null) {
LOG.error("Input passed to the request is empty");
return ResponseEntity.noContent().build();
}
try {
final String acceptVal = this.retrieveAndVerifyAccept(accept);
final DataSchema schema = this.retrieveAndVerifySchema(null, mapper);
final String jsonStringLines[] = new String(jsonLines).split("\\r?\\n");
final List<List<Object>> inputDatas = new ArrayList();
for(String jsonStringLine : jsonStringLines) {
final ObjectMapper mapper = new ObjectMapper();
final SageMakerRequestObject sro = mapper.readValue(jsonStringLine, SageMakerRequestObject.class);
inputDatas.add(sro.getData());
}
return this.processInputData(inputDatas, schema, acceptVal);
} catch (final Exception ex) {
LOG.error("Error in processing current request", ex);
return ResponseEntity.badRequest().body(ex.getMessage());
Expand Down Expand Up @@ -169,14 +205,14 @@ protected DataSchema retrieveAndVerifySchema(final DataSchema schemaFromPayload,
: mapper.readValue(SystemUtils.getEnvironmentVariable("SAGEMAKER_SPARKML_SCHEMA"), DataSchema.class);
}

private ResponseEntity<String> processInputData(final List<Object> inputData, final DataSchema schema,
private ResponseEntity<String> processInputData(final List<List<Object>> inputDatas, final DataSchema schema,
final String acceptVal) throws JsonProcessingException {
final DefaultLeapFrame leapFrame = dataConversionHelper.convertInputToLeapFrame(schema, inputData);
final DefaultLeapFrame leapFrame = dataConversionHelper.convertInputToLeapFrame(schema, inputDatas);
// Making call to the MLeap executor to get the output
final DefaultLeapFrame totalLeapFrame = ScalaUtils.transformLeapFrame(mleapTransformer, leapFrame);
final DefaultLeapFrame predictionsLeapFrame = ScalaUtils
.selectFromLeapFrame(totalLeapFrame, schema.getOutput().getName());
final ArrayRow outputArrayRow = ScalaUtils.getOutputArrayRow(predictionsLeapFrame);
final List<Row> outputArrayRow = ScalaUtils.getOutputArrayRow(predictionsLeapFrame);
return transformToHttpResponse(schema, outputArrayRow, acceptVal);

}
Expand All @@ -186,17 +222,18 @@ private boolean checkEmptyAccept(final String acceptFromRequest) {
return (StringUtils.isBlank(acceptFromRequest) || StringUtils.equals(acceptFromRequest, MediaType.ALL_VALUE));
}

private ResponseEntity<String> transformToHttpResponse(final DataSchema schema, final ArrayRow predictionRow,
private ResponseEntity<String> transformToHttpResponse(final DataSchema schema, final List<Row> predictionsRow,
final String accept) throws JsonProcessingException {

if (StringUtils.equals(schema.getOutput().getStruct(), DataStructureType.BASIC)) {
final Object output = dataConversionHelper
.convertMLeapBasicTypeToJavaType(predictionRow, schema.getOutput().getType());
.convertMLeapBasicTypeToJavaType(predictionsRow.get(0), schema.getOutput().getType());
return responseHelper.sendResponseForSingleValue(output.toString(), accept);
} else {
// If not basic type, it can be vector or array type from Spark
return responseHelper.sendResponseForList(
ScalaUtils.getJavaObjectIteratorFromArrayRow(predictionRow, schema.getOutput().getStruct()), accept);
predictionsRow.stream().map(predictionRow -> ScalaUtils.getJavaObjectIteratorFromArrayRow(predictionRow, schema.getOutput().getStruct())).collect(Collectors.toList())
, accept);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,19 +70,26 @@ public DataConversionHelper(final LeapFrameBuilderSupport support, final LeapFra
* @return List of Objects, where each Object correspond to one feature of the input data
* @throws IOException, if there is an exception thrown in the try-with-resources block
*/
public List<Object> convertCsvToObjectList(final String csvInput, final DataSchema schema) throws IOException {
public List<List<Object>> convertCsvToObjectList(final String csvInput, final DataSchema schema) throws IOException {
try (final StringReader sr = new StringReader(csvInput)) {
final List<Object> valueList = Lists.newArrayList();
final CSVParser parser = CSVFormat.DEFAULT.parse(sr);
// We don not supporting multiple CSV lines as input currently
final CSVRecord record = parser.getRecords().get(0);
final List<CSVRecord> records = parser.getRecords();
final int inputLength = schema.getInput().size();
for (int idx = 0; idx < inputLength; ++idx) {
ColumnSchema sc = schema.getInput().get(idx);
// For CSV input, each value is treated as an individual feature by default
valueList.add(this.convertInputDataToJavaType(sc.getType(), DataStructureType.BASIC, record.get(idx)));

final List<List<Object>> returnList = Lists.newArrayList();

for(CSVRecord record : records) {
final List<Object> valueList = Lists.newArrayList();
for (int idx = 0; idx < inputLength; ++idx) {
ColumnSchema sc = schema.getInput().get(idx);
// For CSV input, each value is treated as an individual feature by default
valueList.add(this.convertInputDataToJavaType(sc.getType(), DataStructureType.BASIC, record.get(idx)));
}
returnList.add(valueList);
}
return valueList;

return returnList;
}
}

Expand All @@ -91,38 +98,52 @@ public List<Object> convertCsvToObjectList(final String csvInput, final DataSche
* Convert input object to DefaultLeapFrame
*
* @param schema, the input schema received from request or environment variable
* @param data , the input data received from request as a list of objects
* @param datas , the input datas received from request as a list of objects
* @return the DefaultLeapFrame object which MLeap transformer expects
*/
public DefaultLeapFrame convertInputToLeapFrame(final DataSchema schema, final List<Object> data) {
public DefaultLeapFrame convertInputToLeapFrame(final DataSchema schema, final List<List<Object>> datas) {

final int inputLength = schema.getInput().size();
final List<StructField> structFieldList = Lists.newArrayList();
final List<Object> valueList = Lists.newArrayList();

for (int idx = 0; idx < inputLength; ++idx) {
ColumnSchema sc = schema.getInput().get(idx);
structFieldList
.add(new StructField(sc.getName(), this.convertInputToMLeapInputType(sc.getType(), sc.getStruct())));
valueList.add(this.convertInputDataToJavaType(sc.getType(), sc.getStruct(), data.get(idx)));
.add(new StructField(sc.getName(), this.convertInputToMLeapInputType(sc.getType(), sc.getStruct())));
}

final StructType mleapSchema = leapFrameBuilder.createSchema(structFieldList);
final Row currentRow = support.createRowFromIterable(valueList);

final List<Row> rows = Lists.newArrayList();
rows.add(currentRow);

for(Object data : datas)
{
final Row currentRow = getRow(schema, (List) data, inputLength);

rows.add(currentRow);
}

return leapFrameBuilder.createFrame(mleapSchema, rows);
}

private Row getRow(DataSchema schema, List<Object> data, int inputLength) {
final List<Object> valueList = Lists.newArrayList();

for (int idx = 0; idx < inputLength; ++idx) {
ColumnSchema sc = schema.getInput().get(idx);
valueList.add(this.convertInputDataToJavaType(sc.getType(), sc.getStruct(), data.get(idx)));
}

return support.createRowFromIterable(valueList);
}

/**
* Convert basic types in the MLeap helper to Java types for output.
*
* @param predictionRow, the ArrayRow from MLeapResponse
* @param type, the basic type to which the helper should be casted, provided by user via input
* @return the proper Java type
*/
public Object convertMLeapBasicTypeToJavaType(final ArrayRow predictionRow, final String type) {
public Object convertMLeapBasicTypeToJavaType(final Row predictionRow, final String type) {
switch (type) {
case BasicDataType.INTEGER:
return predictionRow.getInt(0);
Expand Down
Loading