Open In Colab  View Notebook on GitHub

🎛️ 使用 ArgillaTrainer 微调 SetFit 模型#

在本例中,我们将展示如何使用 ArgillaTrainer 在由分类问题(如 LabelQuestionMultiLabelQuestion)组成的 FeedbackDataset 上微调 SetFit 模型。

我们将从 setfit 教程 开始,使用相同的数据集并将其调整为新的 trainer 类。如果需要,可以在 ArgillaTrainer 指南 中找到一些背景知识。

让我们开始吧!

注意

本教程是一个 Jupyter Notebook。有两种运行方式

  • 使用本页面顶部的“在 Colab 中打开”按钮。此选项允许您直接在 Google Colab 上运行 notebook。不要忘记将运行时类型更改为 GPU,以加快模型训练和推理速度。

  • 通过单击页面顶部的“查看源代码”链接下载 .ipynb 文件。此选项允许您下载 notebook 并在本地计算机或您选择的 Jupyter notebook 工具上运行它。

设置#

对于本教程,您需要运行 Argilla 服务器。如果您还没有服务器,请查看我们的快速入门安装 页面。完成后,完成以下步骤

  1. 使用 pip 安装 Argilla 客户端和所需的第三方库

[3]:
!pip install argilla setfit
  1. 让我们进行必要的导入

[1]:
import argilla as rg
from datasets import load_dataset
from argilla.feedback import ArgillaTrainer, FeedbackDataset, TrainingTaskMapping
  1. 如果您使用 Docker 快速入门镜像或 Hugging Face Spaces 运行 Argilla,则需要使用 URLAPI_KEY init Argilla 客户端

[ ]:
# Replace api_url with the url to your local host if using Docker or your HF Spaces url
# Replace api_key if you configured a custom API key
rg.init(
    api_url=r"https://<YOUR-HF-SPACE>.hf.space",
    api_key="admin.apikey"
)

如果您运行的是私有 Hugging Face Space,您还需要按如下方式设置 HF_TOKEN

[ ]:
# # Set the HF_TOKEN environment variable
# import os
# os.environ['HF_TOKEN'] = "your-hf-token"

# # Replace api_url with the url to your HF Spaces URL
# # Replace api_key if you configured a custom API key
# rg.init(
#     api_url=r"https://<YOUR-HF-SPACE>.hf.space",
#     api_key="admin.apikey",
#     extra_headers={"Authorization": f"Bearer {os.environ['HF_TOKEN']}"},
# )

启用遥测#

我们从您与教程的互动中获得宝贵的见解。为了改进我们为您提供最合适内容的方式,使用以下代码行将帮助我们了解本教程是否有效地为您服务。虽然这是完全匿名的,但如果您愿意,可以选择跳过此步骤。有关更多信息,请查看遥测页面。

[ ]:
try:
    from argilla.utils.telemetry import tutorial_running
    tutorial_running()
except ImportError:
    print("Telemetry is introduced in Argilla 1.20.0 and not found in the current installation. Skipping telemetry.")

配置数据集#

在本例中,我们将加载一个流行的开源数据集,其中包含银行业务领域的客户请求。

[3]:
data = load_dataset("PolyAI/banking77", split="test")

我们将使用两个不同的问题配置我们的数据集,以便我们可以同时处理两个文本分类任务。在本例中,我们将加载此数据集的原始标签,以对请求中提到的主题进行多标签分类,并且我们还将设置一个问题,将请求的情感分类为“正面”、“中性”或“负面”。

[4]:
dataset = rg.FeedbackDataset(
    guidelines="Add some guidelines for the annotation team here.",
    fields=[
        rg.TextField(name="text", title="banking topics")
    ],
    questions=[
        rg.MultiLabelQuestion(
            name="topics",
            title="Select the topic(s) of the request",
            labels=data.info.features['label'].names, #these are the original labels present in the dataset
            visible_labels=10
        ),
        rg.LabelQuestion(
            name="sentiment",
            title="What is the sentiment of the message?",
            labels=["positive", "neutral", "negative"]
        )
    ]
)

定义任务映射#

我们将训练两个不同的模型,一个用于带有情感的 LabelQuestion,第二个用于数据集原始主题的 MultiLabelQuestion。为此,我们将使用 TrainingTaskMapping.for_text_classification

[5]:
task_mapping_topics = TrainingTaskMapping.for_text_classification(
    text=dataset.field_by_name("text"),
    label=dataset.question_by_name("topics")
)
task_mapping_sentiment = TrainingTaskMapping.for_text_classification(
    text=dataset.field_by_name("text"),
    label=dataset.question_by_name("sentiment")
)

为我们的数据集创建记录#

此时,我们的数据集包含结构,但没有记录来训练我们的模型。示例已标记主题,但情感并非如此。在本例中,我们将使用 transformer 模型 预训练以获取情感。

让我们创建记录。

[ ]:
from transformers import pipeline

sentiment_classifier = pipeline(model="cardiffnlp/twitter-roberta-base-sentiment-latest")
[52]:
import random

random.seed(1234)

record_indices = random.choices(range(len(data)), k=8)


def get_sentiment(text: str) -> str:
    return sentiment_classifier(text)[0]["label"]


def topic_int2str(label_int: int) -> str:
    return data.features["label"].int2str(label_int)


records = [
    rg.FeedbackRecord(
        fields={"text": record['text']},
        responses=[
            {
                "values": {
                    # Get the sentiment from a pretrained transformer model
                    "sentiment": {"value": get_sentiment(record["text"])},
                    # Add the topics as a list
                    "topics": {"value": [topic_int2str(record["label"])]}
                }
            }
        ]
    )
    for record in data.select(record_indices)
]

让我们回顾一下之前的代码块。请注意,此 FeedbackDataset 仅用于演示目的。我们将仅选择 8 个示例(在 record_indices 中,我们存储将从原始数据集中抓取的记录的索引)以加快训练速度,并从我们的 sentiment_classifier 获取情感。我们在 hugging face hub 上有一个足够好的模型可供使用,但如果情况并非如此,我们可以使用“足够好”的模型来获取建议并正确标记它以训练我们的模型,请查看此 教程 以获取更多信息。

关于主题,原始数据集已经标记,所以我们只需要插入相应的名称。topic_int2str 为我们完成了映射。

让我们将记录添加到我们的数据集

[55]:
dataset.add_records(records=records)
[68]:
dataset.format_as("datasets").to_pandas()[-3:][["text", "topics", "sentiment"]]
[68]:
文本 主题 情感
5 我的退款不见了 [{'user_id': None, 'value': ['Refund_not_showi... [{'user_id': None, 'value': 'negative', 'statu...
6 处理从 ... 转账需要多长时间 [{'user_id': None, 'value': ['transfer_timing'... [{'user_id': None, 'value': 'neutral', 'status...
7 你们会处理欧元吗? [{'user_id': None, 'value': ['fiat_currency_su... [{'user_id': None, 'value': 'neutral', 'status...
[ ]:
ds = dataset.push_to_argilla("setfit_training_tutorial", workspace="admin")

screenshot-demo-feedback-dataset.png

训练模型#

现在我们将为每个任务训练两个不同的模型,只需使用适当的 task_mapping,并将它们保存到它们各自的文件夹中

[ ]:
trainer_sentiment = ArgillaTrainer(
    dataset=dataset,
    task=task_mapping_sentiment,
    framework="setfit",
)
trainer_sentiment.update_config(
    num_train_epochs=1,
)
trainer_sentiment.train(output_dir="sentiment_model")
[ ]:
trainer_topics = ArgillaTrainer(
    dataset=dataset,
    task=task_mapping_topics,
    framework="setfit",
)
trainer_topics.update_config(num_train_epochs=1)
trainer_topics.train(output_dir="topics_model")

进行预测#

按照之前的教程,让我们使用我们新训练的模型获取预测。首先,我们调整之前教程中的函数以使用我们的 ArgillaTrainer.predict 方法。

[384]:
def get_predictions(texts, model):
    return model.predict(texts)

将预测应用于我们的子集,并查看生成的数据集。

[ ]:
subset = data.select(record_indices)
subset = subset.map(
    lambda batch: {
        "topics": get_predictions(batch["text"], trainer_topics),
        "sentiment": get_predictions(batch["text"], trainer_sentiment),
    },
    batched=True,
)
[387]:
subset.to_pandas().set_index("text").head()
[387]:
标签 主题 情感
文本
我在国外货币的 ATM 上使用了 ATM,但应用的汇率是错误的! 76 {'activate_my_card': 0.11339724732947638, 'age... {'negative': 0.30721275251644486, 'neutral': 0...
我不明白为什么它说我必须验证充值。 71 {'activate_my_card': 0.11338630825105621, 'age... {'negative': 0.19467522210149565, 'neutral': 0...
我提交了一笔交易到错误的帐户。 8 {'activate_my_card': 0.2284524559678986, 'age_... {'negative': 0.20868444521020138, 'neutral': 0...
汇率以什么货币计算? 32 {'activate_my_card': 0.10659121602923811, 'age... {'negative': 0.433266446947898, 'neutral': 0.1...
你知道我的卡在哪里可以被接受吗? 10 {'activate_my_card': 0.10908837526993907, 'age... {'negative': 0.30172024225859906, 'neutral': 0...

结论#

在本教程中,我们介绍了如何使用新的 ArgillaTrainerFeedbackDataset 训练 SetFit 模型。在“使用 SetFit 添加零样本建议”教程的基础上,我们学习了如何使用新的 Argilla API 在我们的数据集上训练模型,而无需离开 Argilla。

要了解有关 SetFit 的更多信息,请查看以下链接