Link Search Menu Expand Document Documentation Menu

在 Amazon SageMaker 中使用交叉编码器重新排序搜索结果

一个重排序管道可以对搜索结果进行重排序,针对搜索查询为搜索结果中的每个文档提供相关性分数。相关性分数由交叉编码器模型计算。

本教程展示了如何在重排序管道中使用 Hugging Face ms-marco-MiniLM-L-6-v2 模型

将以 your_ 为前缀的占位符替换为您自己的值。

先决条件

在开始之前,请在 Amazon SageMaker 上部署模型。为了获得更好的性能,请使用 GPU。

运行以下代码以在 Amazon SageMaker 上部署模型

import sagemaker
import boto3
from sagemaker.huggingface import HuggingFaceModel
sess = sagemaker.Session()
role = sagemaker.get_execution_role()

hub = {
    'HF_MODEL_ID':'cross-encoder/ms-marco-MiniLM-L-6-v2',
    'HF_TASK':'text-classification'
}
huggingface_model = HuggingFaceModel(
    transformers_version='4.37.0',
    pytorch_version='2.1.0',
    py_version='py310',
    env=hub,
    role=role, 
)
predictor = huggingface_model.deploy(
    initial_instance_count=1, # number of instances
    instance_type='ml.m5.xlarge' # ec2 instance type
)

请记下模型推理端点;您将在下一步中使用它来创建连接器。

步骤 1:创建连接器并注册模型

要为模型创建连接器,请发送以下请求。

如果您使用自管理的 OpenSearch,请提供您的 AWS 凭证:

POST /_plugins/_ml/connectors/_create
{
  "name": "Sagemaker cross-encoder model",
  "description": "Test connector for Sagemaker cross-encoder model",
  "version": 1,
  "protocol": "aws_sigv4",
  "credential": {
    "access_key": "your_access_key",
    "secret_key": "your_secret_key",
    "session_token": "your_session_token"
  },
  "parameters": {
    "region": "your_sagemkaer_model_region_like_us-west-2",
    "service_name": "sagemaker"
  },
  "actions": [
    {
      "action_type": "predict",
      "method": "POST",
      "url": "your_sagemaker_model_inference_endpoint_created_in_last_step",
      "headers": {
        "content-type": "application/json"
      },
      "request_body": "{ \"inputs\": ${parameters.inputs} }",
      "pre_process_function": "\n    String escape(def input) { \n       if (input.contains(\"\\\\\")) {\n        input = input.replace(\"\\\\\", \"\\\\\\\\\");\n      }\n      if (input.contains(\"\\\"\")) {\n        input = input.replace(\"\\\"\", \"\\\\\\\"\");\n      }\n      if (input.contains('\r')) {\n        input = input = input.replace('\r', '\\\\r');\n      }\n      if (input.contains(\"\\\\t\")) {\n        input = input.replace(\"\\\\t\", \"\\\\\\\\\\\\t\");\n      }\n      if (input.contains('\n')) {\n        input = input.replace('\n', '\\\\n');\n      }\n      if (input.contains('\b')) {\n        input = input.replace('\b', '\\\\b');\n      }\n      if (input.contains('\f')) {\n        input = input.replace('\f', '\\\\f');\n      }\n      return input;\n    }\n\n   String query = params.query_text;\n   StringBuilder builder = new StringBuilder('[');\n    \n    for (int i=0; i<params.text_docs.length; i ++) {\n      builder.append('{\"text\":\"');\n      builder.append(escape(query));\n      builder.append('\", \"text_pair\":\"');\n      builder.append(escape(params.text_docs[i]));\n      builder.append('\"}');\n      if (i<params.text_docs.length - 1) {\n        builder.append(',');\n      }\n    }\n    builder.append(']');\n    \n    def parameters = '{ \"inputs\": ' + builder + ' }';\n    return  '{\"parameters\": ' + parameters + '}';\n     ",
      "post_process_function": "\n      \n      def dataType = \"FLOAT32\";\n      \n      \n      if (params.result == null)\n      {\n          return 'no result generated';\n          //return params.response;\n      }\n      def outputs = params.result;\n      \n      \n      def resultBuilder = new StringBuilder('[ ');\n      for (int i=0; i<outputs.length; i++) {\n        resultBuilder.append(' {\"name\": \"similarity\", \"data_type\": \"FLOAT32\", \"shape\": [1],');\n        //resultBuilder.append('{\"name\": \"similarity\"}');\n        \n        resultBuilder.append('\"data\": [');\n        resultBuilder.append(outputs[i].score);\n        resultBuilder.append(']}');\n        if (i<outputs.length - 1) {\n          resultBuilder.append(',');\n        }\n      }\n      resultBuilder.append(']');\n      \n      return resultBuilder.toString();\n    "
    }
  ]
}

如果您正在使用 Amazon OpenSearch 服务,您可以提供一个 AWS Identity and Access Management (IAM) 角色 Amazon 资源名称 (ARN),该名称允许访问 SageMaker 模型推理端点。

POST /_plugins/_ml/connectors/_create
{
  "name": "Sagemakre cross-encoder model",
  "description": "Test connector for Sagemaker cross-encoder model",
  "version": 1,
  "protocol": "aws_sigv4",
  "credential": {
    "roleArn": "your_role_arn_which_allows_access_to_sagemaker_model_inference_endpoint"
  },
  "parameters": {
    "region": "your_sagemkaer_model_region_like_us-west-2",
    "service_name": "sagemaker"
  },
  "actions": [
    {
      "action_type": "predict",
      "method": "POST",
      "url": "your_sagemaker_model_inference_endpoint_created_in_last_step",
      "headers": {
        "content-type": "application/json"
      },
      "request_body": "{ \"inputs\": ${parameters.inputs} }",
      "pre_process_function": "\n    String escape(def input) { \n       if (input.contains(\"\\\\\")) {\n        input = input.replace(\"\\\\\", \"\\\\\\\\\");\n      }\n      if (input.contains(\"\\\"\")) {\n        input = input.replace(\"\\\"\", \"\\\\\\\"\");\n      }\n      if (input.contains('\r')) {\n        input = input = input.replace('\r', '\\\\r');\n      }\n      if (input.contains(\"\\\\t\")) {\n        input = input.replace(\"\\\\t\", \"\\\\\\\\\\\\t\");\n      }\n      if (input.contains('\n')) {\n        input = input.replace('\n', '\\\\n');\n      }\n      if (input.contains('\b')) {\n        input = input.replace('\b', '\\\\b');\n      }\n      if (input.contains('\f')) {\n        input = input.replace('\f', '\\\\f');\n      }\n      return input;\n    }\n\n   String query = params.query_text;\n   StringBuilder builder = new StringBuilder('[');\n    \n    for (int i=0; i<params.text_docs.length; i ++) {\n      builder.append('{\"text\":\"');\n      builder.append(escape(query));\n      builder.append('\", \"text_pair\":\"');\n      builder.append(escape(params.text_docs[i]));\n      builder.append('\"}');\n      if (i<params.text_docs.length - 1) {\n        builder.append(',');\n      }\n    }\n    builder.append(']');\n    \n    def parameters = '{ \"inputs\": ' + builder + ' }';\n    return  '{\"parameters\": ' + parameters + '}';\n     ",
      "post_process_function": "\n      \n      def dataType = \"FLOAT32\";\n      \n      \n      if (params.result == null)\n      {\n          return 'no result generated';\n          //return params.response;\n      }\n      def outputs = params.result;\n      \n      \n      def resultBuilder = new StringBuilder('[ ');\n      for (int i=0; i<outputs.length; i++) {\n        resultBuilder.append(' {\"name\": \"similarity\", \"data_type\": \"FLOAT32\", \"shape\": [1],');\n        //resultBuilder.append('{\"name\": \"similarity\"}');\n        \n        resultBuilder.append('\"data\": [');\n        resultBuilder.append(outputs[i].score);\n        resultBuilder.append(']}');\n        if (i<outputs.length - 1) {\n          resultBuilder.append(',');\n        }\n      }\n      resultBuilder.append(']');\n      \n      return resultBuilder.toString();\n    "
    }
  ]
}

更多信息请参阅AWS 文档本教程AIConnectorHelper Notebook

使用响应中的连接器 ID 注册和部署模型:

POST /_plugins/_ml/models/_register?deploy=true
{
    "name": "Sagemaker Cross-Encoder model",
    "function_name": "remote",
    "description": "test rerank model",
    "connector_id": "your_connector_id"
}

记下响应中的模型 ID;您将在后续步骤中使用它。

要测试模型,请调用 Predict API

POST _plugins/_ml/models/your_model_id/_predict
{
  "parameters": {
    "inputs": [
      {
        "text": "I like you",
        "text_pair": "I hate you"
      },
      {
        "text": "I like you",
        "text_pair": "I love you"
      }
    ]
  }
}

inputs 数组中的每个项都包含一个 query_text 和一个 text_docs 字符串,由 ` . ` 分隔

或者,您可以按如下方式测试模型:

POST _plugins/_ml/_predict/text_similarity/your_model_id
{
  "query_text": "I like you",
  "text_docs": ["I hate you", "I love you"]
}

连接器的 pre_process_function 会将输入转换为上一个 Predict API 请求中显示的 inputs 参数所需的格式。

默认情况下,SageMaker 模型输出采用以下格式

[
  {
    "label": "LABEL_0",
    "score": 0.054037678986787796
  },
  {
    "label": "LABEL_0",
    "score": 0.5877784490585327
  }
]

连接器的 pre_process_function 会将模型输出转换为可由重排序处理器解释的以下格式

{
  "inference_results": [
    {
      "output": [
        {
          "name": "similarity",
          "data_type": "FLOAT32",
          "shape": [
            1
          ],
          "data": [
            0.054037678986787796
          ]
        },
        {
          "name": "similarity",
          "data_type": "FLOAT32",
          "shape": [
            1
          ],
          "data": [
            0.5877784490585327
          ]
        }
      ],
      "status_code": 200
    }
  ]
}

响应包含两个 similarity 输出。对于每个 similarity 输出,data 数组都包含针对查询的每个文档的相关性分数。这些 similarity 输出按照输入文档的顺序提供:第一个相似性结果对应第一个文档。

步骤 2:配置重排序流水线

按照以下步骤配置重排序流水线。

步骤 2.1:摄取测试数据

发送批量请求以摄入测试数据

POST _bulk
{ "index": { "_index": "my-test-data" } }
{ "passage_text" : "Carson City is the capital city of the American state of Nevada." }
{ "index": { "_index": "my-test-data" } }
{ "passage_text" : "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean. Its capital is Saipan." }
{ "index": { "_index": "my-test-data" } }
{ "passage_text" : "Washington, D.C. (also known as simply Washington or D.C., and officially as the District of Columbia) is the capital of the United States. It is a federal district." }
{ "index": { "_index": "my-test-data" } }
{ "passage_text" : "Capital punishment (the death penalty) has existed in the United States since beforethe United States was a country. As of 2017, capital punishment is legal in 30 of the 50 states." }

步骤 2.2:创建重新排序管道

使用 MS MARCO 交叉编码器模型创建重排序管道

PUT /_search/pipeline/rerank_pipeline_sagemaker
{
    "description": "Pipeline for reranking with Sagemaker cross-encoder model",
    "response_processors": [
        {
            "rerank": {
                "ml_opensearch": {
                    "model_id": "your_model_id_created_in_step1"
                },
                "context": {
                    "document_fields": ["passage_text"]
                }
            }
        }
    ]
}

如果您在 document_fields 中提供多个字段名称,则所有字段的值将首先被连接起来,然后执行重排序。

步骤 2.3:测试重排序

要限制返回结果的数量,您可以指定 size 参数。例如,设置 "size": 4 以返回前四个文档

GET my-test-data/_search?search_pipeline=rerank_pipeline_sagemaker
{
  "query": {
    "match_all": {}
  },
  "size": 4,
  "ext": {
    "rerank": {
      "query_context": {
         "query_text": "What is the capital of the United States?"
      }
    }
  }
}

响应包含四个最相关的文档

{
  "took": 3,
  "timed_out": false,
  "_shards": {
    "total": 1,
    "successful": 1,
    "skipped": 0,
    "failed": 0
  },
  "hits": {
    "total": {
      "value": 4,
      "relation": "eq"
    },
    "max_score": 0.9997217,
    "hits": [
      {
        "_index": "my-test-data",
        "_id": "U0xye5AB9ZeWZdmDjWZn",
        "_score": 0.9997217,
        "_source": {
          "passage_text": "Washington, D.C. (also known as simply Washington or D.C., and officially as the District of Columbia) is the capital of the United States. It is a federal district."
        }
      },
      {
        "_index": "my-test-data",
        "_id": "VExye5AB9ZeWZdmDjWZn",
        "_score": 0.55655104,
        "_source": {
          "passage_text": "Capital punishment (the death penalty) has existed in the United States since beforethe United States was a country. As of 2017, capital punishment is legal in 30 of the 50 states."
        }
      },
      {
        "_index": "my-test-data",
        "_id": "UUxye5AB9ZeWZdmDjWZn",
        "_score": 0.115356825,
        "_source": {
          "passage_text": "Carson City is the capital city of the American state of Nevada."
        }
      },
      {
        "_index": "my-test-data",
        "_id": "Ukxye5AB9ZeWZdmDjWZn",
        "_score": 0.00021142483,
        "_source": {
          "passage_text": "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean. Its capital is Saipan."
        }
      }
    ]
  },
  "profile": {
    "shards": []
  }
}

要将这些结果与未重排序的结果进行比较,请在没有重排序管道的情况下运行搜索

GET my-test-data/_search
{
  "query": {
    "match_all": {}
  },
  "ext": {
    "rerank": {
      "query_context": {
         "query_text": "What is the capital of the United States?"
      }
    }
  }
}

响应中的第一个文档与卡森城相关,而卡森城并非美国首都

{
  "took": 1,
  "timed_out": false,
  "_shards": {
    "total": 1,
    "successful": 1,
    "skipped": 0,
    "failed": 0
  },
  "hits": {
    "total": {
      "value": 4,
      "relation": "eq"
    },
    "max_score": 1,
    "hits": [
      {
        "_index": "my-test-data",
        "_id": "UUxye5AB9ZeWZdmDjWZn",
        "_score": 1,
        "_source": {
          "passage_text": "Carson City is the capital city of the American state of Nevada."
        }
      },
      {
        "_index": "my-test-data",
        "_id": "Ukxye5AB9ZeWZdmDjWZn",
        "_score": 1,
        "_source": {
          "passage_text": "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean. Its capital is Saipan."
        }
      },
      {
        "_index": "my-test-data",
        "_id": "U0xye5AB9ZeWZdmDjWZn",
        "_score": 1,
        "_source": {
          "passage_text": "Washington, D.C. (also known as simply Washington or D.C., and officially as the District of Columbia) is the capital of the United States. It is a federal district."
        }
      },
      {
        "_index": "my-test-data",
        "_id": "VExye5AB9ZeWZdmDjWZn",
        "_score": 1,
        "_source": {
          "passage_text": "Capital punishment (the death penalty) has existed in the United States since beforethe United States was a country. As of 2017, capital punishment is legal in 30 of the 50 states."
        }
      }
    ]
  }
}