💾 监控 FastAPI 模型端点#
在本教程中,你将学习如何监控 FastAPI 推理端点的预测,并将模型预测记录在 Argilla 数据集中。它将引导你完成 4 个基本的 MLOps 步骤
💾 加载你想使用的模型。
🔄 将模型输出转换为 Argilla 格式。
💻 创建一个 FastAPI 端点。
🤖 添加中间件以自动化记录到 Argilla 的过程
简介#
模型通常通过 HTTP API 端点部署,客户端调用该端点以获取模型的预测。借助 FastAPI 和 Argilla,你可以轻松监控这些预测,并将它们记录到 Argilla 数据集中。由于其以人为中心的 UX,Argilla 数据集可以被你组织的任何团队成员舒适地查看和探索。此外,Argilla 还提供自动计算的指标,这两者都有助于你跟踪你的预测器并及早发现潜在问题。
FastAPI 和 Argilla 允许你部署和监控任何你喜欢的模型,但在本教程中,我们将重点关注 NLP 领域中最常见的两个框架:spaCy 和 transformers。让我们开始吧!
运行 Argilla#
在本教程中,你将需要运行一个 Argilla 服务器。部署和运行 Argilla 主要有两种选择
在 Hugging Face Spaces 上部署 Argilla:如果你想使用外部笔记本(例如 Google Colab)运行教程,并且你在 Hugging Face 上有一个帐户,你可以通过几次点击在 Spaces 上部署 Argilla
有关配置部署的详细信息,请查看 官方 Hugging Face Hub 指南。
使用 Argilla 的快速入门 Docker 镜像启动 Argilla:如果你想在 本地机器上运行 Argilla,这是推荐的选项。请注意,此选项仅允许你在本地运行本教程,而不能使用外部笔记本服务。
有关部署选项的更多信息,请查看文档的“部署”部分。
提示
本教程是一个 Jupyter Notebook。有两种运行它的选项
使用此页面顶部的“在 Colab 中打开”按钮。此选项允许你直接在 Google Colab 上运行笔记本。不要忘记将运行时类型更改为 GPU,以加快模型训练和推理速度。
通过点击页面顶部的“查看源代码”链接下载 .ipynb 文件。此选项允许你下载笔记本并在本地机器或你选择的 Jupyter 笔记本工具上运行它。
设置#
要完成本教程,你将需要使用 pip
安装 Argilla 客户端和一些第三方库
[ ]:
%pip install argilla fastapi uvicorn[standard] spacy transformers[torch] -qqq
让我们导入 Argilla 模块以进行数据读取和写入
[ ]:
import argilla as rg
如果你正在使用 Docker 快速入门镜像或 Hugging Face Spaces 运行 Argilla,则需要使用 URL
和 API_KEY
初始化 Argilla 客户端
[ ]:
# Replace api_url with the url to your HF Spaces URL if using Spaces
# Replace api_key if you configured a custom API key
rg.init(
api_url="https://#:6900",
api_key="admin.apikey"
)
如果你正在运行私有的 Hugging Face Space,你还需要按如下方式设置 HF_TOKEN
[ ]:
# # Set the HF_TOKEN environment variable
# import os
# os.environ['HF_TOKEN'] = "your-hf-token"
# # Replace api_url with the url to your HF Spaces URL
# # Replace api_key if you configured a custom API key
# rg.init(
# api_url="https://[your-owner-name]-[your_space_name].hf.space",
# api_key="admin.apikey",
# extra_headers={"Authorization": f"Bearer {os.environ['HF_TOKEN']}"},
# )
最后,让我们包含我们需要的导入
[74]:
from fastapi import FastAPI
import spacy
from transformers import pipeline
from typing import List
# for adding logging to API endpoints
from argilla.monitoring.asgi import (
ArgillaLogHTTPMiddleware,
text_classification_mapper,
token_classification_mapper,
)
# Instantiate our FastAPI app
app = FastAPI()
启用遥测#
我们从你与我们教程的互动中获得宝贵的见解。为了改进我们自身,以便为你提供最合适的内容,使用以下代码行将帮助我们了解本教程是否有效地为你服务。虽然这是完全匿名的,但如果你愿意,可以选择跳过此步骤。有关更多信息,请查看遥测页面。
[ ]:
try:
from argilla.utils.telemetry import tutorial_running
tutorial_running()
except ImportError:
print("Telemetry is introduced in Argilla 1.20.0 and not found in the current installation. Skipping telemetry.")
1. 加载模型#
作为第一步,让我们加载我们的模型。对于 spaCy,我们需要先下载模型,然后才能使用它实例化 spaCy 管道。这里我们使用小型英语模型 en_core_web_sm
,但你可以在他们的 hub 上选择任何可用的模型。
[ ]:
!python -m spacy download en_core_web_sm
transformers 的“文本分类”管道会为你下载模型,默认情况下它将使用 distilbert-base-uncased-finetuned-sst-2-english
模型。但是你可以在他们的 hub 上使用任何兼容的模型实例化管道。
[76]:
spacy_pipeline = spacy.load("en_core_web_sm")
transformers_pipeline = pipeline("text-classification", return_all_scores=True)
No model was supplied, defaulted to distilbert-base-uncased-finetuned-sst-2-english and revision af0f99b (https://hugging-face.cn/distilbert-base-uncased-finetuned-sst-2-english).
Using a pipeline without specifying a model name and revision in production is not recommended.
/usr/local/lib/python3.8/dist-packages/transformers/pipelines/text_classification.py:104: UserWarning: `return_all_scores` is now deprecated, if want a similar funcionality use `top_k=None` instead of `return_all_scores=True` or `top_k=1` instead of `return_all_scores=False`.
warnings.warn(
有关将 transformers
库与 Argilla 一起使用的更多信息,请查看教程 如何标注你的数据并微调 🤗 情感分类器
模型输出#
让我们在这个例子中尝试 transformers 的管道
[78]:
batch = ["I really like argilla!"]
predictions = transformers_pipeline(batch)
print(predictions)
[[{'label': 'NEGATIVE', 'score': 0.0029897126369178295}, {'label': 'POSITIVE', 'score': 0.9970102310180664}]]
看起来 predictions
是一个列表,其中包含两个元素列表
第一个字典包含
NEGATIVE
情感标签及其分数。第二个字典包含相同的数据,但用于
POSITIVE
情感。
2. 将输出转换为 Argilla 格式#
要将输出记录到 Argilla,我们应该提供一个字典列表,每个字典包含两个键
labels
:值是一个字符串列表,每个字符串都是情感的标签。scores
:值是一个浮点数列表,每个浮点数都是情感的概率。
[79]:
argilla_format = [
{
"labels": [p["label"] for p in prediction],
"scores": [p["score"] for p in prediction],
}
for prediction in predictions
]
argilla_format
[79]:
[{'labels': ['NEGATIVE', 'POSITIVE'],
'scores': [0.0029897126369178295, 0.9970102310180664]}]
3. 创建预测端点#
[80]:
# prediction endpoint using transformers pipeline
@app.post("/sentiment/")
def predict_transformers(batch: List[str]):
predictions = transformers_pipeline(batch)
return [
{
"labels": [p["label"] for p in prediction],
"scores": [p["score"] for p in prediction],
}
for prediction in predictions
]
4. 将 Argilla 日志记录中间件添加到应用程序#
[82]:
def text2records(batch: List[str], outputs: List[dict]):
return [
text_classification_mapper(data, prediction)
for data, prediction in zip(batch, outputs)
]
app.add_middleware(
ArgillaLogHTTPMiddleware,
api_endpoint="/transformers/", # the endpoint that will be logged
dataset="monitoring_transformers", # your dataset name
records_mapper=text2records, # your post-process func to adapt service inputs and outputs into an Argilla record
)
5. 使用 spaCy 的 NER 端点#
我们将添加一个自定义映射器,以将 spaCy 的输出转换为 TokenClassificationRecord
格式
[83]:
def token2records(batch: List[str], outputs: List[dict]):
return [
token_classification_mapper(data, prediction)
for data, prediction in zip(batch, outputs)
]
app.add_middleware(
ArgillaLogHTTPMiddleware,
api_endpoint="/spacy/",
dataset="monitoring_spacy",
records_mapper=token2records,
)
# prediction endpoint using spacy pipeline
@app.post("/ner/")
def predict_spacy(batch: List[str]):
predictions = []
for text in batch:
doc = spacy_pipeline(text) # spaCy Doc creation
# Entity annotations
entities = [
{"label": ent.label_, "start": ent.start_char, "end": ent.end_char}
for ent in doc.ents
]
prediction = {
"text": text,
"entities": entities,
}
predictions.append(prediction)
return predictions
现在我们可以添加方法来检查服务器是否已启动并运行
[85]:
@app.get("/")
def root():
return {"message": "alive"}
启动和测试 API#
要启动应用程序,请将整个代码复制到一个名为 main.py
的文件中(你可以在附录中找到完整的文件内容)。
创建文件后,你可以运行以下命令。我们添加 nohup 命令是为了防止你在 Colab 上运行,否则你可以删除它
[ ]:
!nohup uvicorn main:app
测试我们的端点和预测日志记录#
如果我们现在开始调用我们的 API,我们的模型输入和输出应该被记录到它们对应的 Argilla Datasets 中
[8]:
import requests
response = requests.post(
"https://#:8000/sentiment/",
json=["I like Argilla", "I hated data labelling but now I don't"]
)
response.content
[8]:
b'[{"labels":["NEGATIVE","POSITIVE"],"scores":[0.8717259168624878,0.128274068236351]},{"labels":["NEGATIVE","POSITIVE"],"scores":[0.9916356801986694,0.008364332839846611]}]'
如果一切顺利,你应该在你的 Argilla monitoring_transformers
数据集上看到两个新记录。
总结#
在本教程中,我们学习了如何自动将模型输入和输出记录到 Argilla 中。这可以用于持续和透明地监控 HTTP 推理端点。
附录:main.py
完整代码#
[31]:
import argilla
from fastapi import FastAPI
from typing import List
import spacy
from transformers import pipeline
from argilla.monitoring.asgi import (
ArgillaLogHTTPMiddleware,
text_classification_mapper,
token_classification_mapper,
)
spacy_pipeline = spacy.load("en_core_web_sm")
transformers_pipeline = pipeline("text-classification", return_all_scores=True)
app = FastAPI()
# prediction endpoint using transformers pipeline
@app.post("/sentiment/")
def predict_transformers(batch: List[str]):
predictions = transformers_pipeline(batch)
return [
{
"labels": [p["label"] for p in prediction],
"scores": [p["score"] for p in prediction],
}
for prediction in predictions
]
def text2records(batch: List[str], outputs: List[dict]):
return [
text_classification_mapper(data, prediction)
for data, prediction in zip(batch, outputs)
]
app.add_middleware(
ArgillaLogHTTPMiddleware,
api_endpoint="/transformers/", # the endpoint that will be logged
dataset="monitoring_transformers", # your dataset name
records_mapper=text2records, # your post-process func to adapt service inputs and outputs into an Argilla record
)
def token2records(batch: List[str], outputs: List[dict]):
return [
token_classification_mapper(data, prediction)
for data, prediction in zip(batch, outputs)
]
# prediction endpoint using spacy pipeline
@app.post("/ner/")
def predict_spacy(batch: List[str]):
predictions = []
for text in batch:
doc = spacy_pipeline(text) # spaCy Doc creation
# Entity annotations
entities = [
{"label": ent.label_, "start": ent.start_char, "end": ent.end_char}
for ent in doc.ents
]
prediction = {
"text": text,
"entities": entities,
}
predictions.append(prediction)
return predictions
app.add_middleware(
ArgillaLogHTTPMiddleware,
api_endpoint="/ner/",
dataset="monitoring_spacy",
records_mapper=token2records,
)
app.add_middleware(
ArgillaLogHTTPMiddleware,
api_endpoint="/sentiment/",
dataset="monitoring_transformers",
records_mapper=text2records,
)
@app.get("/")
def root():
return {"message": "alive"}
argilla.init(
api_url="https://#:6900",
api_key="admin.apikey"
)