🎛️ 使用 ArgillaTrainer 微调 SetFit 模型#
在本例中,我们将展示如何使用 ArgillaTrainer
在由分类问题(如 LabelQuestion
和 MultiLabelQuestion
)组成的 FeedbackDataset
上微调 SetFit 模型。
我们将从 setfit 教程 开始,使用相同的数据集并将其调整为新的 trainer 类。如果需要,可以在 ArgillaTrainer 指南 中找到一些背景知识。
让我们开始吧!
注意
本教程是一个 Jupyter Notebook。有两种运行方式
使用本页面顶部的“在 Colab 中打开”按钮。此选项允许您直接在 Google Colab 上运行 notebook。不要忘记将运行时类型更改为 GPU,以加快模型训练和推理速度。
通过单击页面顶部的“查看源代码”链接下载 .ipynb 文件。此选项允许您下载 notebook 并在本地计算机或您选择的 Jupyter notebook 工具上运行它。
设置#
对于本教程,您需要运行 Argilla 服务器。如果您还没有服务器,请查看我们的快速入门 或 安装 页面。完成后,完成以下步骤
使用
pip
安装 Argilla 客户端和所需的第三方库
[3]:
!pip install argilla setfit
让我们进行必要的导入
[1]:
import argilla as rg
from datasets import load_dataset
from argilla.feedback import ArgillaTrainer, FeedbackDataset, TrainingTaskMapping
如果您使用 Docker 快速入门镜像或 Hugging Face Spaces 运行 Argilla,则需要使用
URL
和API_KEY
init
Argilla 客户端
[ ]:
# Replace api_url with the url to your local host if using Docker or your HF Spaces url
# Replace api_key if you configured a custom API key
rg.init(
api_url=r"https://<YOUR-HF-SPACE>.hf.space",
api_key="admin.apikey"
)
如果您运行的是私有 Hugging Face Space,您还需要按如下方式设置 HF_TOKEN
[ ]:
# # Set the HF_TOKEN environment variable
# import os
# os.environ['HF_TOKEN'] = "your-hf-token"
# # Replace api_url with the url to your HF Spaces URL
# # Replace api_key if you configured a custom API key
# rg.init(
# api_url=r"https://<YOUR-HF-SPACE>.hf.space",
# api_key="admin.apikey",
# extra_headers={"Authorization": f"Bearer {os.environ['HF_TOKEN']}"},
# )
启用遥测#
我们从您与教程的互动中获得宝贵的见解。为了改进我们为您提供最合适内容的方式,使用以下代码行将帮助我们了解本教程是否有效地为您服务。虽然这是完全匿名的,但如果您愿意,可以选择跳过此步骤。有关更多信息,请查看遥测页面。
[ ]:
try:
from argilla.utils.telemetry import tutorial_running
tutorial_running()
except ImportError:
print("Telemetry is introduced in Argilla 1.20.0 and not found in the current installation. Skipping telemetry.")
配置数据集#
在本例中,我们将加载一个流行的开源数据集,其中包含银行业务领域的客户请求。
[3]:
data = load_dataset("PolyAI/banking77", split="test")
我们将使用两个不同的问题配置我们的数据集,以便我们可以同时处理两个文本分类任务。在本例中,我们将加载此数据集的原始标签,以对请求中提到的主题进行多标签分类,并且我们还将设置一个问题,将请求的情感分类为“正面”、“中性”或“负面”。
[4]:
dataset = rg.FeedbackDataset(
guidelines="Add some guidelines for the annotation team here.",
fields=[
rg.TextField(name="text", title="banking topics")
],
questions=[
rg.MultiLabelQuestion(
name="topics",
title="Select the topic(s) of the request",
labels=data.info.features['label'].names, #these are the original labels present in the dataset
visible_labels=10
),
rg.LabelQuestion(
name="sentiment",
title="What is the sentiment of the message?",
labels=["positive", "neutral", "negative"]
)
]
)
定义任务映射#
我们将训练两个不同的模型,一个用于带有情感的 LabelQuestion
,第二个用于数据集原始主题的 MultiLabelQuestion
。为此,我们将使用 TrainingTaskMapping.for_text_classification
。
[5]:
task_mapping_topics = TrainingTaskMapping.for_text_classification(
text=dataset.field_by_name("text"),
label=dataset.question_by_name("topics")
)
task_mapping_sentiment = TrainingTaskMapping.for_text_classification(
text=dataset.field_by_name("text"),
label=dataset.question_by_name("sentiment")
)
为我们的数据集创建记录#
此时,我们的数据集包含结构,但没有记录来训练我们的模型。示例已标记主题,但情感并非如此。在本例中,我们将使用 transformer 模型 预训练以获取情感。
让我们创建记录。
[ ]:
from transformers import pipeline
sentiment_classifier = pipeline(model="cardiffnlp/twitter-roberta-base-sentiment-latest")
[52]:
import random
random.seed(1234)
record_indices = random.choices(range(len(data)), k=8)
def get_sentiment(text: str) -> str:
return sentiment_classifier(text)[0]["label"]
def topic_int2str(label_int: int) -> str:
return data.features["label"].int2str(label_int)
records = [
rg.FeedbackRecord(
fields={"text": record['text']},
responses=[
{
"values": {
# Get the sentiment from a pretrained transformer model
"sentiment": {"value": get_sentiment(record["text"])},
# Add the topics as a list
"topics": {"value": [topic_int2str(record["label"])]}
}
}
]
)
for record in data.select(record_indices)
]
让我们回顾一下之前的代码块。请注意,此 FeedbackDataset
仅用于演示目的。我们将仅选择 8 个示例(在 record_indices
中,我们存储将从原始数据集中抓取的记录的索引)以加快训练速度,并从我们的 sentiment_classifier
获取情感。我们在 hugging face hub 上有一个足够好的模型可供使用,但如果情况并非如此,我们可以使用“足够好”的模型来获取建议并正确标记它以训练我们的模型,请查看此 教程 以获取更多信息。
关于主题,原始数据集已经标记,所以我们只需要插入相应的名称。topic_int2str
为我们完成了映射。
让我们将记录添加到我们的数据集
[55]:
dataset.add_records(records=records)
[68]:
dataset.format_as("datasets").to_pandas()[-3:][["text", "topics", "sentiment"]]
[68]:
文本 | 主题 | 情感 | |
---|---|---|---|
5 | 我的退款不见了 | [{'user_id': None, 'value': ['Refund_not_showi... | [{'user_id': None, 'value': 'negative', 'statu... |
6 | 处理从 ... 转账需要多长时间 | [{'user_id': None, 'value': ['transfer_timing'... | [{'user_id': None, 'value': 'neutral', 'status... |
7 | 你们会处理欧元吗? | [{'user_id': None, 'value': ['fiat_currency_su... | [{'user_id': None, 'value': 'neutral', 'status... |
[ ]:
ds = dataset.push_to_argilla("setfit_training_tutorial", workspace="admin")
训练模型#
现在我们将为每个任务训练两个不同的模型,只需使用适当的 task_mapping,并将它们保存到它们各自的文件夹中
[ ]:
trainer_sentiment = ArgillaTrainer(
dataset=dataset,
task=task_mapping_sentiment,
framework="setfit",
)
trainer_sentiment.update_config(
num_train_epochs=1,
)
trainer_sentiment.train(output_dir="sentiment_model")
[ ]:
trainer_topics = ArgillaTrainer(
dataset=dataset,
task=task_mapping_topics,
framework="setfit",
)
trainer_topics.update_config(num_train_epochs=1)
trainer_topics.train(output_dir="topics_model")
进行预测#
按照之前的教程,让我们使用我们新训练的模型获取预测。首先,我们调整之前教程中的函数以使用我们的 ArgillaTrainer.predict
方法。
[384]:
def get_predictions(texts, model):
return model.predict(texts)
将预测应用于我们的子集,并查看生成的数据集。
[ ]:
subset = data.select(record_indices)
subset = subset.map(
lambda batch: {
"topics": get_predictions(batch["text"], trainer_topics),
"sentiment": get_predictions(batch["text"], trainer_sentiment),
},
batched=True,
)
[387]:
subset.to_pandas().set_index("text").head()
[387]:
标签 | 主题 | 情感 | |
---|---|---|---|
文本 | |||
我在国外货币的 ATM 上使用了 ATM,但应用的汇率是错误的! | 76 | {'activate_my_card': 0.11339724732947638, 'age... | {'negative': 0.30721275251644486, 'neutral': 0... |
我不明白为什么它说我必须验证充值。 | 71 | {'activate_my_card': 0.11338630825105621, 'age... | {'negative': 0.19467522210149565, 'neutral': 0... |
我提交了一笔交易到错误的帐户。 | 8 | {'activate_my_card': 0.2284524559678986, 'age_... | {'negative': 0.20868444521020138, 'neutral': 0... |
汇率以什么货币计算? | 32 | {'activate_my_card': 0.10659121602923811, 'age... | {'negative': 0.433266446947898, 'neutral': 0.1... |
你知道我的卡在哪里可以被接受吗? | 10 | {'activate_my_card': 0.10908837526993907, 'age... | {'negative': 0.30172024225859906, 'neutral': 0... |
结论#
在本教程中,我们介绍了如何使用新的 ArgillaTrainer
和 FeedbackDataset
训练 SetFit 模型。在“使用 SetFit 添加零样本建议”教程的基础上,我们学习了如何使用新的 Argilla API 在我们的数据集上训练模型,而无需离开 Argilla。
要了解有关 SetFit 的更多信息,请查看以下链接