在 Amazon SageMaker 中使用重新排序器重新排序搜索结果
重排序流水线可以对搜索结果进行重排序,为搜索结果中的每个文档提供相对于搜索查询的相关性得分。相关性得分由重排序模型计算。
本教程将向您展示如何在自管 OpenSearch 和 Amazon OpenSearch Service 中对搜索结果进行重排序。本教程使用托管在 Amazon SageMaker 上的 Hugging Face BAAI/bge-reranker-v2-m3 模型。
将以 your_
为前缀的占位符替换为您自己的值。
先决条件:将模型部署到 Amazon SageMaker
使用以下代码将模型部署到 Amazon SageMaker。我们建议使用 GPU 以获得更好的性能。
import json
import sagemaker
import boto3
from sagemaker.huggingface import HuggingFaceModel, get_huggingface_llm_image_uri
from sagemaker.serverless import ServerlessInferenceConfig
try:
role = sagemaker.get_execution_role()
except ValueError:
iam = boto3.client('iam')
role = iam.get_role(RoleName='sagemaker_execution_role')['Role']['Arn']
# Hub Model configuration. https://hugging-face.cn/models
hub = {
'HF_MODEL_ID':'BAAI/bge-reranker-v2-m3'
}
# create Hugging Face Model Class
huggingface_model = HuggingFaceModel(
image_uri=get_huggingface_llm_image_uri("huggingface-tei",version="1.2.3"),
env=hub,
role=role,
)
# deploy model to SageMaker Inference
predictor = huggingface_model.deploy(
initial_instance_count=1,
instance_type="ml.g5.2xlarge",
)
欲了解更多信息,请参阅如何使用 Amazon SageMaker 部署此模型。
要执行重排序测试,请使用以下代码:
result = predictor.predict(data={
"query":"What is the capital city of America?",
"texts":[
"Carson City is the capital city of the American state of Nevada.",
"The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean. Its capital is Saipan.",
"Washington, D.C. (also known as simply Washington or D.C., and officially as the District of Columbia) is the capital of the United States. It is a federal district.",
"Capital punishment (the death penalty) has existed in the United States since beforethe United States was a country. As of 2017, capital punishment is legal in 30 of the 50 states."
]
})
print(json.dumps(result, indent=2))
响应包含按相关性得分排序的重排序结果。
[
{
"index": 2,
"score": 0.92879725
},
{
"index": 0,
"score": 0.013636836
},
{
"index": 1,
"score": 0.000593021
},
{
"index": 3,
"score": 0.00012148176
}
]
要按索引对结果进行排序,请使用以下代码:
print(json.dumps(sorted(result, key=lambda x: x['index']),indent=2))
排序结果如下:
[
{
"index": 0,
"score": 0.013636836
},
{
"index": 1,
"score": 0.000593021
},
{
"index": 2,
"score": 0.92879725
},
{
"index": 3,
"score": 0.00012148176
}
]
请注意模型推理端点;您将在下一步中使用它来创建连接器。您可以使用以下代码确认推理端点 URL:
region_name = boto3.Session().region_name
endpoint_name = predictor.endpoint_name
endpoint_url = f"https://runtime.sagemaker.{region_name}.amazonaws.com/endpoints/{endpoint_name}/invocations"
print(endpoint_url)
步骤 1:创建连接器并注册模型
要为模型创建连接器,请发送以下请求。
如果您使用自管理的 OpenSearch,请提供您的 AWS 凭证:
POST /_plugins/_ml/connectors/_create
{
"name": "Sagemakre cross-encoder model",
"description": "Test connector for Sagemaker cross-encoder model",
"version": 1,
"protocol": "aws_sigv4",
"credential": {
"access_key": "your_access_key",
"secret_key": "your_secret_key",
"session_token": "your_session_token"
},
"parameters": {
"region": "your_sagemaker_model_region_like_us-west-2",
"service_name": "sagemaker"
},
"actions": [
{
"action_type": "predict",
"method": "POST",
"url": "your_sagemaker_model_inference_endpoint_created_in_last_step",
"headers": {
"content-type": "application/json"
},
"pre_process_function": """
def query_text = params.query_text;
def text_docs = params.text_docs;
def textDocsBuilder = new StringBuilder('[');
for (int i=0; i<text_docs.length; i++) {
textDocsBuilder.append('"');
textDocsBuilder.append(text_docs[i]);
textDocsBuilder.append('"');
if (i<text_docs.length - 1) {
textDocsBuilder.append(',');
}
}
textDocsBuilder.append(']');
def parameters = '{ "query": "' + query_text + '", "texts": ' + textDocsBuilder.toString() + ' }';
return '{"parameters": ' + parameters + '}';
""",
"request_body": """
{
"query": "${parameters.query}",
"texts": ${parameters.texts}
}
""",
"post_process_function": """
if (params.result == null || params.result.length == 0) {
throw new IllegalArgumentException("Post process function input is empty.");
}
def outputs = params.result;
def scores = new Double[outputs.length];
for (int i=0; i<outputs.length; i++) {
def index = new BigDecimal(outputs[i].index.toString()).intValue();
scores[index] = outputs[i].score;
}
def resultBuilder = new StringBuilder('[');
for (int i=0; i<scores.length; i++) {
resultBuilder.append(' {"name": "similarity", "data_type": "FLOAT32", "shape": [1],');
resultBuilder.append('"data": [');
resultBuilder.append(scores[i]);
resultBuilder.append(']}');
if (i<outputs.length - 1) {
resultBuilder.append(',');
}
}
resultBuilder.append(']');
return resultBuilder.toString();
"""
}
]
}
如果您正在使用 Amazon OpenSearch Service,您可以提供一个 AWS Identity and Access Management (IAM) 角色 Amazon Resource Name (ARN),以允许访问 SageMaker 模型推理端点。
POST /_plugins/_ml/connectors/_create
{
"name": "Sagemakre cross-encoder model",
"description": "Test connector for Sagemaker cross-encoder model",
"version": 1,
"protocol": "aws_sigv4",
"credential": {
"roleArn": "your_role_arn_which_allows_access_to_sagemaker_model_inference_endpoint"
},
"parameters": {
"region": "your_sagemkaer_model_region_like_us-west-2",
"service_name": "sagemaker"
},
"actions": [
{
"action_type": "predict",
"method": "POST",
"url": "your_sagemaker_model_inference_endpoint_created_in_last_step",
"headers": {
"content-type": "application/json"
},
"pre_process_function": """
def query_text = params.query_text;
def text_docs = params.text_docs;
def textDocsBuilder = new StringBuilder('[');
for (int i=0; i<text_docs.length; i++) {
textDocsBuilder.append('"');
textDocsBuilder.append(text_docs[i]);
textDocsBuilder.append('"');
if (i<text_docs.length - 1) {
textDocsBuilder.append(',');
}
}
textDocsBuilder.append(']');
def parameters = '{ "query": "' + query_text + '", "texts": ' + textDocsBuilder.toString() + ' }';
return '{"parameters": ' + parameters + '}';
""",
"request_body": """
{
"query": "${parameters.query}",
"texts": ${parameters.texts}
}
""",
"post_process_function": """
if (params.result == null || params.result.length == 0) {
throw new IllegalArgumentException("Post process function input is empty.");
}
def outputs = params.result;
def scores = new Double[outputs.length];
for (int i=0; i<outputs.length; i++) {
def index = new BigDecimal(outputs[i].index.toString()).intValue();
scores[index] = outputs[i].score;
}
def resultBuilder = new StringBuilder('[');
for (int i=0; i<scores.length; i++) {
resultBuilder.append(' {"name": "similarity", "data_type": "FLOAT32", "shape": [1],');
resultBuilder.append('"data": [');
resultBuilder.append(scores[i]);
resultBuilder.append(']}');
if (i<outputs.length - 1) {
resultBuilder.append(',');
}
}
resultBuilder.append(']');
return resultBuilder.toString();
"""
}
]
}
欲了解更多信息,请参阅AWS 文档、本教程和AIConnectorHelper 笔记本。
使用响应中的连接器 ID 注册和部署模型:
POST /_plugins/_ml/models/_register?deploy=true
{
"name": "Sagemaker Cross-Encoder model",
"function_name": "remote",
"description": "test rerank model",
"connector_id": "your_connector_id"
}
记下响应中的模型 ID;您将在后续步骤中使用它。
使用 Predict API 测试模型:
POST _plugins/_ml/models/your_model_id/_predict
{
"parameters": {
"query": "What is the capital city of America?",
"texts": [
"Carson City is the capital city of the American state of Nevada.",
"The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean. Its capital is Saipan.",
"Washington, D.C. (also known as simply Washington or D.C., and officially as the District of Columbia) is the capital of the United States. It is a federal district.",
"Capital punishment (the death penalty) has existed in the United States since beforethe United States was a country. As of 2017, capital punishment is legal in 30 of the 50 states."
]
}
}
或者,您可以按如下方式测试模型:
POST _plugins/_ml/_predict/text_similarity/your_model_id
{
"query_text": "What is the capital city of America?",
"text_docs": [
"Carson City is the capital city of the American state of Nevada.",
"The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean. Its capital is Saipan.",
"Washington, D.C. (also known as simply Washington or D.C., and officially as the District of Columbia) is the capital of the United States. It is a federal district.",
"Capital punishment (the death penalty) has existed in the United States since beforethe United States was a country. As of 2017, capital punishment is legal in 30 of the 50 states."
]
}
连接器 pre_process_function
将输入转换为先前所示参数所需的格式。
默认情况下,模型输出具有以下格式:
[
{
"index": 2,
"score": 0.92879725
},
{
"index": 0,
"score": 0.013636836
},
{
"index": 1,
"score": 0.000593021
},
{
"index": 3,
"score": 0.00012148176
}
]
连接器 post_process_function
将模型的输出转换为 重排序处理器 可以解释的格式,并按索引顺序排列结果。此适应格式如下:
{
"inference_results": [
{
"output": [
{
"name": "similarity",
"data_type": "FLOAT32",
"shape": [
1
],
"data": [
0.013636836
]
},
{
"name": "similarity",
"data_type": "FLOAT32",
"shape": [
1
],
"data": [
0.013636836
]
},
{
"name": "similarity",
"data_type": "FLOAT32",
"shape": [
1
],
"data": [
0.92879725
]
},
{
"name": "similarity",
"data_type": "FLOAT32",
"shape": [
1
],
"data": [
0.00012148176
]
}
],
"status_code": 200
}
]
}
响应包含两个 similarity
对象。对于每个 similarity
对象,data
数组包含每个文档相对于查询的相关性得分。similarity
对象按照输入文档的顺序提供——第一个对象对应于第一个文档。
步骤 2:配置重排序流水线
按照以下步骤配置重排序流水线。
步骤 2.1:摄取测试数据
发送批量请求以摄入测试数据
POST _bulk
{ "index": { "_index": "my-test-data" } }
{ "passage_text" : "Carson City is the capital city of the American state of Nevada." }
{ "index": { "_index": "my-test-data" } }
{ "passage_text" : "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean. Its capital is Saipan." }
{ "index": { "_index": "my-test-data" } }
{ "passage_text" : "Washington, D.C. (also known as simply Washington or D.C., and officially as the District of Columbia) is the capital of the United States. It is a federal district." }
{ "index": { "_index": "my-test-data" } }
{ "passage_text" : "Capital punishment (the death penalty) has existed in the United States since beforethe United States was a country. As of 2017, capital punishment is legal in 30 of the 50 states." }
步骤 2.2:创建重新排序管道
使用交叉编码器模型创建重排序流水线
PUT /_search/pipeline/rerank_pipeline_sagemaker
{
"description": "Pipeline for reranking with Sagemaker cross-encoder model",
"response_processors": [
{
"rerank": {
"ml_opensearch": {
"model_id": "your_model_id_created_in_step1"
},
"context": {
"document_fields": ["passage_text"]
}
}
}
]
}
如果在 document_fields
中提供了多个字段名,则所有字段的值将首先连接,然后执行重排序。
步骤 2.3:测试重排序
要限制返回结果的数量,可以指定 size
参数。例如,设置 "size": 2
以返回前两个文档。
首先,在不使用重新排序管道的情况下测试查询:
POST my-test-data/_search
{
"query": {
"match": {
"passage_text": "What is the capital city of America?"
}
},
"highlight": {
"pre_tags": ["<strong>"],
"post_tags": ["</strong>"],
"fields": {"passage_text": {}}
},
"_source": false,
"fields": ["passage_text"]
}
响应中的第一个文档是 Carson City is the capital city of the American state of Nevada
,这是不正确的。
{
"took": 2,
"timed_out": false,
"_shards": {
"total": 1,
"successful": 1,
"skipped": 0,
"failed": 0
},
"hits": {
"total": {
"value": 4,
"relation": "eq"
},
"max_score": 2.5045562,
"hits": [
{
"_index": "my-test-data",
"_id": "1",
"_score": 2.5045562,
"fields": {
"passage_text": [
"Carson City is the capital city of the American state of Nevada."
]
},
"highlight": {
"passage_text": [
"Carson <strong>City</strong> <strong>is</strong> <strong>the</strong> <strong>capital</strong> <strong>city</strong> <strong>of</strong> <strong>the</strong> American state <strong>of</strong> Nevada."
]
}
},
{
"_index": "my-test-data",
"_id": "2",
"_score": 0.5807494,
"fields": {
"passage_text": [
"The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean. Its capital is Saipan."
]
},
"highlight": {
"passage_text": [
"<strong>The</strong> Commonwealth <strong>of</strong> <strong>the</strong> Northern Mariana Islands <strong>is</strong> a group <strong>of</strong> islands in <strong>the</strong> Pacific Ocean.",
"Its <strong>capital</strong> <strong>is</strong> Saipan."
]
}
},
{
"_index": "my-test-data",
"_id": "3",
"_score": 0.5261191,
"fields": {
"passage_text": [
"Washington, D.C. (also known as simply Washington or D.C., and officially as the District of Columbia) is the capital of the United States. It is a federal district."
]
},
"highlight": {
"passage_text": [
"(also known as simply Washington or D.C., and officially as <strong>the</strong> District <strong>of</strong> Columbia) <strong>is</strong> <strong>the</strong> <strong>capital</strong>",
"<strong>of</strong> <strong>the</strong> United States.",
"It <strong>is</strong> a federal district."
]
}
},
{
"_index": "my-test-data",
"_id": "4",
"_score": 0.5083029,
"fields": {
"passage_text": [
"Capital punishment (the death penalty) has existed in the United States since beforethe United States was a country. As of 2017, capital punishment is legal in 30 of the 50 states."
]
},
"highlight": {
"passage_text": [
"<strong>Capital</strong> punishment (<strong>the</strong> death penalty) has existed in <strong>the</strong> United States since beforethe United States",
"As <strong>of</strong> 2017, <strong>capital</strong> punishment <strong>is</strong> legal in 30 <strong>of</strong> <strong>the</strong> 50 states."
]
}
}
]
}
}
接下来,使用重新排序管道测试查询:
POST my-test-data/_search?search_pipeline=rerank_pipeline_sagemaker
{
"query": {
"match": {
"passage_text": "What is the capital city of America?"
}
},
"ext": {
"rerank": {
"query_context": {
"query_text": "What is the capital city of America?"
}
}
},
"highlight": {
"pre_tags": ["<strong>"],
"post_tags": ["</strong>"],
"fields": {"passage_text": {}}
},
"_source": false,
"fields": ["passage_text"]
}
响应中的第一个文档是 "华盛顿特区(简称华盛顿或D.C.,官方名称为哥伦比亚特区)是美国的首都。它是一个联邦区。"
,这是正确的。
{
"took": 2,
"timed_out": false,
"_shards": {
"total": 1,
"successful": 1,
"skipped": 0,
"failed": 0
},
"hits": {
"total": {
"value": 4,
"relation": "eq"
},
"max_score": 0.92879725,
"hits": [
{
"_index": "my-test-data",
"_id": "3",
"_score": 0.92879725,
"fields": {
"passage_text": [
"Washington, D.C. (also known as simply Washington or D.C., and officially as the District of Columbia) is the capital of the United States. It is a federal district."
]
},
"highlight": {
"passage_text": [
"(also known as simply Washington or D.C., and officially as <strong>the</strong> District <strong>of</strong> Columbia) <strong>is</strong> <strong>the</strong> <strong>capital</strong>",
"<strong>of</strong> <strong>the</strong> United States.",
"It <strong>is</strong> a federal district."
]
}
},
{
"_index": "my-test-data",
"_id": "1",
"_score": 0.013636836,
"fields": {
"passage_text": [
"Carson City is the capital city of the American state of Nevada."
]
},
"highlight": {
"passage_text": [
"Carson <strong>City</strong> <strong>is</strong> <strong>the</strong> <strong>capital</strong> <strong>city</strong> <strong>of</strong> <strong>the</strong> American state <strong>of</strong> Nevada."
]
}
},
{
"_index": "my-test-data",
"_id": "2",
"_score": 0.013636836,
"fields": {
"passage_text": [
"The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean. Its capital is Saipan."
]
},
"highlight": {
"passage_text": [
"<strong>The</strong> Commonwealth <strong>of</strong> <strong>the</strong> Northern Mariana Islands <strong>is</strong> a group <strong>of</strong> islands in <strong>the</strong> Pacific Ocean.",
"Its <strong>capital</strong> <strong>is</strong> Saipan."
]
}
},
{
"_index": "my-test-data",
"_id": "4",
"_score": 0.00012148176,
"fields": {
"passage_text": [
"Capital punishment (the death penalty) has existed in the United States since beforethe United States was a country. As of 2017, capital punishment is legal in 30 of the 50 states."
]
},
"highlight": {
"passage_text": [
"<strong>Capital</strong> punishment (<strong>the</strong> death penalty) has existed in <strong>the</strong> United States since beforethe United States",
"As <strong>of</strong> 2017, <strong>capital</strong> punishment <strong>is</strong> legal in 30 <strong>of</strong> <strong>the</strong> 50 states."
]
}
}
]
},
"profile": {
"shards": []
}
}
为避免重复编写查询,请使用 query_text_path
而不是 query_text
,如下所示:
POST my-test-data/_search?search_pipeline=rerank_pipeline_sagemaker
{
"query": {
"match": {
"passage_text": "What is the capital city of America?"
}
},
"ext": {
"rerank": {
"query_context": {
"query_text_path": "query.match.passage_text.query"
}
}
},
"highlight": {
"pre_tags": ["<strong>"],
"post_tags": ["</strong>"],
"fields": {"passage_text": {}}
},
"_source": false,
"fields": ["passage_text"]
}