Link Search Menu Expand Document Documentation Menu

批量预测

ML Commons 可以在离线异步模式下,使用部署在外部模型服务器上的模型,对大型数据集执行推理。要使用批量预测 API,您必须提供外部托管模型的 model_id。目前,Amazon SageMaker、Cohere 和 OpenAI 是唯一支持此 API 的经过验证的外部服务器。

有关此 API 的用户访问信息,请参阅模型访问控制注意事项

有关外部托管模型的信息,请参阅连接到外部托管模型

有关如何设置批量推理和连接器蓝图的说明,请参阅以下内容

端点

POST /_plugins/_ml/models/<model_id>/_batch_predict

先决条件

在使用批量预测 API 之前,您需要创建与外部托管模型的连接器。对于每个操作,请指定描述该操作的 action_type 参数

  • batch_predict: 运行批量预测操作。
  • batch_predict_status: 检查批量预测操作状态。
  • cancel_batch_predict: 取消批量预测操作。

例如,要创建与 OpenAI text-embedding-ada-002 模型的连接器,请发送以下请求。cancel_batch_predict 操作是可选的,并支持取消在 OpenAI 上运行的批量作业

POST /_plugins/_ml/connectors/_create
{
  "name": "OpenAI Embedding model",
  "description": "OpenAI embedding model for testing offline batch",
  "version": "1",
  "protocol": "http",
  "parameters": {
    "model": "text-embedding-ada-002",
    "input_file_id": "<your input file id in OpenAI>",
    "endpoint": "/v1/embeddings"
  },
  "credential": {
    "openAI_key": "<your openAI key>"
  },
  "actions": [
    {
      "action_type": "predict",
      "method": "POST",
      "url": "https://api.openai.com/v1/embeddings",
      "headers": {
        "Authorization": "Bearer ${credential.openAI_key}"
      },
      "request_body": "{ \"input\": ${parameters.input}, \"model\": \"${parameters.model}\" }",
      "pre_process_function": "connector.pre_process.openai.embedding",
      "post_process_function": "connector.post_process.openai.embedding"
    },
    {
      "action_type": "batch_predict",
      "method": "POST",
      "url": "https://api.openai.com/v1/batches",
      "headers": {
        "Authorization": "Bearer ${credential.openAI_key}"
      },
      "request_body": "{ \"input_file_id\": \"${parameters.input_file_id}\", \"endpoint\": \"${parameters.endpoint}\", \"completion_window\": \"24h\" }"
    },
    {
      "action_type": "batch_predict_status",
      "method": "GET",
      "url": "https://api.openai.com/v1/batches/${parameters.id}",
      "headers": {
        "Authorization": "Bearer ${credential.openAI_key}"
      }
    },
    {
      "action_type": "cancel_batch_predict",
      "method": "POST",
      "url": "https://api.openai.com/v1/batches/${parameters.id}/cancel",
      "headers": {
        "Authorization": "Bearer ${credential.openAI_key}"
      }
    }
  ]
}

响应包含您将在后续步骤中使用的连接器 ID

{
  "connector_id": "XU5UiokBpXT9icfOM0vt"
}

接下来,注册一个外部托管模型,并提供所创建连接器的连接器 ID

POST /_plugins/_ml/models/_register?deploy=true
{
    "name": "OpenAI model for realtime embedding and offline batch inference",
    "function_name": "remote",
    "description": "OpenAI text embedding model",
    "connector_id": "XU5UiokBpXT9icfOM0vt"
}

响应包含注册操作的任务 ID

{
  "task_id": "rMormY8B8aiZvtEZIO_j",
  "status": "CREATED",
  "model_id": "lyjxwZABNrAVdFa9zrcZ"
}

要检查操作状态,请将任务 ID 提供给 任务 API。注册完成后,任务的 state 将变为 COMPLETED

请求示例

完成先决条件步骤后,您可以调用批量预测 API。批量预测请求中的参数将覆盖连接器中定义的参数

POST /_plugins/_ml/models/lyjxwZABNrAVdFa9zrcZ/_batch_predict
{
  "parameters": {
    "model": "text-embedding-3-large"
  }
}

示例响应

响应包含批量预测操作的任务 ID

{
  "task_id": "KYZSv5EBqL2d0mFvs80C",
  "status": "CREATED"
}

要检查批量预测作业的状态,请将任务 ID 提供给 任务 API。您可以在任务的 remote_job 字段中找到作业详细信息。预测完成后,任务的 state 将变为 COMPLETED

请求示例

GET /_plugins/_ml/tasks/KYZSv5EBqL2d0mFvs80C

示例响应

响应在 remote_job 字段中包含批量预测操作的详细信息

{
  "model_id": "JYZRv5EBqL2d0mFvKs1E",
  "task_type": "BATCH_PREDICTION",
  "function_name": "REMOTE",
  "state": "RUNNING",
  "input_type": "REMOTE",
  "worker_node": [
    "Ee5OCIq0RAy05hqQsNI1rg"
  ],
  "create_time": 1725491751455,
  "last_update_time": 1725491751455,
  "is_async": false,
  "remote_job": {
    "cancelled_at": null,
    "metadata": null,
    "request_counts": {
      "total": 3,
      "completed": 3,
      "failed": 0
    },
    "input_file_id": "file-XXXXXXXXXXXX",
    "output_file_id": "file-XXXXXXXXXXXXX",
    "error_file_id": null,
    "created_at": 1725491753,
    "in_progress_at": 1725491753,
    "expired_at": null,
    "finalizing_at": 1725491757,
    "completed_at": null,
    "endpoint": "/v1/embeddings",
    "expires_at": 1725578153,
    "cancelling_at": null,
    "completion_window": "24h",
    "id": "batch_XXXXXXXXXXXXXXX",
    "failed_at": null,
    "errors": null,
    "object": "batch",
    "status": "in_progress"
  }
}

有关结果中每个字段的定义,请参阅 OpenAI 批量 API。批量推理完成后,您可以通过调用 OpenAI 文件 API 并提供响应的 id 字段中指定的文件名来下载输出。

取消批量预测作业

您还可以使用批量预测请求返回的任务 ID 来取消在远程平台上运行的批量预测操作。要添加此功能,请在创建连接器时,在连接器配置中将 action_type 设置为 cancel_batch_predict

请求示例

POST /_plugins/_ml/tasks/KYZSv5EBqL2d0mFvs80C/_cancel_batch

示例响应

{
  "status": "OK"
}
剩余 350 字符

有问题?

想贡献?