Link Search Menu Expand Document Documentation Menu

使用外部托管的交叉编码器模型按字段重新排序

2.18 版引入

在本教程中,您将学习如何使用托管在 Amazon SageMaker 上的交叉编码器模型重新排序搜索结果并提高搜索相关性。

要重新排序文档,您将配置一个搜索管道,该管道在查询时处理搜索结果。管道会拦截搜索结果并将其传递给 ml_inference 搜索响应处理器,该处理器会调用交叉编码器模型。该模型会生成用于通过 by_field 重新排序匹配文档的分数。

先决条件:在 Amazon SageMaker 上部署模型

运行以下代码以在 Amazon SageMaker 上部署模型。对于此示例,您将使用托管在 Amazon SageMaker 上的 ms-marco-MiniLM-L-6-v2 Hugging Face 交叉编码器模型。我们建议使用 GPU 以获得更好的性能。

import sagemaker
import boto3
from sagemaker.huggingface import HuggingFaceModel

sess = sagemaker.Session()
role = sagemaker.get_execution_role()

hub = {
    'HF_MODEL_ID':'cross-encoder/ms-marco-MiniLM-L-6-v2',
    'HF_TASK':'text-classification'
}
huggingface_model = HuggingFaceModel(
    transformers_version='4.37.0',
    pytorch_version='2.1.0',
    py_version='py310',
    env=hub,
    role=role, 
)
predictor = huggingface_model.deploy(
    initial_instance_count=1, # number of instances
    instance_type='ml.m5.xlarge' # ec2 instance type
)

部署模型后,您可以通过访问 AWS 管理控制台中的 Amazon SageMaker 控制台,并在左侧选项卡中选择 Inference > Endpoints 来找到模型终端节点。记下创建的模型 URL;您将在后续步骤中使用它来创建连接器。

运行带重排的搜索

要运行带重排的搜索,请遵循以下步骤:

  1. 创建连接器.
  2. 注册模型.
  3. 将文档摄入到索引中.
  4. 创建搜索管道.
  5. 使用重排进行搜索.

步骤 1:创建连接器

通过在 actions.url 参数中提供模型 URL 来创建与交叉编码器模型的连接器:

POST /_plugins/_ml/connectors/_create
{
  "name": "SageMaker cross-encoder model",
  "description": "Test connector for SageMaker cross-encoder hosted model",
  "version": 1,
  "protocol": "aws_sigv4",
  "credential": {
		"access_key": "<YOUR_ACCESS_KEY>",
		"secret_key": "<YOUR_SECRET_KEY>",
		"session_token": "<YOUR_SESSION_TOKEN>"
  },
  "parameters": {
    "region": "<REGION>",
    "service_name": "sagemaker"
  },
  "actions": [
    {
      "action_type": "predict",
      "method": "POST",
      "url": "<YOUR_SAGEMAKER_ENDPOINT_URL>",
      "headers": {
        "content-type": "application/json"
      },
      "request_body": "{ \"inputs\": { \"text\": \"${parameters.text}\", \"text_pair\": \"${parameters.text_pair}\" }}"
    }
  ]
}

记下响应中包含的连接器 ID;您将在下一步中使用它。

步骤 2:注册模型

要注册模型,请在 connector_id 参数中提供连接器 ID:

POST /_plugins/_ml/models/_register
{
  "name": "Cross encoder model",
  "version": "1.0.1",
  "function_name": "remote",
  "description": "Using a SageMaker endpoint to apply a cross encoder model",
  "connector_id": "<YOUR_CONNECTOR_ID>"
} 

步骤 3:将文档摄入到索引中

创建一个索引并摄入包含纽约市行政区事实的示例文档:

POST /nyc_areas/_bulk
{ "index": { "_id": 1 } }
{ "borough": "Queens", "area_name": "Astoria", "description": "Astoria is a neighborhood in the western part of Queens, New York City, known for its diverse community and vibrant cultural scene.", "population": 93000, "facts": "Astoria is home to many artists and has a large Greek-American community. The area also boasts some of the best Mediterranean food in NYC." } 
{ "index": { "_id": 2 } }
{ "borough": "Queens", "area_name": "Flushing", "description": "Flushing is a neighborhood in the northern part of Queens, famous for its Asian-American population and bustling business district.", "population": 227000, "facts": "Flushing is one of the most ethnically diverse neighborhoods in NYC, with a large Chinese and Korean population. It is also home to the USTA Billie Jean King National Tennis Center." } 
{ "index": { "_id": 3 } }
{ "borough": "Brooklyn", "area_name": "Williamsburg", "description": "Williamsburg is a trendy neighborhood in Brooklyn known for its hipster culture, vibrant art scene, and excellent restaurants.", "population": 150000, "facts": "Williamsburg is a hotspot for young professionals and artists. The neighborhood has seen rapid gentrification over the past two decades." } 
{ "index": { "_id": 4 } }
{ "borough": "Manhattan", "area_name": "Harlem", "description": "Harlem is a historic neighborhood in Upper Manhattan, known for its significant African-American cultural heritage.", "population": 116000, "facts": "Harlem was the birthplace of the Harlem Renaissance, a cultural movement that celebrated Black culture through art, music, and literature." } 
{ "index": { "_id": 5 } }
{ "borough": "The Bronx", "area_name": "Riverdale", "description": "Riverdale is a suburban-like neighborhood in the Bronx, known for its leafy streets and affluent residential areas.", "population": 48000, "facts": "Riverdale is one of the most affluent areas in the Bronx, with beautiful parks, historic homes, and excellent schools." } 
{ "index": { "_id": 6 } }
{ "borough": "Staten Island", "area_name": "St. George", "description": "St. George is the main commercial and cultural center of Staten Island, offering stunning views of Lower Manhattan.", "population": 15000, "facts": "St. George is home to the Staten Island Ferry terminal and is a gateway to Staten Island, offering stunning views of the Statue of Liberty and Ellis Island." }

步骤 4:创建搜索管道

接下来,创建用于重排的搜索管道。在搜索管道配置中,input_mapoutput_map 定义了如何为交叉编码器模型准备输入数据以及如何解释模型的输出以进行重排:

  • input_map 指定了搜索文档和查询中的哪些字段应作为模型输入:
    • text 字段映射到索引文档中的 facts 字段。它提供了模型将分析的特定于文档的内容。
    • text_pair 字段从搜索请求中动态检索搜索查询文本 (multi_match.query)。

    text(文档 facts)和 text_pair(搜索 query)的组合允许交叉编码器模型比较文档与查询的相关性,同时考虑它们的语义关系。

  • output_map 字段指定了模型输出如何映射到响应中的字段:
    • 响应中的 rank_score 字段将存储模型的相关性分数,该分数将用于执行重排。

当使用 by_field 重排类型时,rank_score 字段将包含与 _score 字段相同的分数。要从搜索结果中删除 rank_score 字段,请将 remove_target_field 设置为 true。通过将 keep_previous_score 设置为 true,原始的 BM25 分数(重排前)会包含在内,用于调试目的。这使您能够将原始分数与重排后的分数进行比较,以评估搜索相关性的改进。

要创建搜索管道,请发送以下请求:

PUT /_search/pipeline/my_pipeline
{
  "response_processors": [
    {
      "ml_inference": {
        "tag": "ml_inference",
        "description": "This processor runs ml inference during search response",
        "model_id": "<model_id_from_step_3>",
        "function_name": "REMOTE",
        "input_map": [
          {
            "text": "facts",
            "text_pair":"$._request.query.multi_match.query"
          }
        ],
        "output_map": [
          {
            "rank_score": "$.score"
          }
        ],
        "full_response_path": false,
        "model_config": {},
        "ignore_missing": false,
        "ignore_failure": false,
        "one_to_one": true
      },
       
      "rerank": {
        "by_field": {
          "target_field": "rank_score",
          "remove_target_field": true,
          "keep_previous_score" : true
          }
      }
    
    }
  ]
}

步骤 5:使用重排进行搜索

使用以下请求搜索索引文档并使用交叉编码器模型对其进行重排。该请求检索在 descriptionfacts 字段中包含任何指定术语的文档。然后使用这些术语来比较和重排匹配的文档:

POST /nyc_areas/_search?search_pipeline=my_pipeline
{
  "query": {
    "multi_match": {
      "query": "artists art creative community",
      "fields": ["description", "facts"]
    }
  }
}

在响应中,previous_score 字段包含文档的 BM25 分数,如果未应用管道,它将获得该分数。请注意,虽然 BM25 将“Astoria”排在最高位,但交叉编码器模型优先考虑了“Harlem”,因为它匹配了更多的搜索词:

{
  "took": 4,
  "timed_out": false,
  "_shards": {
    "total": 1,
    "successful": 1,
    "skipped": 0,
    "failed": 0
  },
  "hits": {
    "total": {
      "value": 3,
      "relation": "eq"
    },
    "max_score": 0.03418137,
    "hits": [
      {
        "_index": "nyc_areas",
        "_id": "4",
        "_score": 0.03418137,
        "_source": {
          "area_name": "Harlem",
          "description": "Harlem is a historic neighborhood in Upper Manhattan, known for its significant African-American cultural heritage.",
          "previous_score": 1.6489418,
          "borough": "Manhattan",
          "facts": "Harlem was the birthplace of the Harlem Renaissance, a cultural movement that celebrated Black culture through art, music, and literature.",
          "population": 116000
        }
      },
      {
        "_index": "nyc_areas",
        "_id": "1",
        "_score": 0.0090838,
        "_source": {
          "area_name": "Astoria",
          "description": "Astoria is a neighborhood in the western part of Queens, New York City, known for its diverse community and vibrant cultural scene.",
          "previous_score": 2.519608,
          "borough": "Queens",
          "facts": "Astoria is home to many artists and has a large Greek-American community. The area also boasts some of the best Mediterranean food in NYC.",
          "population": 93000
        }
      },
      {
        "_index": "nyc_areas",
        "_id": "3",
        "_score": 0.0032599436,
        "_source": {
          "area_name": "Williamsburg",
          "description": "Williamsburg is a trendy neighborhood in Brooklyn known for its hipster culture, vibrant art scene, and excellent restaurants.",
          "previous_score": 1.5632852,
          "borough": "Brooklyn",
          "facts": "Williamsburg is a hotspot for young professionals and artists. The neighborhood has seen rapid gentrification over the past two decades.",
          "population": 150000
        }
      }
    ]
  },
  "profile": {
    "shards": []
  }
}