在 Amazon SageMaker 中使用交叉编码器重新排序搜索结果
一个重排序管道可以对搜索结果进行重排序,针对搜索查询为搜索结果中的每个文档提供相关性分数。相关性分数由交叉编码器模型计算。
本教程展示了如何在重排序管道中使用 Hugging Face ms-marco-MiniLM-L-6-v2
模型。
将以 your_
为前缀的占位符替换为您自己的值。
先决条件
在开始之前,请在 Amazon SageMaker 上部署模型。为了获得更好的性能,请使用 GPU。
运行以下代码以在 Amazon SageMaker 上部署模型
import sagemaker
import boto3
from sagemaker.huggingface import HuggingFaceModel
sess = sagemaker.Session()
role = sagemaker.get_execution_role()
hub = {
'HF_MODEL_ID':'cross-encoder/ms-marco-MiniLM-L-6-v2',
'HF_TASK':'text-classification'
}
huggingface_model = HuggingFaceModel(
transformers_version='4.37.0',
pytorch_version='2.1.0',
py_version='py310',
env=hub,
role=role,
)
predictor = huggingface_model.deploy(
initial_instance_count=1, # number of instances
instance_type='ml.m5.xlarge' # ec2 instance type
)
请记下模型推理端点;您将在下一步中使用它来创建连接器。
步骤 1:创建连接器并注册模型
要为模型创建连接器,请发送以下请求。
如果您使用自管理的 OpenSearch,请提供您的 AWS 凭证:
POST /_plugins/_ml/connectors/_create
{
"name": "Sagemaker cross-encoder model",
"description": "Test connector for Sagemaker cross-encoder model",
"version": 1,
"protocol": "aws_sigv4",
"credential": {
"access_key": "your_access_key",
"secret_key": "your_secret_key",
"session_token": "your_session_token"
},
"parameters": {
"region": "your_sagemkaer_model_region_like_us-west-2",
"service_name": "sagemaker"
},
"actions": [
{
"action_type": "predict",
"method": "POST",
"url": "your_sagemaker_model_inference_endpoint_created_in_last_step",
"headers": {
"content-type": "application/json"
},
"request_body": "{ \"inputs\": ${parameters.inputs} }",
"pre_process_function": "\n String escape(def input) { \n if (input.contains(\"\\\\\")) {\n input = input.replace(\"\\\\\", \"\\\\\\\\\");\n }\n if (input.contains(\"\\\"\")) {\n input = input.replace(\"\\\"\", \"\\\\\\\"\");\n }\n if (input.contains('\r')) {\n input = input = input.replace('\r', '\\\\r');\n }\n if (input.contains(\"\\\\t\")) {\n input = input.replace(\"\\\\t\", \"\\\\\\\\\\\\t\");\n }\n if (input.contains('\n')) {\n input = input.replace('\n', '\\\\n');\n }\n if (input.contains('\b')) {\n input = input.replace('\b', '\\\\b');\n }\n if (input.contains('\f')) {\n input = input.replace('\f', '\\\\f');\n }\n return input;\n }\n\n String query = params.query_text;\n StringBuilder builder = new StringBuilder('[');\n \n for (int i=0; i<params.text_docs.length; i ++) {\n builder.append('{\"text\":\"');\n builder.append(escape(query));\n builder.append('\", \"text_pair\":\"');\n builder.append(escape(params.text_docs[i]));\n builder.append('\"}');\n if (i<params.text_docs.length - 1) {\n builder.append(',');\n }\n }\n builder.append(']');\n \n def parameters = '{ \"inputs\": ' + builder + ' }';\n return '{\"parameters\": ' + parameters + '}';\n ",
"post_process_function": "\n \n def dataType = \"FLOAT32\";\n \n \n if (params.result == null)\n {\n return 'no result generated';\n //return params.response;\n }\n def outputs = params.result;\n \n \n def resultBuilder = new StringBuilder('[ ');\n for (int i=0; i<outputs.length; i++) {\n resultBuilder.append(' {\"name\": \"similarity\", \"data_type\": \"FLOAT32\", \"shape\": [1],');\n //resultBuilder.append('{\"name\": \"similarity\"}');\n \n resultBuilder.append('\"data\": [');\n resultBuilder.append(outputs[i].score);\n resultBuilder.append(']}');\n if (i<outputs.length - 1) {\n resultBuilder.append(',');\n }\n }\n resultBuilder.append(']');\n \n return resultBuilder.toString();\n "
}
]
}
如果您正在使用 Amazon OpenSearch 服务,您可以提供一个 AWS Identity and Access Management (IAM) 角色 Amazon 资源名称 (ARN),该名称允许访问 SageMaker 模型推理端点。
POST /_plugins/_ml/connectors/_create
{
"name": "Sagemakre cross-encoder model",
"description": "Test connector for Sagemaker cross-encoder model",
"version": 1,
"protocol": "aws_sigv4",
"credential": {
"roleArn": "your_role_arn_which_allows_access_to_sagemaker_model_inference_endpoint"
},
"parameters": {
"region": "your_sagemkaer_model_region_like_us-west-2",
"service_name": "sagemaker"
},
"actions": [
{
"action_type": "predict",
"method": "POST",
"url": "your_sagemaker_model_inference_endpoint_created_in_last_step",
"headers": {
"content-type": "application/json"
},
"request_body": "{ \"inputs\": ${parameters.inputs} }",
"pre_process_function": "\n String escape(def input) { \n if (input.contains(\"\\\\\")) {\n input = input.replace(\"\\\\\", \"\\\\\\\\\");\n }\n if (input.contains(\"\\\"\")) {\n input = input.replace(\"\\\"\", \"\\\\\\\"\");\n }\n if (input.contains('\r')) {\n input = input = input.replace('\r', '\\\\r');\n }\n if (input.contains(\"\\\\t\")) {\n input = input.replace(\"\\\\t\", \"\\\\\\\\\\\\t\");\n }\n if (input.contains('\n')) {\n input = input.replace('\n', '\\\\n');\n }\n if (input.contains('\b')) {\n input = input.replace('\b', '\\\\b');\n }\n if (input.contains('\f')) {\n input = input.replace('\f', '\\\\f');\n }\n return input;\n }\n\n String query = params.query_text;\n StringBuilder builder = new StringBuilder('[');\n \n for (int i=0; i<params.text_docs.length; i ++) {\n builder.append('{\"text\":\"');\n builder.append(escape(query));\n builder.append('\", \"text_pair\":\"');\n builder.append(escape(params.text_docs[i]));\n builder.append('\"}');\n if (i<params.text_docs.length - 1) {\n builder.append(',');\n }\n }\n builder.append(']');\n \n def parameters = '{ \"inputs\": ' + builder + ' }';\n return '{\"parameters\": ' + parameters + '}';\n ",
"post_process_function": "\n \n def dataType = \"FLOAT32\";\n \n \n if (params.result == null)\n {\n return 'no result generated';\n //return params.response;\n }\n def outputs = params.result;\n \n \n def resultBuilder = new StringBuilder('[ ');\n for (int i=0; i<outputs.length; i++) {\n resultBuilder.append(' {\"name\": \"similarity\", \"data_type\": \"FLOAT32\", \"shape\": [1],');\n //resultBuilder.append('{\"name\": \"similarity\"}');\n \n resultBuilder.append('\"data\": [');\n resultBuilder.append(outputs[i].score);\n resultBuilder.append(']}');\n if (i<outputs.length - 1) {\n resultBuilder.append(',');\n }\n }\n resultBuilder.append(']');\n \n return resultBuilder.toString();\n "
}
]
}
更多信息请参阅AWS 文档、本教程和AIConnectorHelper Notebook。
使用响应中的连接器 ID 注册和部署模型:
POST /_plugins/_ml/models/_register?deploy=true
{
"name": "Sagemaker Cross-Encoder model",
"function_name": "remote",
"description": "test rerank model",
"connector_id": "your_connector_id"
}
记下响应中的模型 ID;您将在后续步骤中使用它。
要测试模型,请调用 Predict API
POST _plugins/_ml/models/your_model_id/_predict
{
"parameters": {
"inputs": [
{
"text": "I like you",
"text_pair": "I hate you"
},
{
"text": "I like you",
"text_pair": "I love you"
}
]
}
}
inputs
数组中的每个项都包含一个 query_text
和一个 text_docs
字符串,由 ` . ` 分隔
或者,您可以按如下方式测试模型:
POST _plugins/_ml/_predict/text_similarity/your_model_id
{
"query_text": "I like you",
"text_docs": ["I hate you", "I love you"]
}
连接器的 pre_process_function
会将输入转换为上一个 Predict API 请求中显示的 inputs
参数所需的格式。
默认情况下,SageMaker 模型输出采用以下格式
[
{
"label": "LABEL_0",
"score": 0.054037678986787796
},
{
"label": "LABEL_0",
"score": 0.5877784490585327
}
]
连接器的 pre_process_function
会将模型输出转换为可由重排序处理器解释的以下格式
{
"inference_results": [
{
"output": [
{
"name": "similarity",
"data_type": "FLOAT32",
"shape": [
1
],
"data": [
0.054037678986787796
]
},
{
"name": "similarity",
"data_type": "FLOAT32",
"shape": [
1
],
"data": [
0.5877784490585327
]
}
],
"status_code": 200
}
]
}
响应包含两个 similarity
输出。对于每个 similarity
输出,data
数组都包含针对查询的每个文档的相关性分数。这些 similarity
输出按照输入文档的顺序提供:第一个相似性结果对应第一个文档。
步骤 2:配置重排序流水线
按照以下步骤配置重排序流水线。
步骤 2.1:摄取测试数据
发送批量请求以摄入测试数据
POST _bulk
{ "index": { "_index": "my-test-data" } }
{ "passage_text" : "Carson City is the capital city of the American state of Nevada." }
{ "index": { "_index": "my-test-data" } }
{ "passage_text" : "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean. Its capital is Saipan." }
{ "index": { "_index": "my-test-data" } }
{ "passage_text" : "Washington, D.C. (also known as simply Washington or D.C., and officially as the District of Columbia) is the capital of the United States. It is a federal district." }
{ "index": { "_index": "my-test-data" } }
{ "passage_text" : "Capital punishment (the death penalty) has existed in the United States since beforethe United States was a country. As of 2017, capital punishment is legal in 30 of the 50 states." }
步骤 2.2:创建重新排序管道
使用 MS MARCO 交叉编码器模型创建重排序管道
PUT /_search/pipeline/rerank_pipeline_sagemaker
{
"description": "Pipeline for reranking with Sagemaker cross-encoder model",
"response_processors": [
{
"rerank": {
"ml_opensearch": {
"model_id": "your_model_id_created_in_step1"
},
"context": {
"document_fields": ["passage_text"]
}
}
}
]
}
如果您在 document_fields
中提供多个字段名称,则所有字段的值将首先被连接起来,然后执行重排序。
步骤 2.3:测试重排序
要限制返回结果的数量,您可以指定 size
参数。例如,设置 "size": 4
以返回前四个文档
GET my-test-data/_search?search_pipeline=rerank_pipeline_sagemaker
{
"query": {
"match_all": {}
},
"size": 4,
"ext": {
"rerank": {
"query_context": {
"query_text": "What is the capital of the United States?"
}
}
}
}
响应包含四个最相关的文档
{
"took": 3,
"timed_out": false,
"_shards": {
"total": 1,
"successful": 1,
"skipped": 0,
"failed": 0
},
"hits": {
"total": {
"value": 4,
"relation": "eq"
},
"max_score": 0.9997217,
"hits": [
{
"_index": "my-test-data",
"_id": "U0xye5AB9ZeWZdmDjWZn",
"_score": 0.9997217,
"_source": {
"passage_text": "Washington, D.C. (also known as simply Washington or D.C., and officially as the District of Columbia) is the capital of the United States. It is a federal district."
}
},
{
"_index": "my-test-data",
"_id": "VExye5AB9ZeWZdmDjWZn",
"_score": 0.55655104,
"_source": {
"passage_text": "Capital punishment (the death penalty) has existed in the United States since beforethe United States was a country. As of 2017, capital punishment is legal in 30 of the 50 states."
}
},
{
"_index": "my-test-data",
"_id": "UUxye5AB9ZeWZdmDjWZn",
"_score": 0.115356825,
"_source": {
"passage_text": "Carson City is the capital city of the American state of Nevada."
}
},
{
"_index": "my-test-data",
"_id": "Ukxye5AB9ZeWZdmDjWZn",
"_score": 0.00021142483,
"_source": {
"passage_text": "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean. Its capital is Saipan."
}
}
]
},
"profile": {
"shards": []
}
}
要将这些结果与未重排序的结果进行比较,请在没有重排序管道的情况下运行搜索
GET my-test-data/_search
{
"query": {
"match_all": {}
},
"ext": {
"rerank": {
"query_context": {
"query_text": "What is the capital of the United States?"
}
}
}
}
响应中的第一个文档与卡森城相关,而卡森城并非美国首都
{
"took": 1,
"timed_out": false,
"_shards": {
"total": 1,
"successful": 1,
"skipped": 0,
"failed": 0
},
"hits": {
"total": {
"value": 4,
"relation": "eq"
},
"max_score": 1,
"hits": [
{
"_index": "my-test-data",
"_id": "UUxye5AB9ZeWZdmDjWZn",
"_score": 1,
"_source": {
"passage_text": "Carson City is the capital city of the American state of Nevada."
}
},
{
"_index": "my-test-data",
"_id": "Ukxye5AB9ZeWZdmDjWZn",
"_score": 1,
"_source": {
"passage_text": "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean. Its capital is Saipan."
}
},
{
"_index": "my-test-data",
"_id": "U0xye5AB9ZeWZdmDjWZn",
"_score": 1,
"_source": {
"passage_text": "Washington, D.C. (also known as simply Washington or D.C., and officially as the District of Columbia) is the capital of the United States. It is a federal district."
}
},
{
"_index": "my-test-data",
"_id": "VExye5AB9ZeWZdmDjWZn",
"_score": 1,
"_source": {
"passage_text": "Capital punishment (the death penalty) has existed in the United States since beforethe United States was a country. As of 2017, capital punishment is legal in 30 of the 50 states."
}
}
]
}
}