Open In Colab  View Notebook on GitHub

🥇 比较文本分类模型#

在本教程中,您将学习如何使用两个不同的模型对数据集进行文本分类,将模型预测上传到您的 Argilla 工作区,并通过计算每个模型的 F1 分数来比较模型。它将引导您完成以下步骤: - 💾 加载您想要使用的数据集。 - 💻 使用零样本分类模型计算预测。 - 🔄 将模型输出转换为 Argilla 格式并上传到 Argilla 工作区。 - 💻 使用零样本 SetFit 模型计算预测。 - 🧪 使用 F1 分数比较模型预测

简介#

在进行文本分类工作时,您可能想要比较两个模型以决定使用哪个模型。为此,我们使用模型的注释作为真实文本类别,计算训练模型的 F1 分数。F1 分数可以解释为精确率和召回率的调和平均值,其中 F1 分数的最佳值为 1,最差值为 0。(更多信息请参考此文档

Argilla 允许您部署和监控您喜欢的任何模型,但在本教程中,我们将重点关注 NLP 领域中最常用的两个框架:transformersSetFit。让我们开始吧!

运行 Argilla#

对于本教程,您需要运行 Argilla 服务器。部署和运行 Argilla 有两个主要选项

在 Hugging Face Spaces 上部署 Argilla:如果您想使用外部笔记本(例如,Google Colab)运行教程,并且您在 Hugging Face 上有帐户,则只需点击几下即可在 Spaces 上部署 Argilla

deploy on spaces

有关配置部署的详细信息,请查看官方 Hugging Face Hub 指南

使用 Argilla 的快速入门 Docker 镜像启动 Argilla:如果您想在本地机器上运行 Argilla,这是推荐的选项。请注意,此选项仅允许您在本地运行教程,而不能与外部笔记本服务一起运行。

有关部署选项的更多信息,请查看文档的部署部分。

提示

本教程是一个 Jupyter Notebook。有两种运行方式

  • 使用此页面顶部的“在 Colab 中打开”按钮。此选项允许您直接在 Google Colab 上运行 notebook。不要忘记将运行时类型更改为 GPU 以加快模型训练和推理速度。

  • 单击页面顶部的“查看源代码”链接下载 .ipynb 文件。此选项允许您下载 notebook 并在本地机器或您选择的 Jupyter Notebook 工具上运行它。

设置#

要完成本教程,您需要使用 pip 安装 Argilla 客户端和一些第三方库

[1]:
%pip install transformers argilla datasets torch setfit -qqqqqqq

所需的导入

[2]:
import argilla as rg
from datasets import load_dataset
from transformers import pipeline
from argilla.metrics.text_classification import f1
import pandas as pd

如果您正在使用 Docker 快速入门镜像或 Hugging Face Spaces 运行 Argilla,您需要使用 URLAPI_KEY 初始化 Argilla 客户端

[3]:
# Replace api_url with the url to your HF Spaces URL if using Spaces
# Replace api_key if you configured a custom API key
# Replace workspace with the name of your workspace
rg.init(
    api_url="https://#:6900",
    api_key="owner.apikey",
    workspace="admin"
)

如果您正在运行私有的 Hugging Face Space,您还需要设置 HF_TOKEN,如下所示

[ ]:
# # Set the HF_TOKEN environment variable
# import os
# os.environ['HF_TOKEN'] = "your-hf-token"

# # Replace api_url with the url to your HF Spaces URL
# # Replace api_key if you configured a custom API key
# # Replace workspace with the name of your workspace
# rg.init(
#     api_url="https://[your-owner-name]-[your_space_name].hf.space",
#     api_key="owner.apikey",
#     workspace="admin",
#     extra_headers={"Authorization": f"Bearer {os.environ['HF_TOKEN']}"},
# )

本教程选择 HugginFace ag_news 数据集

[ ]:
news_dataset = load_dataset("ag_news", split="test")

此数据集由两列组成,一列是新闻文章的文本,另一列是与该文本文章关联的标签

对于本教程,我们将标签视为文本的注释

我们转换数据集以创建一个 argilla TextClassificationDataset

[ ]:
int_to_label = {
    0:"World",
    1:"Sports",
    2:"Business",
    3:"Sci/Tech",
}

news_dataset = news_dataset.map(lambda row: {"prediction": [{"label":int_to_label[row["label"]], "score":1}]})
[ ]:
ds_record = rg.read_datasets(dataset=news_dataset, task="TextClassification")

启用遥测#

我们从您与我们教程的互动中获得宝贵的见解。为了改进我们自己,为您提供最合适的内容,使用以下代码行将帮助我们了解本教程是否有效地为您服务。虽然这是完全匿名的,但如果您愿意,可以选择跳过此步骤。有关更多信息,请查看遥测页面。

[ ]:
try:
    from argilla.utils.telemetry import tutorial_running
    tutorial_running()
except ImportError:
    print("Telemetry is introduced in Argilla 1.20.0 and not found in the current installation. Skipping telemetry.")

使用 transformers 进行零样本文本分类预测#

在 HugginFace 上,我们选择模型 cross-encoder/nli-distilroberta-base,该模型经过训练以执行零样本分类。我们使用此模型创建一个 pipeline,然后执行预测。

注意:``pipeline()`` 中的 ``device=0`` 允许使用 GPU,如果您没有可用的 GPU,请删除此参数

[7]:
labels =["Sports", "Sci/Tech", "Business", "World"]

pipe = pipeline("zero-shot-classification", model='cross-encoder/nli-distilroberta-base', device=0)
result = []
with pipe.device_placement():
  result = pipe(
    [data.text for data in ds_record],
    candidate_labels=labels,
  )

现在已经使用零样本模型成功进行了预测,我们可以将其转换为 argilla TextClassificationRecord 列表并上传到我们的 argilla 客户端

[ ]:
zero_shot_news_dataset = [
    rg.TextClassificationRecord(
        text=res["sequence"],
        prediction=list(zip(res['labels'],res['scores'])),
        annotation=record.prediction[0][0],
    ) for res, record in zip(result, ds_record)
]

rg.log(name="zero_shot_news_dataset", records=zero_shot_news_dataset)

您可以在 Argilla UI 中访问 zero_shot_news_dataset

Model comparison demo

最后,我们使用 argilla f1 函数测量模型的性能,该函数计算 F1 分数

[9]:
f1_zero_shot = f1("zero_shot_news_dataset")
f1_zero_shot.visualize()

使用训练好的 SetFit 分类器进行零样本文本分类#

所需的导入

[10]:
from setfit import SetFitModel, SetFitTrainer, get_templated_dataset

我们使用数据集的标签创建训练示例的合成数据集

[11]:
labels = ["Sports", "Sci/Tech", "Business", "World"]
train_dataset = get_templated_dataset(
    candidate_labels=labels,
    sample_size=8,
    template="The news article is about {}"
)

我们使用预训练模型 ‘all-MiniLM-L6-v2’ 训练 SetFitModel

[ ]:
model = SetFitModel.from_pretrained("all-MiniLM-L6-v2")
trainer = SetFitTrainer(
    model=model,
    train_dataset=train_dataset
)
trainer.train()

然后我们可以计算文本分类

[13]:
result = [{
    'sequence': data["text"],
    'scores': model.predict_proba([data["text"]]).squeeze().numpy(),
    'labels': labels
    } for data in news_dataset]

最后,我们可以将结果数据集记录到 Argilla 中并计算 F1 分数

[ ]:
setfit_zero_shot_news_dataset = [
    rg.TextClassificationRecord(
        text=res["sequence"],
        prediction=list(zip(res['labels'],res['scores'])),
        annotation=record.prediction[0][0],
    ) for res, record in zip(result, ds_record)
]
rg.log(name="setfit_zero_shot_news_dataset", records=setfit_zero_shot_news_dataset)

您可以在 Argilla UI 中访问 setfit_zero_shot_news_dataset

Model comparison demo

[15]:
f1_setfit_zero_shot = f1("setfit_zero_shot_news_dataset")
f1_setfit_zero_shot.visualize()

现在我们已经计算了每个模型的 F1 分数,我们可以创建一个比较表

[16]:
f_score = list(f1_setfit_zero_shot.data.keys())
f1_setfit_zero_shot_values = list(f1_setfit_zero_shot.data.values())
f1_zero_shot_values = list(f1_zero_shot.data.values())

unnecessary_labels = ["Sports_recall", "World_recall", ""]

df_results = pd.DataFrame({"f_score": f_score, "zero-shot classification": f1_zero_shot_values, "zero-shot SetFit classification": f1_setfit_zero_shot_values})
[17]:
df_results
[17]:
f_score 零样本分类 零样本 SetFit 分类
0 precision_macro 0.517754 0.663322
1 recall_macro 0.529605 0.668816
2 f1_macro 0.514483 0.663725
3 precision_micro 0.529605 0.668816
4 recall_micro 0.529605 0.668816
5 f1_micro 0.529605 0.668816
6 Sci/Tech_precision 0.476950 0.556291
7 Sci/Tech_recall 0.283158 0.530526
8 Sci/Tech_f1 0.355350 0.543103
9 Sci/Tech_support 11400.000000 7600.000000
10 World_precision 0.367909 0.663734
11 World_recall 0.358421 0.555789
12 World_f1 0.363103 0.604984
13 World_support 11400.000000 7600.000000
14 Business_precision 0.449227 0.620098
15 Business_recall 0.565789 0.665789
16 Business_f1 0.500815 0.642132
17 Business_support 11400.000000 7600.000000
18 Sports_precision 0.776930 0.813166
19 Sports_recall 0.911053 0.923158
20 Sports_f1 0.838663 0.864678
21 Sports_support 11400.000000 7600.000000

结果解释: 毫无疑问,使用 SetFit 模型的零样本分类更有效。每个类别的 F1 分数都更好。

对于两个分类器,最佳预测类别都是 Sports