批量预测
ML Commons 可以在离线异步模式下,使用部署在外部模型服务器上的模型,对大型数据集执行推理。要使用批量预测 API,您必须提供外部托管模型的 model_id
。目前,Amazon SageMaker、Cohere 和 OpenAI 是唯一支持此 API 的经过验证的外部服务器。
有关此 API 的用户访问信息,请参阅模型访问控制注意事项。
有关外部托管模型的信息,请参阅连接到外部托管模型。
有关如何设置批量推理和连接器蓝图的说明,请参阅以下内容
端点
POST /_plugins/_ml/models/<model_id>/_batch_predict
先决条件
在使用批量预测 API 之前,您需要创建与外部托管模型的连接器。对于每个操作,请指定描述该操作的 action_type
参数
batch_predict
: 运行批量预测操作。batch_predict_status
: 检查批量预测操作状态。cancel_batch_predict
: 取消批量预测操作。
例如,要创建与 OpenAI text-embedding-ada-002
模型的连接器,请发送以下请求。cancel_batch_predict
操作是可选的,并支持取消在 OpenAI 上运行的批量作业
POST /_plugins/_ml/connectors/_create
{
"name": "OpenAI Embedding model",
"description": "OpenAI embedding model for testing offline batch",
"version": "1",
"protocol": "http",
"parameters": {
"model": "text-embedding-ada-002",
"input_file_id": "<your input file id in OpenAI>",
"endpoint": "/v1/embeddings"
},
"credential": {
"openAI_key": "<your openAI key>"
},
"actions": [
{
"action_type": "predict",
"method": "POST",
"url": "https://api.openai.com/v1/embeddings",
"headers": {
"Authorization": "Bearer ${credential.openAI_key}"
},
"request_body": "{ \"input\": ${parameters.input}, \"model\": \"${parameters.model}\" }",
"pre_process_function": "connector.pre_process.openai.embedding",
"post_process_function": "connector.post_process.openai.embedding"
},
{
"action_type": "batch_predict",
"method": "POST",
"url": "https://api.openai.com/v1/batches",
"headers": {
"Authorization": "Bearer ${credential.openAI_key}"
},
"request_body": "{ \"input_file_id\": \"${parameters.input_file_id}\", \"endpoint\": \"${parameters.endpoint}\", \"completion_window\": \"24h\" }"
},
{
"action_type": "batch_predict_status",
"method": "GET",
"url": "https://api.openai.com/v1/batches/${parameters.id}",
"headers": {
"Authorization": "Bearer ${credential.openAI_key}"
}
},
{
"action_type": "cancel_batch_predict",
"method": "POST",
"url": "https://api.openai.com/v1/batches/${parameters.id}/cancel",
"headers": {
"Authorization": "Bearer ${credential.openAI_key}"
}
}
]
}
响应包含您将在后续步骤中使用的连接器 ID
{
"connector_id": "XU5UiokBpXT9icfOM0vt"
}
接下来,注册一个外部托管模型,并提供所创建连接器的连接器 ID
POST /_plugins/_ml/models/_register?deploy=true
{
"name": "OpenAI model for realtime embedding and offline batch inference",
"function_name": "remote",
"description": "OpenAI text embedding model",
"connector_id": "XU5UiokBpXT9icfOM0vt"
}
响应包含注册操作的任务 ID
{
"task_id": "rMormY8B8aiZvtEZIO_j",
"status": "CREATED",
"model_id": "lyjxwZABNrAVdFa9zrcZ"
}
要检查操作状态,请将任务 ID 提供给 任务 API。注册完成后,任务的 state
将变为 COMPLETED
。
请求示例
完成先决条件步骤后,您可以调用批量预测 API。批量预测请求中的参数将覆盖连接器中定义的参数
POST /_plugins/_ml/models/lyjxwZABNrAVdFa9zrcZ/_batch_predict
{
"parameters": {
"model": "text-embedding-3-large"
}
}
示例响应
响应包含批量预测操作的任务 ID
{
"task_id": "KYZSv5EBqL2d0mFvs80C",
"status": "CREATED"
}
要检查批量预测作业的状态,请将任务 ID 提供给 任务 API。您可以在任务的 remote_job
字段中找到作业详细信息。预测完成后,任务的 state
将变为 COMPLETED
。
请求示例
GET /_plugins/_ml/tasks/KYZSv5EBqL2d0mFvs80C
示例响应
响应在 remote_job
字段中包含批量预测操作的详细信息
{
"model_id": "JYZRv5EBqL2d0mFvKs1E",
"task_type": "BATCH_PREDICTION",
"function_name": "REMOTE",
"state": "RUNNING",
"input_type": "REMOTE",
"worker_node": [
"Ee5OCIq0RAy05hqQsNI1rg"
],
"create_time": 1725491751455,
"last_update_time": 1725491751455,
"is_async": false,
"remote_job": {
"cancelled_at": null,
"metadata": null,
"request_counts": {
"total": 3,
"completed": 3,
"failed": 0
},
"input_file_id": "file-XXXXXXXXXXXX",
"output_file_id": "file-XXXXXXXXXXXXX",
"error_file_id": null,
"created_at": 1725491753,
"in_progress_at": 1725491753,
"expired_at": null,
"finalizing_at": 1725491757,
"completed_at": null,
"endpoint": "/v1/embeddings",
"expires_at": 1725578153,
"cancelling_at": null,
"completion_window": "24h",
"id": "batch_XXXXXXXXXXXXXXX",
"failed_at": null,
"errors": null,
"object": "batch",
"status": "in_progress"
}
}
有关结果中每个字段的定义,请参阅 OpenAI 批量 API。批量推理完成后,您可以通过调用 OpenAI 文件 API 并提供响应的 id
字段中指定的文件名来下载输出。
取消批量预测作业
您还可以使用批量预测请求返回的任务 ID 来取消在远程平台上运行的批量预测操作。要添加此功能,请在创建连接器时,在连接器配置中将 action_type
设置为 cancel_batch_predict
。
请求示例
POST /_plugins/_ml/tasks/KYZSv5EBqL2d0mFvs80C/_cancel_batch
示例响应
{
"status": "OK"
}