🤯 使用 SetFit 进行少样本分类#
SetFit 是一个令人兴奋的开源软件包,用于少样本分类,由 Hugging Face 和 Intel Labs 的团队开发。您可以在项目仓库上阅读所有相关信息。
为了展示 SetFit 和 Argilla 结合使用 的强大之处
我们手动从 IMDb 数据集的未标记拆分中 标注了 55 个示例,
我们在 5 分钟 内训练了一个模型,
并且在没有使用原始 IMDb 训练集中的任何示例的情况下,我们在完整的测试集上实现了 0.9 的准确率!
摘要#
在本教程中,您将学习如何
在 Argilla 中加载未标记的数据集。我们将使用
imdb
电影评论情感数据集中的未标记拆分。相同的工作流程可以应用于任何自定义数据集、问题和语言!使用 UI 手动 标注少量示例。
训练 SetFit 模型 以获得极具竞争力的结果。对于此示例,仅使用 55 个示例,我们在测试集上获得了 0.9 的准确率,这与在 3K 个示例上微调的模型相当。这意味着以
50x
更少的示例获得相似的性能 🤯。
有关参考,请参阅 Hugging Face Hub 和 PapersWithCode 排行榜。
让我们开始吧!
运行 Argilla#
对于本教程,您需要运行 Argilla 服务器。部署和运行 Argilla 有两个主要选项
在 Hugging Face Spaces 上部署 Argilla:如果您想使用外部笔记本(例如,Google Colab)运行教程,并且您在 Hugging Face 上有一个帐户,您只需点击几下即可在 Spaces 上部署 Argilla
有关配置部署的详细信息,请查看 Hugging Face Hub 官方指南。
使用 Argilla 的快速入门 Docker 镜像启动 Argilla:如果您想在 本地机器上运行 Argilla,这是推荐选项。请注意,此选项仅允许您在本地运行教程,而不能与外部笔记本服务一起运行。
有关部署选项的更多信息,请查看文档的部署部分。
提示
本教程是一个 Jupyter Notebook。有两种运行它的选项
使用此页面顶部的“在 Colab 中打开”按钮。此选项允许您直接在 Google Colab 上运行笔记本。不要忘记将运行时类型更改为 GPU 以加快模型训练和推理速度。
单击页面顶部的“查看源代码”链接下载 .ipynb 文件。此选项允许您下载笔记本并在本地机器或您选择的 Jupyter Notebook 工具上运行它。
[ ]:
%pip install argilla "setfit~=0.2.0" "datasets~=2.3.0" -qqq
让我们导入 Argilla 模块以进行数据读取和写入
[ ]:
import argilla as rg
如果您使用 Docker 快速入门镜像或 Hugging Face Spaces 运行 Argilla,您需要使用 URL
和 API_KEY
初始化 Argilla 客户端
[ ]:
# Replace api_url with the url to your HF Spaces URL if using Spaces
# Replace api_key if you configured a custom API key
# Replace workspace with the name of your workspace
rg.init(
api_url="https://#:6900",
api_key="owner.apikey",
workspace="admin"
)
如果您正在运行私有的 Hugging Face Space,您还需要按如下方式设置 HF_TOKEN
[ ]:
# # Set the HF_TOKEN environment variable
# import os
# os.environ['HF_TOKEN'] = "your-hf-token"
# # Replace api_url with the url to your HF Spaces URL
# # Replace api_key if you configured a custom API key
# # Replace workspace with the name of your workspace
# rg.init(
# api_url="https://[your-owner-name]-[your_space_name].hf.space",
# api_key="owner.apikey",
# workspace="admin",
# extra_headers={"Authorization": f"Bearer {os.environ['HF_TOKEN']}"},
# )
让我们导入所需的模块
[ ]:
from datasets import load_dataset
from sentence_transformers.losses import CosineSimilarityLoss
from setfit import SetFitModel, SetFitTrainer
启用遥测#
我们从您与教程的互动中获得宝贵的见解。为了改进自身,为您提供最合适的内容,使用以下代码行将帮助我们了解本教程是否有效地为您服务。虽然这是完全匿名的,但如果您愿意,可以选择跳过此步骤。有关更多信息,请查看遥测页面。
[ ]:
try:
from argilla.utils.telemetry import tutorial_running
tutorial_running()
except ImportError:
print("Telemetry is introduced in Argilla 1.20.0 and not found in the current installation. Skipping telemetry.")
在 Argilla 中加载未标记的数据集#
首先,我们从 imdb
数据集中加载 unsupervised
拆分,并创建一个包含 100 个随机示例的新 Argilla 数据集
[ ]:
unlabelled = (
load_dataset("imdb", split="unsupervised").shuffle(seed=42).select(range(100))
)
unlabelled = rg.DatasetForTextClassification.from_datasets(unlabelled)
rg.log(unlabelled, "imdb_unlabelled")
手动标注#
在此步骤中,我们使用与原始数据集相同的标签方案创建标签 pos
和 neg
。然后我们使用 UI 顺序标注一些示例。对于这个例子,我们实际上只花了 15 分钟。
在训练之前,您可以使用 push_to_hub
方法轻松共享数据集。如果您在机器上没有 GPU,并且想要使用训练服务或 Colab 等,这可能很有用。更多信息请访问 此处
[ ]:
rg.load("imdb_unlabelled").prepare_for_training().push_to_hub("mini-imdb")
训练和评估 SetFit 模型#
最后,我们准备好测试 SetFit 了!
感谢 Argilla 与 datasets
和 Hub 的集成,如果您没有本地 GPU,您可以使用此 Google Colab 使用标注的数据集重现训练过程。如果您使用 GPU 运行时,则训练只需 5 分钟。
下面我们从 Argilla 加载数据集,将其格式化为使用 transformers 进行训练,加载完整的 IMDb 测试数据集,加载预训练的 sentence transformers 模型,训练 SetFit 模型并对其进行评估!
[ ]:
# Load the hand-labeled dataset from Argilla
train_ds = rg.load("imdb_unlabelled").prepare_for_training()
# Load the full IMDb test dataset
test_ds = load_dataset("imdb", split="test")
# Load SetFit model from Hub
model = SetFitModel.from_pretrained("sentence-transformers/paraphrase-mpnet-base-v2")
# Create trainer
trainer = SetFitTrainer(
model=model,
train_dataset=train_ds,
eval_dataset=test_ds,
loss_class=CosineSimilarityLoss,
batch_size=16,
num_iterations=20, # The number of text pairs to generate
)
# Train and evaluate
trainer.train()
metrics = trainer.evaluate()
可选地,您可以与世界分享您出色的模型!
[ ]:
trainer.push_to_hub("setfit-mini-imdb")
结论#
指标对象应该在完整测试集上为您提供大约 0.9 的准确率 🎉
请记住
我们手动标注了 55 个示例,
我们没有使用原始训练集中的任何示例,
并且我们在 5 分钟内训练了模型!
现在,我认为您没有任何借口不花一些时间标注一些高质量的示例了!
如果您对 SetFit 感兴趣,您可以查看我们的其他 SetFit + Argilla 教程
或者查看 GitHub 上的 SetFit 仓库。