🤗 使用 SetFit 训练情感分类器#
在本教程中,我们将使用 SetFit 和 Argilla 为银行领域的用户请求构建情感分类器。
SetFit 是一个令人兴奋的开源软件包,由 Hugging Face 和 Intel Labs 的团队开发,用于少样本分类。你可以在项目仓库中阅读所有相关信息。
Argilla 使你能够快速构建和迭代 NLP 的训练数据。
让我们看看如何将它们结合起来从头开始构建情感分类器!
简介#
本教程将向你展示如何微调你自己的领域的情感分类器,从没有标记数据开始。
大多数关于微调模型的在线教程都假设你已经有一个训练数据集。你会找到许多关于使用广泛使用的数据集(例如,用于情感分析的 IMDB)微调预训练模型的教程。
但是,通常你想要的是为你自己的用例微调模型。众所周知,NLP 模型性能通常会随着“域外”数据而降低。例如,在电影评论(例如,IMDB)上预训练的情感分类器在客户请求方面的表现不会很好。
运行 Argilla#
在本教程中,你将需要运行 Argilla 服务器。 有两个主要的选项用于部署和运行 Argilla
在 Hugging Face Spaces 上部署 Argilla:如果你想使用外部 notebook(例如,Google Colab)运行教程,并且你在 Hugging Face 上有一个帐户,你可以在 Spaces 上点击几下部署 Argilla
有关配置部署的详细信息,请查看官方 Hugging Face Hub 指南。
使用 Argilla 的快速入门 Docker 镜像启动 Argilla: 如果你想在本地机器上运行 Argilla,这是推荐的选项。 请注意,此选项仅允许你在本地运行教程,而不能与外部 notebook 服务一起运行。
有关部署选项的更多信息,请查看文档的部署部分。
提示
本教程是一个 Jupyter Notebook。 有两个选项可以运行它
使用此页面顶部的“在 Colab 中打开”按钮。 此选项允许你直接在 Google Colab 上运行 notebook。 不要忘记将运行时类型更改为 GPU,以加快模型训练和推理速度。
通过单击页面顶部的“查看源代码”链接下载 .ipynb 文件。 此选项允许你下载 notebook 并在本地机器或你选择的 Jupyter Notebook 工具上运行它。
[ ]:
%pip install argilla setfit datasets -qqq
让我们导入 Argilla 模块以进行数据读取和写入
[ ]:
import argilla as rg
如果你正在使用 Docker 快速入门镜像或 Hugging Face Spaces 运行 Argilla,你需要使用 URL
和 API_KEY
初始化 Argilla 客户端
[ ]:
# Replace api_url with the url to your HF Spaces URL if using Spaces
# Replace api_key if you configured a custom API key
# Replace workspace with the name of your workspace
rg.init(
api_url="https://#:6900",
api_key="owner.apikey",
workspace="admin"
)
如果你正在运行私有的 Hugging Face Space,你还需要如下设置 HF_TOKEN
[ ]:
# # Set the HF_TOKEN environment variable
# import os
# os.environ['HF_TOKEN'] = "your-hf-token"
# # Replace api_url with the url to your HF Spaces URL
# # Replace api_key if you configured a custom API key
# # Replace workspace with the name of your workspace
# rg.init(
# api_url="https://[your-owner-name]-[your_space_name].hf.space",
# api_key="owner.apikey",
# workspace="admin",
# extra_headers={"Authorization": f"Bearer {os.environ['HF_TOKEN']}"},
# )
最后,让我们包含我们需要的导入
[2]:
from datasets import load_dataset
from sentence_transformers.losses import CosineSimilarityLoss
from setfit import SetFitModel, SetFitTrainer
启用遥测#
我们从你与我们的教程互动的方式中获得宝贵的见解。 为了改进我们自己,为您提供最合适的内容,使用以下代码行将帮助我们了解本教程是否有效地为您服务。 虽然这是完全匿名的,但如果你愿意,可以选择跳过此步骤。 有关更多信息,请查看遥测页面。
[ ]:
try:
from argilla.utils.telemetry import tutorial_running
tutorial_running()
except ImportError:
print("Telemetry is introduced in Argilla 1.20.0 and not found in the current installation. Skipping telemetry.")
源数据集:banking77
#
banking77,可在 Hugging Face Hub 上找到,包含在线银行用户查询,并使用其相应的意图进行注释。
在我们的例子中,我们将标记这些查询的情感。 这可能对数字助理和客户服务分析很有用。
让我们直接从 Hub 加载数据集。
[ ]:
banking_ds = load_dataset("banking77", split="train")
让我们使用 Pandas head
方法预览数据集的内容
[5]:
banking_ds.to_pandas().head(15)
[5]:
text | label | |
---|---|---|
0 | 我还在等我的卡吗? | 11 |
1 | 如果我的卡仍然没有到,我该怎么办... | 11 |
2 | 我已经等了一个多星期了。卡片是... | 11 |
3 | 我可以在我的卡在处理过程中跟踪它吗?... | 11 |
4 | 我怎么知道我是否会收到我的卡,或者它是否... | 11 |
5 | 你什么时候寄给我的新卡? | 11 |
6 | 你有关于卡片交付的信息吗? | 11 |
7 | 如果我仍然没有收到我的新卡,我该怎么办?... | 11 |
8 | 我的卡片包裹有跟踪吗? | 11 |
9 | 我订购了我的卡,但它仍然没有到这里 | 11 |
10 | 为什么我的新卡还没有到? | 11 |
11 | 两个星期后我仍然没有收到我的卡... | 11 |
12 | 你能为我跟踪我的卡吗? | 11 |
13 | 有没有办法跟踪我的卡的交付? | 11 |
14 | 自从我订购我的卡已经一个星期了,我... | 11 |
关于情感分析和数据注释的说明#
情感分析是 NLP 中最主观的任务之一。 我们对情感的理解会因一个应用而异,并取决于项目的业务目标。 此外,情感可以用不同的方式建模,从而导致不同的标记方案。
例如,情感可以建模为实值(从 -1 到 1,从 0 到 1.0 等)或使用 2 个或更多标签(包括不同的程度,例如正面、负面、中性等)
在本教程中,我们将使用以下标记方案:POSITIVE
、NEGATIVE
和 NEUTRAL
。
1. 加载数据集并标记一些示例#
[ ]:
argilla_ds = rg.read_datasets(banking_ds, task="TextClassification")
rg.log(argilla_ds, "banking_sentiment")
2. 手动标记#
在此步骤中,你可以使用 Argilla UI 标记一些示例(例如,50 个示例)。
标记了一些示例后,你可以读取并准备数据以训练你的 SetFit 模型。
注意
如果你现在没有时间进行标记,我们已经使用 Argilla 标记了一个小型数据集,并将其推送到 Hugging Face Hub。
要使用它,请将以下单元格替换为此代码
`labelled_ds = load_dataset("argilla/sentiment-banking-setfit")`
[8]:
labelled_ds = rg.load("banking_sentiment").prepare_for_training()
labelled_ds = labelled_ds.train_test_split()
labelled_ds
[8]:
DatasetDict({
train: Dataset({
features: ['text', 'label'],
num_rows: 108
})
test: Dataset({
features: ['text', 'label'],
num_rows: 36
})
})
3. 训练我们的 SetFit 情感分类器#
[9]:
model = SetFitModel.from_pretrained("sentence-transformers/paraphrase-mpnet-base-v2")
# Create trainer
trainer = SetFitTrainer(
model=model,
train_dataset=labelled_ds["train"],
eval_dataset=labelled_ds["test"],
loss_class=CosineSimilarityLoss,
batch_size=8,
num_iterations=20,
)
trainer.train()
metrics = trainer.evaluate()
metrics
model_head.pkl not found on HuggingFace Hub, initialising classification head with random weights.
108 train samples in total, 540 train steps with batch size 8
[9]:
{'accuracy': 0.8611111111111112}
这里我们使用最简单的方法来训练我们的 SetFit 模型。 由于它与 Optuna 集成,你可以使用超参数调整来找到训练模型的最佳超参数。 但是,最好从一个简单的基线开始,验证模型的用例,并在专注于模型实验和调整之前迭代数据。
总结#
在本教程中,你学习了如何从头开始构建训练集,并为你自己的问题训练情感分类器。
虽然这是一个简单的示例,但你可以将相同的过程应用于你自己的用例。
在这里,我们介绍了一种构建训练集的方法:手动标记。
如果你对 SetFit 感兴趣,你可以查看我们的其他 SetFit + Argilla 教程
或者查看 GitHub 上的 SetFit 仓库。
如果想了解其他方法,如弱监督或主动学习,请查看以下教程