🔫 使用 SetFit 进行零样本和少样本分类#
在本教程中,您将学习如何使用 Sentence Transformer 嵌入和 SetFit 的零样本和少样本能力,以显着加快数据标注速度。本教程将引导您完成以下步骤
💾 使用 sentence transformers 生成具有银行客户请求的数据集的嵌入。
🔫 使用 SetFit 的零样本分类器并上传其预测以及嵌入。
🏷 利用相似性搜索和预先标注的示例,高效地标注批量语义相关、高影响力的示例。
🦾 训练一个少样本 SetFit 模型,以提高零样本模型的结果。
简介#
在本教程中,我们将利用嵌入的强大功能来提高数据标注(和管理)的效率。结合 SetFit 的零样本和少样本能力,这种方法将大大减少使用您自己的数据获得高质量模型的时间。
开始吧!
运行 Argilla#
在本教程中,您需要运行 Argilla 服务器。部署和运行 Argilla 主要有两种选择
在 Hugging Face Spaces 上部署 Argilla:如果您想使用外部 notebook(例如,Google Colab)运行教程,并且您在 Hugging Face 上有一个帐户,您只需点击几下即可在 Spaces 上部署 Argilla
有关配置部署的详细信息,请查看 Hugging Face Hub 官方指南。
使用 Argilla 的快速入门 Docker 镜像启动 Argilla:如果您想在 本地机器上运行 Argilla,这是推荐的选项。请注意,此选项仅允许您在本地运行教程,而不能与外部 notebook 服务一起运行。
有关部署选项的更多信息,请查看文档的部署部分。
提示
本教程是一个 Jupyter Notebook。有两种运行方式
使用此页面顶部的“在 Colab 中打开”按钮。此选项允许您直接在 Google Colab 上运行 notebook。不要忘记将运行时类型更改为 GPU,以加快模型训练和推理速度。
单击页面顶部的“查看源代码”链接下载 .ipynb 文件。此选项允许您下载 notebook 并在本地机器或您选择的 Jupyter notebook 工具上运行它。
设置#
在本教程中,您需要 Argilla 的 Python 客户端和一些可以通过 pip
安装的第三方库
[ ]:
%pip install argilla datasets==2.8.0 sentence-transformers==2.2.2 setfit==0.6.0 plotly==4.1.0 -qqq
让我们导入 Argilla 模块以进行数据读取和写入
[1]:
import argilla as rg
如果您使用 Docker 快速入门镜像或 Hugging Face Spaces 运行 Argilla,您需要使用 URL
和 API_KEY
初始化 Argilla 客户端
[2]:
# Replace api_url with the url to your HF Spaces URL if using Spaces
# Replace api_key if you configured a custom API key
rg.init(
api_url="https://#:6900",
api_key="admin.apikey"
)
如果您正在运行私有的 Hugging Face Space,您还需要设置 HF_TOKEN,如下所示
[ ]:
# # Set the HF_TOKEN environment variable
# import os
# os.environ['HF_TOKEN'] = "your-hf-token"
# # Replace api_url with the url to your HF Spaces URL
# # Replace api_key if you configured a custom API key
# rg.init(
# api_url="https://[your-owner-name]-[your_space_name].hf.space",
# api_key="admin.apikey",
# extra_headers={"Authorization": f"Bearer {os.environ['HF_TOKEN']}"},
# )
让我们添加所需的导入
[ ]:
from sentence_transformers import SentenceTransformer
from sentence_transformers.losses import CosineSimilarityLoss
from datasets import load_dataset
from setfit import get_templated_dataset
from setfit import SetFitModel, SetFitTrainer
启用遥测#
我们从您与教程的互动中获得宝贵的见解。为了改进自身,为您提供最合适的内容,使用以下代码行将帮助我们了解本教程是否有效地为您服务。虽然这是完全匿名的,但如果您愿意,可以选择跳过此步骤。有关更多信息,请查看 遥测 页面。
[ ]:
try:
from argilla.utils.telemetry import tutorial_running
tutorial_running()
except ImportError:
print("Telemetry is introduced in Argilla 1.20.0 and not found in the current installation. Skipping telemetry.")
💾 嵌入您的数据集#
下面的代码将从 Hub 加载 banking customer requests dataset,编码 text
字段,并创建 vectors
字段,该字段将仅包含一个键 (mini-lm-sentence-transformers
)。为了从头开始标注数据集,它还将删除包含原始意图标签的 label
字段。
[ ]:
# Define fast version of sentence transformers, change to cuda if available
encoder = SentenceTransformer("all-MiniLM-L6-v2", device="cuda")
# Load dataset with banking
dataset = load_dataset("banking77", split="test")
# Encode text field using batched computation
dataset = dataset.map(
lambda batch: {"vectors": encoder.encode(batch["text"])},
batch_size=32,
batched=True
)
# Removes the original labels because you'll be labeling from scratch
dataset = dataset.remove_columns("label")
# Turn vectors into a dictionary
dataset = dataset.map(
lambda r: {"vectors": {"mini-lm-sentence-transformers": r["vectors"]}}
)
我们的数据集现在包含一个 vectors
字段,其中包含由 sentence transformer 模型生成的嵌入向量。
[ ]:
dataset.to_pandas().head()
文本 | 向量 | |
---|---|---|
0 | 如何找到我的卡? | {'mini-lm-sentence-transformers': [-0.01016701... |
1 | 我仍然没有收到我的新卡,我订购了... | {'mini-lm-sentence-transformers': [-0.04284121... |
2 | 我订购了一张卡,但尚未到达。请帮忙 ... | {'mini-lm-sentence-transformers': [-0.03365556... |
3 | 有没有办法知道我的卡什么时候到? | {'mini-lm-sentence-transformers': [0.012195922... |
4 | 我的卡还没到。 | {'mini-lm-sentence-transformers': [-0.04361867... |
🔫 使用 SetFit 进行零样本预测#
原始的 banking77
数据集是一个意图分类数据集,包含数十个标签(lost_card
、card_arrival
等)。为了使本教程简单易懂,我们定义了一个简化的标签方案,其中包含更高级别的类别。
让我们设置并训练我们的零样本 SetFit 模型。请注意,SetFit 的零样本方法是创建合成的训练示例数据集,这与其他方法(例如,transformers 零样本 pipelines)不同,在其他方法中,在推理时使用带有标签名称的“模板化”示例。
[ ]:
labels = ["change details", "card", "atm", "top up", "balance", "transfer", "exchange rate", "pin"]
train_dataset = get_templated_dataset(
candidate_labels=labels,
sample_size=8,
template="The customer request is about {}"
)
model = SetFitModel.from_pretrained("all-MiniLM-L6-v2")
trainer = SetFitTrainer(
model=model,
train_dataset=train_dataset
)
trainer.train()
我们可以使用我们训练的零样本模型来预测数据集。稍后,我们将这些预测加载到 Argilla 中,并使用它们来加速标注过程。
[ ]:
def get_predictions(texts):
probas = model.predict_proba(texts, as_numpy=True)
for pred in probas:
yield [{"label": label, "score": score} for label, score in zip(labels, pred)]
dataset = dataset.map(lambda batch: {"prediction": list(get_predictions(batch["text"]))}, batched=True)
让我们上传包含向量和预测的数据集。
[ ]:
rg_ds = rg.DatasetForTextClassification.from_datasets(dataset)
rg.log(
name="banking77-topics-setfit",
records=rg_ds,
chunk_size=50,
)
🏷 使用 find similar
和零样本预测进行批量标注#
现在,我们的 banking77-topics-setfit
可以从 Argilla UI 中获得。您可以开始利用相似性搜索和我们的零样本预测来注释数据。转到您的 Argilla UI URL 后,工作流程如下
标注一条记录(例如,将“更改我的信息”标注为
change details
),然后单击记录右上角的“查找相似记录”。结果,您将获得按相似性排序的最相似记录列表及其相应的预测。
您现在可以查看预测,验证它们或更正它们。
在标注大约 200 条记录后,我们准备评估我们的零样本模型,让我们看看如何操作!
📏 评估零样本模型#
我们可以使用 Argilla 的内置指标来计算 f1
,基于 (1) 我们在本教程开始时存储的零样本模型的预测,以及 (2) 手动注释。请注意,在标注过程中,我们添加了一个新的标签 Other
,以解释不属于我们预定义类别的示例。这突出了在项目定义早期进行迭代的重要性。Argilla 为用户提供了很大的灵活性,预测和相似性搜索等功能可以帮助比传统数据注释工具更快地发现潜在问题和改进之处。
[3]:
from argilla.metrics.text_classification import f1
f1(name="banking77-topics-setfit").visualize()
🦾 训练少样本 SetFit 模型#
即使零样本方法给出了不错的结果(约 0.86 F1),我们也可以使用标注的数据集来训练少样本模型,以获得约 0.95 的准确率。
[ ]:
# Load the hand-labelled dataset from Argilla
ds = rg.load("banking77-topics-setfit").prepare_for_training(train_size=0.8)
# Load SetFit model from Hub
# Feel free to experiment with other larger models, e.g. "sentence-transformers/paraphrase-mpnet-base-v2"
model = SetFitModel.from_pretrained("all-MiniLM-L6-v2")
# Create trainer
trainer = SetFitTrainer(
model=model,
train_dataset=ds["train"],
eval_dataset=ds["test"],
loss_class=CosineSimilarityLoss,
batch_size=16,
num_iterations=20,
)
# Train and evaluate
trainer.train()
metrics = trainer.evaluate()
print(metrics)
总结#
在本教程中,您学习了如何利用嵌入和 SetFit 的零样本预测来构建训练数据集。之后,您学习了如何训练 SetFit 模型以改进零样本模型的结果。
如果您对 SetFit 感兴趣,可以查看我们其他的 SetFit 与 Argilla 教程
或查看 GitHub 上的 SetFit 存储库。