Open In Colab  View Notebook on GitHub

🤯 使用 SetFit 进行少样本分类#

SetFit 是一个令人兴奋的开源软件包,用于少样本分类,由 Hugging Face 和 Intel Labs 的团队开发。您可以在项目仓库上阅读所有相关信息。

为了展示 SetFit 和 Argilla 结合使用 的强大之处

  • 我们手动从 IMDb 数据集的未标记拆分中 标注了 55 个示例

  • 我们在 5 分钟 内训练了一个模型,

  • 并且在没有使用原始 IMDb 训练集中的任何示例的情况下,我们在完整的测试集上实现了 0.9 的准确率!

摘要#

在本教程中,您将学习如何

  1. 在 Argilla 中加载未标记的数据集。我们将使用 imdb 电影评论情感数据集中的未标记拆分。相同的工作流程可以应用于任何自定义数据集、问题和语言!

  2. 使用 UI 手动 标注少量示例

  3. 训练 SetFit 模型 以获得极具竞争力的结果。对于此示例,仅使用 55 个示例,我们在测试集上获得了 0.9 的准确率,这与在 3K 个示例上微调的模型相当。这意味着以 50x 更少的示例获得相似的性能 🤯。

有关参考,请参阅 Hugging Face HubPapersWithCode 排行榜。

让我们开始吧!

运行 Argilla#

对于本教程,您需要运行 Argilla 服务器。部署和运行 Argilla 有两个主要选项

在 Hugging Face Spaces 上部署 Argilla:如果您想使用外部笔记本(例如,Google Colab)运行教程,并且您在 Hugging Face 上有一个帐户,您只需点击几下即可在 Spaces 上部署 Argilla

deploy on spaces

有关配置部署的详细信息,请查看 Hugging Face Hub 官方指南

使用 Argilla 的快速入门 Docker 镜像启动 Argilla:如果您想在 本地机器上运行 Argilla,这是推荐选项。请注意,此选项仅允许您在本地运行教程,而不能与外部笔记本服务一起运行。

有关部署选项的更多信息,请查看文档的部署部分。

提示

本教程是一个 Jupyter Notebook。有两种运行它的选项

  • 使用此页面顶部的“在 Colab 中打开”按钮。此选项允许您直接在 Google Colab 上运行笔记本。不要忘记将运行时类型更改为 GPU 以加快模型训练和推理速度。

  • 单击页面顶部的“查看源代码”链接下载 .ipynb 文件。此选项允许您下载笔记本并在本地机器或您选择的 Jupyter Notebook 工具上运行它。

[ ]:
%pip install argilla "setfit~=0.2.0" "datasets~=2.3.0" -qqq

让我们导入 Argilla 模块以进行数据读取和写入

[ ]:
import argilla as rg

如果您使用 Docker 快速入门镜像或 Hugging Face Spaces 运行 Argilla,您需要使用 URLAPI_KEY 初始化 Argilla 客户端

[ ]:
# Replace api_url with the url to your HF Spaces URL if using Spaces
# Replace api_key if you configured a custom API key
# Replace workspace with the name of your workspace
rg.init(
    api_url="https://#:6900",
    api_key="owner.apikey",
    workspace="admin"
)

如果您正在运行私有的 Hugging Face Space,您还需要按如下方式设置 HF_TOKEN

[ ]:
# # Set the HF_TOKEN environment variable
# import os
# os.environ['HF_TOKEN'] = "your-hf-token"

# # Replace api_url with the url to your HF Spaces URL
# # Replace api_key if you configured a custom API key
# # Replace workspace with the name of your workspace
# rg.init(
#     api_url="https://[your-owner-name]-[your_space_name].hf.space",
#     api_key="owner.apikey",
#     workspace="admin",
#     extra_headers={"Authorization": f"Bearer {os.environ['HF_TOKEN']}"},
# )

让我们导入所需的模块

[ ]:
from datasets import load_dataset
from sentence_transformers.losses import CosineSimilarityLoss

from setfit import SetFitModel, SetFitTrainer

启用遥测#

我们从您与教程的互动中获得宝贵的见解。为了改进自身,为您提供最合适的内容,使用以下代码行将帮助我们了解本教程是否有效地为您服务。虽然这是完全匿名的,但如果您愿意,可以选择跳过此步骤。有关更多信息,请查看遥测页面。

[ ]:
try:
    from argilla.utils.telemetry import tutorial_running
    tutorial_running()
except ImportError:
    print("Telemetry is introduced in Argilla 1.20.0 and not found in the current installation. Skipping telemetry.")

在 Argilla 中加载未标记的数据集#

首先,我们从 imdb 数据集中加载 unsupervised 拆分,并创建一个包含 100 个随机示例的新 Argilla 数据集

[ ]:
unlabelled = (
    load_dataset("imdb", split="unsupervised").shuffle(seed=42).select(range(100))
)

unlabelled = rg.DatasetForTextClassification.from_datasets(unlabelled)

rg.log(unlabelled, "imdb_unlabelled")

手动标注#

在此步骤中,我们使用与原始数据集相同的标签方案创建标签 posneg。然后我们使用 UI 顺序标注一些示例。对于这个例子,我们实际上只花了 15 分钟。

在训练之前,您可以使用 push_to_hub 方法轻松共享数据集。如果您在机器上没有 GPU,并且想要使用训练服务或 Colab 等,这可能很有用。更多信息请访问 此处

[ ]:
rg.load("imdb_unlabelled").prepare_for_training().push_to_hub("mini-imdb")

训练和评估 SetFit 模型#

最后,我们准备好测试 SetFit 了!

感谢 Argilla 与 datasets 和 Hub 的集成,如果您没有本地 GPU,您可以使用此 Google Colab 使用标注的数据集重现训练过程。如果您使用 GPU 运行时,则训练只需 5 分钟。

下面我们从 Argilla 加载数据集,将其格式化为使用 transformers 进行训练,加载完整的 IMDb 测试数据集,加载预训练的 sentence transformers 模型,训练 SetFit 模型并对其进行评估!

[ ]:
# Load the hand-labeled dataset from Argilla
train_ds = rg.load("imdb_unlabelled").prepare_for_training()

# Load the full IMDb test dataset
test_ds = load_dataset("imdb", split="test")


# Load SetFit model from Hub
model = SetFitModel.from_pretrained("sentence-transformers/paraphrase-mpnet-base-v2")

# Create trainer
trainer = SetFitTrainer(
    model=model,
    train_dataset=train_ds,
    eval_dataset=test_ds,
    loss_class=CosineSimilarityLoss,
    batch_size=16,
    num_iterations=20,  # The number of text pairs to generate
)

# Train and evaluate
trainer.train()
metrics = trainer.evaluate()

可选地,您可以与世界分享您出色的模型!

[ ]:
trainer.push_to_hub("setfit-mini-imdb")

结论#

指标对象应该在完整测试集上为您提供大约 0.9 的准确率 🎉

请记住

  • 我们手动标注了 55 个示例,

  • 我们没有使用原始训练集中的任何示例,

  • 并且我们在 5 分钟内训练了模型!

现在,我认为您没有任何借口不花一些时间标注一些高质量的示例了!

如果您对 SetFit 感兴趣,您可以查看我们的其他 SetFit + Argilla 教程

或者查看 GitHub 上的 SetFit 仓库