SageMaker AI examples using SDK for JavaScript (v3) (original) (raw)
The following file excerpt contains functions that use the SageMaker AI client to manage a pipeline.
import { readFileSync } from "node:fs";
import {
CreateRoleCommand,
DeleteRoleCommand,
CreatePolicyCommand,
DeletePolicyCommand,
AttachRolePolicyCommand,
DetachRolePolicyCommand,
GetRoleCommand,
ListPoliciesCommand,
} from "@aws-sdk/client-iam";
import {
PublishLayerVersionCommand,
DeleteLayerVersionCommand,
CreateFunctionCommand,
Runtime,
DeleteFunctionCommand,
CreateEventSourceMappingCommand,
DeleteEventSourceMappingCommand,
GetFunctionCommand,
} from "@aws-sdk/client-lambda";
import {
PutObjectCommand,
CreateBucketCommand,
DeleteBucketCommand,
DeleteObjectCommand,
GetObjectCommand,
ListObjectsV2Command,
} from "@aws-sdk/client-s3";
import {
CreatePipelineCommand,
DeletePipelineCommand,
DescribePipelineCommand,
DescribePipelineExecutionCommand,
PipelineExecutionStatus,
StartPipelineExecutionCommand,
} from "@aws-sdk/client-sagemaker";
import { VectorEnrichmentJobDocumentType } from "@aws-sdk/client-sagemaker-geospatial";
import {
CreateQueueCommand,
DeleteQueueCommand,
GetQueueAttributesCommand,
GetQueueUrlCommand,
} from "@aws-sdk/client-sqs";
import { dirnameFromMetaUrl } from "@aws-doc-sdk-examples/lib/utils/util-fs.js";
import { retry } from "@aws-doc-sdk-examples/lib/utils/util-timers.js";
/**
* Create the AWS IAM role that will be assumed by AWS Lambda.
* @param {{ name: string, iamClient: import('@aws-sdk/client-iam').IAMClient }} props
*/
export async function createLambdaExecutionRole({ name, iamClient }) {
const createRole = () =>
iamClient.send(
new CreateRoleCommand({
RoleName: name,
AssumeRolePolicyDocument: JSON.stringify({
Version: "2012-10-17",
Statement: [
{
Effect: "Allow",
Action: ["sts:AssumeRole"],
Principal: { Service: ["lambda.amazonaws.com"] },
},
],
}),
}),
);
let role = null;
try {
const { Role } = await createRole();
role = Role;
} catch (caught) {
if (
caught instanceof Error &&
caught.name === "EntityAlreadyExistsException"
) {
const { Role } = await iamClient.send(
new GetRoleCommand({ RoleName: name }),
);
role = Role;
} else {
throw caught;
}
}
return {
arn: role.Arn,
cleanUp: async () => {
await iamClient.send(new DeleteRoleCommand({ RoleName: name }));
},
};
}
/**
* Create an AWS IAM policy that will be attached to the AWS IAM role assumed by the AWS Lambda function.
* The policy grants permission to work with Amazon SQS, Amazon CloudWatch, and Amazon SageMaker.
* @param {{name: string, iamClient: import('@aws-sdk/client-iam').IAMClient, pipelineExecutionRoleArn: string}} props
*/
export async function createLambdaExecutionPolicy({
name,
iamClient,
pipelineExecutionRoleArn,
}) {
const policyConfig = {
Version: "2012-10-17",
Statement: [
{
Effect: "Allow",
Action: [
"sqs:ReceiveMessage",
"sqs:DeleteMessage",
"sqs:GetQueueAttributes",
"logs:CreateLogGroup",
"logs:CreateLogStream",
"logs:PutLogEvents",
"sagemaker-geospatial:StartVectorEnrichmentJob",
"sagemaker-geospatial:GetVectorEnrichmentJob",
"sagemaker:SendPipelineExecutionStepFailure",
"sagemaker:SendPipelineExecutionStepSuccess",
"sagemaker-geospatial:ExportVectorEnrichmentJob",
],
Resource: "*",
},
{
Effect: "Allow",
// The AWS Lambda function needs permission to pass the pipeline execution role to
// the StartVectorEnrichmentCommand. This restriction prevents an AWS Lambda function
// from elevating privileges. For more information, see:
// https://docs.aws.amazon.com/IAM/latest/UserGuide/id_roles_use_passrole.html
Action: ["iam:PassRole"],
Resource: `${pipelineExecutionRoleArn}`,
Condition: {
StringEquals: {
"iam:PassedToService": [
"sagemaker.amazonaws.com",
"sagemaker-geospatial.amazonaws.com",
],
},
},
},
],
};
const createPolicy = () =>
iamClient.send(
new CreatePolicyCommand({
PolicyDocument: JSON.stringify(policyConfig),
PolicyName: name,
}),
);
let policy = null;
try {
const { Policy } = await createPolicy();
policy = Policy;
} catch (caught) {
if (
caught instanceof Error &&
caught.name === "EntityAlreadyExistsException"
) {
const { Policies } = await iamClient.send(new ListPoliciesCommand({}));
if (Policies) {
policy = Policies.find((p) => p.PolicyName === name);
} else {
throw new Error("No policies found.");
}
} else {
throw caught;
}
}
return {
arn: policy?.Arn,
policyConfig,
cleanUp: async () => {
await iamClient.send(new DeletePolicyCommand({ PolicyArn: policy?.Arn }));
},
};
}
/**
* Attach an AWS IAM policy to an AWS IAM role.
* @param {{roleName: string, policyArn: string, iamClient: import('@aws-sdk/client-iam').IAMClient}} props
*/
export async function attachPolicy({ roleName, policyArn, iamClient }) {
const attachPolicyCommand = new AttachRolePolicyCommand({
RoleName: roleName,
PolicyArn: policyArn,
});
await iamClient.send(attachPolicyCommand);
return {
cleanUp: async () => {
await iamClient.send(
new DetachRolePolicyCommand({
RoleName: roleName,
PolicyArn: policyArn,
}),
);
},
};
}
/**
* Create an AWS Lambda layer that contains the Amazon SageMaker and Amazon SageMaker Geospatial clients
* in the runtime. The default runtime supports v3.188.0 of the JavaScript SDK. The Amazon SageMaker
* Geospatial client wasn't introduced until v3.221.0.
* @param {{ name: string, lambdaClient: import('@aws-sdk/client-lambda').LambdaClient }} props
*/
export async function createLambdaLayer({ name, lambdaClient }) {
const layerPath = `${dirnameFromMetaUrl(import.meta.url)}lambda/nodejs.zip`;
const { LayerVersionArn, Version } = await lambdaClient.send(
new PublishLayerVersionCommand({
LayerName: name,
Content: {
ZipFile: Uint8Array.from(readFileSync(layerPath)),
},
}),
);
return {
versionArn: LayerVersionArn,
version: Version,
cleanUp: async () => {
await lambdaClient.send(
new DeleteLayerVersionCommand({
LayerName: name,
VersionNumber: Version,
}),
);
},
};
}
/**
* Deploy the AWS Lambda function that will be used to respond to Amazon SageMaker pipeline
* execution steps.
* @param {{roleArn: string, name: string, lambdaClient: import('@aws-sdk/client-lambda').LambdaClient, layerVersionArn: string}} props
*/
export async function createLambdaFunction({
name,
roleArn,
lambdaClient,
layerVersionArn,
}) {
const lambdaPath = `${dirnameFromMetaUrl(
import.meta.url,
)}lambda/dist/index.mjs.zip`;
// If a function of the same name already exists, return that
// function's ARN instead. By default this is
// "sagemaker-wkflw-lambda-function", so collisions are
// unlikely.
const createFunction = async () => {
try {
return await lambdaClient.send(
new CreateFunctionCommand({
Code: {
ZipFile: Uint8Array.from(readFileSync(lambdaPath)),
},
Runtime: Runtime.nodejs18x,
Handler: "index.handler",
Layers: [layerVersionArn],
FunctionName: name,
Role: roleArn,
}),
);
} catch (caught) {
if (
caught instanceof Error &&
caught.name === "ResourceConflictException"
) {
const { Configuration } = await lambdaClient.send(
new GetFunctionCommand({ FunctionName: name }),
);
return Configuration;
}
throw caught;
}
};
// Function creation fails if the Role is not ready. This retries
// function creation until it succeeds or it times out.
const { FunctionArn } = await retry(
{ intervalInMs: 1000, maxRetries: 60 },
createFunction,
);
return {
arn: FunctionArn,
cleanUp: async () => {
await lambdaClient.send(
new DeleteFunctionCommand({ FunctionName: name }),
);
},
};
}
/**
* This uploads some sample coordinate data to an Amazon S3 bucket.
* The Amazon SageMaker Geospatial vector enrichment job will take the simple Lat/Long
* coordinates in this file and augment them with more detailed location data.
* @param {{bucketName: string, s3Client: import('@aws-sdk/client-s3').S3Client}} props
*/
export async function uploadCSVDataToS3({ bucketName, s3Client }) {
const s3Path = `${dirnameFromMetaUrl(
import.meta.url,
)}../../../../../scenarios/features/sagemaker_pipelines/resources/latlongtest.csv`;
await s3Client.send(
new PutObjectCommand({
Bucket: bucketName,
Key: "input/sample_data.csv",
Body: readFileSync(s3Path),
}),
);
}
/**
* Create the AWS IAM role that will be assumed by the Amazon SageMaker pipeline.
* @param {{name: string, iamClient: import('@aws-sdk/client-iam').IAMClient, wait: (ms: number) => Promise<void>}} props
*/
export async function createSagemakerRole({ name, iamClient, wait }) {
let role = null;
const createRole = () =>
iamClient.send(
new CreateRoleCommand({
RoleName: name,
AssumeRolePolicyDocument: JSON.stringify({
Version: "2012-10-17",
Statement: [
{
Effect: "Allow",
Action: ["sts:AssumeRole"],
Principal: {
Service: [
"sagemaker.amazonaws.com",
"sagemaker-geospatial.amazonaws.com",
],
},
},
],
}),
}),
);
try {
const { Role } = await createRole();
role = Role;
// Wait for the role to be ready.
await wait(10);
} catch (caught) {
if (
caught instanceof Error &&
caught.name === "EntityAlreadyExistsException"
) {
const { Role } = await iamClient.send(
new GetRoleCommand({ RoleName: name }),
);
role = Role;
} else {
throw caught;
}
}
return {
arn: role.Arn,
cleanUp: async () => {
await iamClient.send(new DeleteRoleCommand({ RoleName: name }));
},
};
}
/**
* Create the Amazon SageMaker execution policy. This policy grants permission to
* invoke the AWS Lambda function, read/write to the Amazon S3 bucket, and send messages to
* the Amazon SQS queue.
* @param {{ name: string, sqsQueueArn: string, lambdaArn: string, iamClient: import('@aws-sdk/client-iam').IAMClient, s3BucketName: string}} props
*/
export async function createSagemakerExecutionPolicy({
sqsQueueArn,
lambdaArn,
iamClient,
name,
s3BucketName,
}) {
const policyConfig = {
Version: "2012-10-17",
Statement: [
{
Effect: "Allow",
Action: ["lambda:InvokeFunction"],
Resource: lambdaArn,
},
{
Effect: "Allow",
Action: ["s3:*"],
Resource: [
`arn:aws:s3:::${s3BucketName}`,
`arn:aws:s3:::${s3BucketName}/*`,
],
},
{
Effect: "Allow",
Action: ["sqs:SendMessage"],
Resource: sqsQueueArn,
},
],
};
const createPolicy = () =>
iamClient.send(
new CreatePolicyCommand({
PolicyDocument: JSON.stringify(policyConfig),
PolicyName: name,
}),
);
let policy = null;
try {
const { Policy } = await createPolicy();
policy = Policy;
} catch (caught) {
if (
caught instanceof Error &&
caught.name === "EntityAlreadyExistsException"
) {
const { Policies } = await iamClient.send(new ListPoliciesCommand({}));
if (Policies) {
policy = Policies.find((p) => p.PolicyName === name);
} else {
throw new Error("No policies found.");
}
} else {
throw caught;
}
}
return {
arn: policy?.Arn,
policyConfig,
cleanUp: async () => {
await iamClient.send(new DeletePolicyCommand({ PolicyArn: policy?.Arn }));
},
};
}
/**
* Create the Amazon SageMaker pipeline using a JSON pipeline definition. The definition
* can also be provided as an Amazon S3 object using PipelineDefinitionS3Location.
* @param {{roleArn: string, name: string, sagemakerClient: import('@aws-sdk/client-sagemaker').SageMakerClient}} props
*/
export async function createSagemakerPipeline({
// Assumes an AWS IAM role has been created for this pipeline.
roleArn,
name,
// Assumes an AWS Lambda function has been created for this pipeline.
functionArn,
sagemakerClient,
}) {
const pipelineDefinition = readFileSync(
// dirnameFromMetaUrl is a local utility function. You can find its implementation
// on GitHub.
`${dirnameFromMetaUrl(
import.meta.url,
)}../../../../../scenarios/features/sagemaker_pipelines/resources/GeoSpatialPipeline.json`,
)
.toString()
.replace(/\*FUNCTION_ARN\*/g, functionArn);
let arn = null;
const createPipeline = () =>
sagemakerClient.send(
new CreatePipelineCommand({
PipelineName: name,
PipelineDefinition: pipelineDefinition,
RoleArn: roleArn,
}),
);
try {
const { PipelineArn } = await createPipeline();
arn = PipelineArn;
} catch (caught) {
if (
caught instanceof Error &&
caught.name === "ValidationException" &&
caught.message.includes(
"Pipeline names must be unique within an AWS account and region",
)
) {
const { PipelineArn } = await sagemakerClient.send(
new DescribePipelineCommand({ PipelineName: name }),
);
arn = PipelineArn;
} else {
throw caught;
}
}
return {
arn,
cleanUp: async () => {
await sagemakerClient.send(
new DeletePipelineCommand({ PipelineName: name }),
);
},
};
}
/**
* Create an Amazon SQS queue. The Amazon SageMaker pipeline will send messages
* to this queue that are then processed by the AWS Lambda function.
* @param {{name: string, sqsClient: import('@aws-sdk/client-sqs').SQSClient}} props
*/
export async function createSQSQueue({ name, sqsClient }) {
const createSqsQueue = () =>
sqsClient.send(
new CreateQueueCommand({
QueueName: name,
Attributes: {
DelaySeconds: "5",
ReceiveMessageWaitTimeSeconds: "5",
VisibilityTimeout: "300",
},
}),
);
let queueUrl = null;
try {
const { QueueUrl } = await createSqsQueue();
queueUrl = QueueUrl;
} catch (caught) {
if (caught instanceof Error && caught.name === "QueueNameExists") {
const { QueueUrl } = await sqsClient.send(
new GetQueueUrlCommand({ QueueName: name }),
);
queueUrl = QueueUrl;
} else {
throw caught;
}
}
const { Attributes } = await retry(
{ intervalInMs: 1000, maxRetries: 60 },
() =>
sqsClient.send(
new GetQueueAttributesCommand({
QueueUrl: queueUrl,
AttributeNames: ["QueueArn"],
}),
),
);
return {
queueUrl,
queueArn: Attributes.QueueArn,
cleanUp: async () => {
await sqsClient.send(new DeleteQueueCommand({ QueueUrl: queueUrl }));
},
};
}
/**
* Configure the AWS Lambda function to long poll for messages from the Amazon SQS
* queue.
* @param {{
* paginateListEventSourceMappings: () => Generator<import('@aws-sdk/client-lambda').ListEventSourceMappingsCommandOutput>,
* lambdaName: string,
* queueArn: string,
* lambdaClient: import('@aws-sdk/client-lambda').LambdaClient}} props
*/
export async function configureLambdaSQSEventSource({
lambdaName,
queueArn,
lambdaClient,
paginateListEventSourceMappings,
}) {
let uuid = null;
const createEvenSourceMapping = () =>
lambdaClient.send(
new CreateEventSourceMappingCommand({
EventSourceArn: queueArn,
FunctionName: lambdaName,
}),
);
try {
const { UUID } = await createEvenSourceMapping();
uuid = UUID;
} catch (caught) {
if (
caught instanceof Error &&
caught.name === "ResourceConflictException"
) {
const paginator = paginateListEventSourceMappings(
{ client: lambdaClient },
{},
);
/**
* @type {import('@aws-sdk/client-lambda').EventSourceMappingConfiguration[]}
*/
const eventSourceMappings = [];
for await (const page of paginator) {
eventSourceMappings.concat(page.EventSourceMappings || []);
}
const { Configuration } = await lambdaClient.send(
new GetFunctionCommand({ FunctionName: lambdaName }),
);
uuid = eventSourceMappings.find(
(mapping) =>
mapping.EventSourceArn === queueArn &&
mapping.FunctionArn === Configuration.FunctionArn,
).UUID;
} else {
throw caught;
}
}
return {
cleanUp: async () => {
await lambdaClient.send(
new DeleteEventSourceMappingCommand({
UUID: uuid,
}),
);
},
};
}
/**
* Create an Amazon S3 bucket that will store the simple coordinate file as input
* and the output of the Amazon SageMaker Geospatial vector enrichment job.
* @param {{
* s3Client: import('@aws-sdk/client-s3').S3Client,
* name: string,
* paginateListObjectsV2: () => Generator<import('@aws-sdk/client-s3').ListObjectsCommandOutput>
* }} props
*/
export async function createS3Bucket({
name,
s3Client,
paginateListObjectsV2,
}) {
await s3Client.send(new CreateBucketCommand({ Bucket: name }));
return {
cleanUp: async () => {
const paginator = paginateListObjectsV2(
{ client: s3Client },
{ Bucket: name },
);
for await (const page of paginator) {
const objects = page.Contents;
if (objects) {
for (const object of objects) {
await s3Client.send(
new DeleteObjectCommand({ Bucket: name, Key: object.Key }),
);
}
}
}
await s3Client.send(new DeleteBucketCommand({ Bucket: name }));
},
};
}
/**
* Start the execution of the Amazon SageMaker pipeline. Parameters that are
* passed in are used in the AWS Lambda function.
* @param {{
* name: string,
* sagemakerClient: import('@aws-sdk/client-sagemaker').SageMakerClient,
* roleArn: string,
* queueUrl: string,
* s3InputBucketName: string,
* }} props
*/
export async function startPipelineExecution({
sagemakerClient,
name,
bucketName,
roleArn,
queueUrl,
}) {
/**
* The Vector Enrichment Job requests CSV data. This configuration points to a CSV
* file in an Amazon S3 bucket.
* @type {import("@aws-sdk/client-sagemaker-geospatial").VectorEnrichmentJobInputConfig}
*/
const inputConfig = {
DataSourceConfig: {
S3Data: {
S3Uri: `s3://${bucketName}/input/sample_data.csv`,
},
},
DocumentType: VectorEnrichmentJobDocumentType.CSV,
};
/**
* The Vector Enrichment Job adds additional data to the source CSV. This configuration points
* to an Amazon S3 prefix where the output will be stored.
* @type {import("@aws-sdk/client-sagemaker-geospatial").ExportVectorEnrichmentJobOutputConfig}
*/
const outputConfig = {
S3Data: {
S3Uri: `s3://${bucketName}/output/`,
},
};
/**
* This job will be a Reverse Geocoding Vector Enrichment Job. Reverse Geocoding requires
* latitude and longitude values.
* @type {import("@aws-sdk/client-sagemaker-geospatial").VectorEnrichmentJobConfig}
*/
const jobConfig = {
ReverseGeocodingConfig: {
XAttributeName: "Longitude",
YAttributeName: "Latitude",
},
};
const { PipelineExecutionArn } = await sagemakerClient.send(
new StartPipelineExecutionCommand({
PipelineName: name,
PipelineExecutionDisplayName: `${name}-example-execution`,
PipelineParameters: [
{ Name: "parameter_execution_role", Value: roleArn },
{ Name: "parameter_queue_url", Value: queueUrl },
{
Name: "parameter_vej_input_config",
Value: JSON.stringify(inputConfig),
},
{
Name: "parameter_vej_export_config",
Value: JSON.stringify(outputConfig),
},
{
Name: "parameter_step_1_vej_config",
Value: JSON.stringify(jobConfig),
},
],
}),
);
return {
arn: PipelineExecutionArn,
};
}
/**
* Poll the executing pipeline until the status is 'SUCCEEDED', 'STOPPED', or 'FAILED'.
* @param {{ arn: string, sagemakerClient: import('@aws-sdk/client-sagemaker').SageMakerClient, wait: (ms: number) => Promise<void>}} props
*/
export async function waitForPipelineComplete({ arn, sagemakerClient, wait }) {
const command = new DescribePipelineExecutionCommand({
PipelineExecutionArn: arn,
});
let complete = false;
const intervalInSeconds = 15;
const COMPLETION_STATUSES = [
PipelineExecutionStatus.FAILED,
PipelineExecutionStatus.STOPPED,
PipelineExecutionStatus.SUCCEEDED,
];
do {
const { PipelineExecutionStatus: status, FailureReason } =
await sagemakerClient.send(command);
complete = COMPLETION_STATUSES.includes(status);
if (!complete) {
console.log(
`Pipeline is <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mrow><mi>s</mi><mi>t</mi><mi>a</mi><mi>t</mi><mi>u</mi><mi>s</mi></mrow><mi mathvariant="normal">.</mi><mi>W</mi><mi>a</mi><mi>i</mi><mi>t</mi><mi>i</mi><mi>n</mi><mi>g</mi></mrow><annotation encoding="application/x-tex">{status}. Waiting </annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.8778em;vertical-align:-0.1944em;"></span><span class="mord"><span class="mord mathnormal">s</span><span class="mord mathnormal">t</span><span class="mord mathnormal">a</span><span class="mord mathnormal">t</span><span class="mord mathnormal">u</span><span class="mord mathnormal">s</span></span><span class="mord">.</span><span class="mord mathnormal">Wai</span><span class="mord mathnormal">t</span><span class="mord mathnormal">in</span><span class="mord mathnormal" style="margin-right:0.03588em;">g</span></span></span></span>{intervalInSeconds} seconds before checking again.`,
);
await wait(intervalInSeconds);
} else if (status === PipelineExecutionStatus.FAILED) {
throw new Error(`Pipeline failed because: ${FailureReason}`);
} else if (status === PipelineExecutionStatus.STOPPED) {
throw new Error("Pipeline was forcefully stopped.");
} else {
console.log(`Pipeline execution ${status}.`);
}
} while (!complete);
}
/**
* Return the string value of an Amazon S3 object.
* @param {{ bucket: string, key: string, s3Client: import('@aws-sdk/client-s3').S3Client}} param0
*/
export async function getObject({ bucket, s3Client }) {
const prefix = "output/";
const { Contents } = await s3Client.send(
new ListObjectsV2Command({ MaxKeys: 1, Bucket: bucket, Prefix: prefix }),
);
if (!Contents.length) {
throw new Error("No objects found in bucket.");
}
// Find the CSV file.
const outputObject = Contents.find((obj) => obj.Key.endsWith(".csv"));
if (!outputObject) {
throw new Error(`No CSV file found in bucket with the prefix "${prefix}".`);
}
const { Body } = await s3Client.send(
new GetObjectCommand({
Bucket: bucket,
Key: outputObject.Key,
}),
);
return Body.transformToString();
}
This function is an excerpt from a file that uses the preceding library functions to set up a SageMaker AI pipeline, execute it, and delete all created resources.
import { retry, wait } from "@aws-doc-sdk-examples/lib/utils/util-timers.js";
import {
attachPolicy,
configureLambdaSQSEventSource,
createLambdaExecutionPolicy,
createLambdaExecutionRole,
createLambdaFunction,
createLambdaLayer,
createS3Bucket,
createSQSQueue,
createSagemakerExecutionPolicy,
createSagemakerPipeline,
createSagemakerRole,
getObject,
startPipelineExecution,
uploadCSVDataToS3,
waitForPipelineComplete,
} from "./lib.js";
import { MESSAGES } from "./messages.js";
export class SageMakerPipelinesWkflw {
names = {
LAMBDA_EXECUTION_ROLE: "sagemaker-wkflw-lambda-execution-role",
LAMBDA_EXECUTION_ROLE_POLICY:
"sagemaker-wkflw-lambda-execution-role-policy",
LAMBDA_FUNCTION: "sagemaker-wkflw-lambda-function",
LAMBDA_LAYER: "sagemaker-wkflw-lambda-layer",
SAGE_MAKER_EXECUTION_ROLE: "sagemaker-wkflw-pipeline-execution-role",
SAGE_MAKER_EXECUTION_ROLE_POLICY:
"sagemaker-wkflw-pipeline-execution-role-policy",
SAGE_MAKER_PIPELINE: "sagemaker-wkflw-pipeline",
SQS_QUEUE: "sagemaker-wkflw-sqs-queue",
S3_BUCKET: `sagemaker-wkflw-s3-bucket-${Date.now()}`,
};
cleanUpFunctions = [];
/**
* @param {import("@aws-doc-sdk-examples/lib/prompter.js").Prompter} prompter
* @param {import("@aws-doc-sdk-examples/lib/logger.js").Logger} logger
* @param {{ IAM: import("@aws-sdk/client-iam").IAMClient, Lambda: import("@aws-sdk/client-lambda").LambdaClient, SageMaker: import("@aws-sdk/client-sagemaker").SageMakerClient, S3: import("@aws-sdk/client-s3").S3Client, SQS: import("@aws-sdk/client-sqs").SQSClient }} clients
*/
constructor(prompter, logger, clients) {
this.prompter = prompter;
this.logger = logger;
this.clients = clients;
}
async run() {
try {
await this.startWorkflow();
} catch (err) {
console.error(err);
throw err;
} finally {
this.logger.logSeparator();
const doCleanUp = await this.prompter.confirm({
message: "Clean up resources?",
});
if (doCleanUp) {
await this.cleanUp();
}
}
}
async cleanUp() {
// Run all of the clean up functions. If any fail, we log the error and continue.
// This ensures all clean up functions are run.
for (let i = this.cleanUpFunctions.length - 1; i >= 0; i--) {
await retry(
{ intervalInMs: 1000, maxRetries: 60, swallowError: true },
this.cleanUpFunctions[i],
);
}
}
async startWorkflow() {
this.logger.logSeparator(MESSAGES.greetingHeader);
await this.logger.log(MESSAGES.greeting);
this.logger.logSeparator();
await this.logger.log(
MESSAGES.creatingRole.replace(
"${ROLE_NAME}",
this.names.LAMBDA_EXECUTION_ROLE,
),
);
// Create an IAM role that will be assumed by the AWS Lambda function. This function
// is triggered by Amazon SQS messages and calls SageMaker and SageMaker GeoSpatial actions.
const { arn: lambdaExecutionRoleArn, cleanUp: lambdaExecutionRoleCleanUp } =
await createLambdaExecutionRole({
name: this.names.LAMBDA_EXECUTION_ROLE,
iamClient: this.clients.IAM,
});
// Add a clean up step to a stack for every resource created.
this.cleanUpFunctions.push(lambdaExecutionRoleCleanUp);
await this.logger.log(
MESSAGES.roleCreated.replace(
"${ROLE_NAME}",
this.names.LAMBDA_EXECUTION_ROLE,
),
);
this.logger.logSeparator();
await this.logger.log(
MESSAGES.creatingRole.replace(
"${ROLE_NAME}",
this.names.SAGE_MAKER_EXECUTION_ROLE,
),
);
// Create an IAM role that will be assumed by the SageMaker pipeline. The pipeline
// sends messages to an Amazon SQS queue and puts/retrieves Amazon S3 objects.
const {
arn: pipelineExecutionRoleArn,
cleanUp: pipelineExecutionRoleCleanUp,
} = await createSagemakerRole({
iamClient: this.clients.IAM,
name: this.names.SAGE_MAKER_EXECUTION_ROLE,
wait,
});
this.cleanUpFunctions.push(pipelineExecutionRoleCleanUp);
await this.logger.log(
MESSAGES.roleCreated.replace(
"${ROLE_NAME}",
this.names.SAGE_MAKER_EXECUTION_ROLE,
),
);
this.logger.logSeparator();
// Create an IAM policy that allows the AWS Lambda function to invoke SageMaker APIs.
const {
arn: lambdaExecutionPolicyArn,
policy: lambdaPolicy,
cleanUp: lambdaExecutionPolicyCleanUp,
} = await createLambdaExecutionPolicy({
name: this.names.LAMBDA_EXECUTION_ROLE_POLICY,
s3BucketName: this.names.S3_BUCKET,
iamClient: this.clients.IAM,
pipelineExecutionRoleArn,
});
this.cleanUpFunctions.push(lambdaExecutionPolicyCleanUp);
console.log(JSON.stringify(lambdaPolicy, null, 2), "\n");
await this.logger.log(
MESSAGES.attachPolicy
.replace("${POLICY_NAME}", this.names.LAMBDA_EXECUTION_ROLE_POLICY)
.replace("${ROLE_NAME}", this.names.LAMBDA_EXECUTION_ROLE),
);
await this.prompter.checkContinue();
// Attach the Lambda execution policy to the execution role.
const { cleanUp: lambdaExecutionRolePolicyCleanUp } = await attachPolicy({
roleName: this.names.LAMBDA_EXECUTION_ROLE,
policyArn: lambdaExecutionPolicyArn,
iamClient: this.clients.IAM,
});
this.cleanUpFunctions.push(lambdaExecutionRolePolicyCleanUp);
await this.logger.log(MESSAGES.policyAttached);
this.logger.logSeparator();
// Create Lambda layer for SageMaker packages.
const { versionArn: layerVersionArn, cleanUp: lambdaLayerCleanUp } =
await createLambdaLayer({
name: this.names.LAMBDA_LAYER,
lambdaClient: this.clients.Lambda,
});
this.cleanUpFunctions.push(lambdaLayerCleanUp);
await this.logger.log(
MESSAGES.creatingFunction.replace(
"${FUNCTION_NAME}",
this.names.LAMBDA_FUNCTION,
),
);
// Create the Lambda function with the execution role.
const { arn: lambdaArn, cleanUp: lambdaCleanUp } =
await createLambdaFunction({
roleArn: lambdaExecutionRoleArn,
lambdaClient: this.clients.Lambda,
name: this.names.LAMBDA_FUNCTION,
layerVersionArn,
});
this.cleanUpFunctions.push(lambdaCleanUp);
await this.logger.log(
MESSAGES.functionCreated.replace(
"${FUNCTION_NAME}",
this.names.LAMBDA_FUNCTION,
),
);
this.logger.logSeparator();
await this.logger.log(
MESSAGES.creatingSQSQueue.replace("${QUEUE_NAME}", this.names.SQS_QUEUE),
);
// Create an SQS queue for the SageMaker pipeline.
const {
queueUrl,
queueArn,
cleanUp: queueCleanUp,
} = await createSQSQueue({
name: this.names.SQS_QUEUE,
sqsClient: this.clients.SQS,
});
this.cleanUpFunctions.push(queueCleanUp);
await this.logger.log(
MESSAGES.sqsQueueCreated.replace("${QUEUE_NAME}", this.names.SQS_QUEUE),
);
this.logger.logSeparator();
await this.logger.log(
MESSAGES.configuringLambdaSQSEventSource
.replace("${LAMBDA_NAME}", this.names.LAMBDA_FUNCTION)
.replace("${QUEUE_NAME}", this.names.SQS_QUEUE),
);
// Configure the SQS queue as an event source for the Lambda.
const { cleanUp: lambdaSQSEventSourceCleanUp } =
await configureLambdaSQSEventSource({
lambdaArn,
lambdaName: this.names.LAMBDA_FUNCTION,
queueArn,
sqsClient: this.clients.SQS,
lambdaClient: this.clients.Lambda,
});
this.cleanUpFunctions.push(lambdaSQSEventSourceCleanUp);
await this.logger.log(
MESSAGES.lambdaSQSEventSourceConfigured
.replace("${LAMBDA_NAME}", this.names.LAMBDA_FUNCTION)
.replace("${QUEUE_NAME}", this.names.SQS_QUEUE),
);
this.logger.logSeparator();
// Create an IAM policy that allows the SageMaker pipeline to invoke AWS Lambda
// and send messages to the Amazon SQS queue.
const {
arn: pipelineExecutionPolicyArn,
policy: sagemakerPolicy,
cleanUp: pipelineExecutionPolicyCleanUp,
} = await createSagemakerExecutionPolicy({
sqsQueueArn: queueArn,
lambdaArn,
iamClient: this.clients.IAM,
name: this.names.SAGE_MAKER_EXECUTION_ROLE_POLICY,
s3BucketName: this.names.S3_BUCKET,
});
this.cleanUpFunctions.push(pipelineExecutionPolicyCleanUp);
console.log(JSON.stringify(sagemakerPolicy, null, 2));
await this.logger.log(
MESSAGES.attachPolicy
.replace("${POLICY_NAME}", this.names.SAGE_MAKER_EXECUTION_ROLE_POLICY)
.replace("${ROLE_NAME}", this.names.SAGE_MAKER_EXECUTION_ROLE),
);
await this.prompter.checkContinue();
// Attach the SageMaker execution policy to the execution role.
const { cleanUp: pipelineExecutionRolePolicyCleanUp } = await attachPolicy({
roleName: this.names.SAGE_MAKER_EXECUTION_ROLE,
policyArn: pipelineExecutionPolicyArn,
iamClient: this.clients.IAM,
});
this.cleanUpFunctions.push(pipelineExecutionRolePolicyCleanUp);
// Wait for the role to be ready. If the role is used immediately,
// the pipeline will fail.
await wait(5);
await this.logger.log(MESSAGES.policyAttached);
this.logger.logSeparator();
await this.logger.log(
MESSAGES.creatingPipeline.replace(
"${PIPELINE_NAME}",
this.names.SAGE_MAKER_PIPELINE,
),
);
// Create the SageMaker pipeline.
const { cleanUp: pipelineCleanUp } = await createSagemakerPipeline({
roleArn: pipelineExecutionRoleArn,
functionArn: lambdaArn,
sagemakerClient: this.clients.SageMaker,
name: this.names.SAGE_MAKER_PIPELINE,
});
this.cleanUpFunctions.push(pipelineCleanUp);
await this.logger.log(
MESSAGES.pipelineCreated.replace(
"${PIPELINE_NAME}",
this.names.SAGE_MAKER_PIPELINE,
),
);
this.logger.logSeparator();
await this.logger.log(
MESSAGES.creatingS3Bucket.replace("${BUCKET_NAME}", this.names.S3_BUCKET),
);
// Create an S3 bucket for storing inputs and outputs.
const { cleanUp: s3BucketCleanUp } = await createS3Bucket({
name: this.names.S3_BUCKET,
s3Client: this.clients.S3,
});
this.cleanUpFunctions.push(s3BucketCleanUp);
await this.logger.log(
MESSAGES.s3BucketCreated.replace("${BUCKET_NAME}", this.names.S3_BUCKET),
);
this.logger.logSeparator();
await this.logger.log(
MESSAGES.uploadingInputData.replace(
"${BUCKET_NAME}",
this.names.S3_BUCKET,
),
);
// Upload CSV Lat/Long data to S3.
await uploadCSVDataToS3({
bucketName: this.names.S3_BUCKET,
s3Client: this.clients.S3,
});
await this.logger.log(MESSAGES.inputDataUploaded);
this.logger.logSeparator();
await this.prompter.checkContinue(MESSAGES.executePipeline);
// Execute the SageMaker pipeline.
const { arn: pipelineExecutionArn } = await startPipelineExecution({
name: this.names.SAGE_MAKER_PIPELINE,
sagemakerClient: this.clients.SageMaker,
roleArn: pipelineExecutionRoleArn,
bucketName: this.names.S3_BUCKET,
queueUrl,
});
// Wait for the pipeline execution to finish.
await waitForPipelineComplete({
arn: pipelineExecutionArn,
sagemakerClient: this.clients.SageMaker,
wait,
});
this.logger.logSeparator();
await this.logger.log(MESSAGES.outputDelay);
// The getOutput function will throw an error if the output is not
// found. The retry function will retry a failed function call once
// ever 10 seconds for 2 minutes.
const output = await retry({ intervalInMs: 10000, maxRetries: 12 }, () =>
getObject({
bucket: this.names.S3_BUCKET,
s3Client: this.clients.S3,
}),
);
this.logger.logSeparator();
await this.logger.log(MESSAGES.outputDataRetrieved);
console.log(output.split("\n").slice(0, 6).join("\n"));
}
}
- For API details, see the following topics in AWS SDK for JavaScript API Reference.