🧼 使用模型的损失清理标签#
在本教程中,我们将学习介绍一种简单的错误分析技术,使用模型损失来查找潜在的训练数据错误。
🤗 此技术使用 来自 Hugging Face Hub 的微调文本分类器在 AG News 数据集上展示。
✅ 使用 Argilla,我们将验证在这个著名的 NLP 基准测试的训练集中超过 50 个错误标记的示例。
💥 这个技巧对于使用小型和嘈杂的数据集进行模型训练非常有用。
👥 这个技巧与其他“以数据为中心”的 ML 方法(如
cleanlab
)互补(请参阅此 Argilla 教程)。
简介#
本教程解释了一个你可以与 Argilla 一起使用的简单技巧,用于查找训练数据中的潜在错误:使用你的模型损失来识别标签错误或模棱两可的示例。这个技巧并不新鲜(那些使用过 fastai 的人知道 plot_top_losses
方法有多么有用)。甚至 Andrej Karpathy 在一段时间前也 发推文 提到过
当你按损失降序对数据集进行排序时,你肯定会发现一些意想不到、奇怪且有用的东西。
— Andrej Karpathy (@karpathy) 2020 年 10 月 2 日
该技术非常简单:如果你正在使用训练集训练模型,请训练你的模型,并将你的模型应用于训练集,以计算训练集中每个示例的损失。如果你按损失对数据集示例进行排序,损失最高的示例是最模糊和最难学习的。
此技术可用于模型开发期间的错误分析(例如,识别分词问题),但事实证明,它也是一种非常简单的技术,用于在模型开发期间或训练数据收集活动后清理训练数据。
在本教程中,我们将对著名的文本分类基准 AG News 数据集 使用此技术。计算损失后,我们将使用 Argilla 分析损失最高的示例。在不到 5 分钟的时间内,我们手动检查并重新标记了前 50 个示例。事实上,损失最高的前 50 个示例在原始训练集中都是不正确的。如果我们进一步目视检查示例,我们仍然在前 500 个示例中发现标签错误。
为什么这很重要#
机器学习模型的好坏取决于它们训练的数据。几乎所有训练数据源都可以被认为是“嘈杂的”(例如,众包工人、注释员错误、弱监督源、数据增强等)
使用这个简单的技术,我们能够在不到 5 分钟的时间内在广泛使用的基准测试中找到超过 50 个标签错误(你的数据集可能会更嘈杂!)。
随着先进的模型架构被广泛使用,管理、清理和整理数据正成为构建稳健的 ML 应用程序的关键步骤。关于当前情况的良好总结可以在 以数据为中心的 AI NeurIPS 工作坊 网站上找到。
这个简单的技巧可以用于整个 ML 生命周期,而不仅仅是查找标签错误。通过这个技巧,你可以改进数据预处理、分词,甚至你的模型架构。
运行 Argilla#
在本教程中,你将需要运行一个 Argilla 服务器。部署和运行 Argilla 有两个主要选项
在 Hugging Face Spaces 上部署 Argilla:如果你想使用外部笔记本(例如,Google Colab)运行教程,并且你拥有 Hugging Face 帐户,则只需点击几下即可在 Spaces 上部署 Argilla
有关配置部署的详细信息,请查看 官方 Hugging Face Hub 指南。
使用 Argilla 的快速入门 Docker 镜像启动 Argilla:如果你想在 你的本地机器上运行 Argilla,这是推荐的选项。请注意,此选项仅允许你在本地运行教程,而不能与外部笔记本服务一起运行。
有关部署选项的更多信息,请查看文档的部署部分。
提示
本教程是一个 Jupyter Notebook。有两种运行它的选项
使用此页面顶部的“在 Colab 中打开”按钮。此选项允许你直接在 Google Colab 上运行笔记本。不要忘记将运行时类型更改为 GPU 以加快模型训练和推理速度。
通过单击页面顶部的“查看源代码”链接下载 .ipynb 文件。此选项允许你下载笔记本并在本地机器或你选择的 Jupyter Notebook 工具上运行它。
[ ]:
%pip install argilla transformers datasets torch -qqq
让我们导入 Argilla 模块以进行数据读取和写入
[ ]:
import argilla as rg
如果你正在使用 Docker 快速入门镜像或 Hugging Face Spaces 运行 Argilla,你需要使用 URL
和 API_KEY
初始化 Argilla 客户端
[ ]:
# Replace api_url with the url to your HF Spaces URL if using Spaces
# Replace api_key if you configured a custom API key
# Replace workspace with the name of your workspace
rg.init(
api_url="https://#:6900",
api_key="owner.apikey",
workspace="admin"
)
如果你正在运行私有的 Hugging Face Space,你还需要按如下方式设置 HF_TOKEN
[ ]:
# # Set the HF_TOKEN environment variable
# import os
# os.environ['HF_TOKEN'] = "your-hf-token"
# # Replace api_url with the url to your HF Spaces URL
# # Replace api_key if you configured a custom API key
# # Replace workspace with the name of your workspace
# rg.init(
# api_url="https://[your-owner-name]-[your_space_name].hf.space",
# api_key="owner.apikey",
# workspace="admin",
# extra_headers={"Authorization": f"Bearer {os.environ['HF_TOKEN']}"},
# )
最后,让我们包含我们需要导入的内容
[ ]:
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from datasets import load_dataset, Dataset, Features, Value, ClassLabel
from transformers.data.data_collator import DataCollatorWithPadding
import pandas as pd
启用遥测#
我们从你与我们的教程互动的方式中获得了宝贵的见解。为了改进我们自己,为您提供最合适的内容,使用以下代码行将帮助我们了解本教程是否有效地为您服务。虽然这是完全匿名的,但如果您愿意,可以选择跳过此步骤。有关更多信息,请查看遥测页面。
[ ]:
try:
from argilla.utils.telemetry import tutorial_running
tutorial_running()
except ImportError:
print("Telemetry is introduced in Argilla 1.20.0 and not found in the current installation. Skipping telemetry.")
预备知识#
使用 AG News 数据集微调的模型(如果你愿意,你可以训练自己的模型)。
AG News 训练集拆分(相同的技巧可以并且应该应用于验证和测试拆分)。
Argilla 用于记录、探索和重新标记错误示例(我们提供预先计算的数据集,因此可以随意跳过此步骤)
1. 加载微调模型和训练数据集#
现在,我们将加载 AG News 数据集。但首先,我们需要定义和设置设备、模型和分词器
[ ]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load model and tokenizer
tokenizer = AutoTokenizer.from_pretrained("andi611/distilbert-base-uncased-ner-agnews")
model = AutoModelForSequenceClassification.from_pretrained(
"andi611/distilbert-base-uncased-ner-agnews"
)
model.to(device)
# Load the training split
ds = load_dataset("ag_news", split="train")
# Tokenize and encode the training set
def tokenize_and_encode(batch):
return tokenizer(batch["text"], truncation=True)
ds_enc = ds.map(tokenize_and_encode, batched=True)
2. 计算损失#
以下代码将使用我们训练的模型计算每个示例的损失。此过程取自 Lewis Tunstall 的博客文章,该文章解释得非常好:“使用数据整理器进行训练和错误分析”,他在其中解释了模型训练期间进行错误分析的此过程。
在我们的例子中,我们直接实例化一个数据整理器,而他直接使用 Trainer 中的数据整理器
[ ]:
# Create the data collator for inference
data_collator = DataCollatorWithPadding(tokenizer, padding=True)
# Function to compute the loss example-wise
def loss_per_example(batch):
batch = data_collator(batch)
input_ids = torch.tensor(batch["input_ids"], device=device)
attention_mask = torch.tensor(batch["attention_mask"], device=device)
labels = torch.tensor(batch["labels"], device=device)
with torch.no_grad():
output = model(input_ids, attention_mask)
batch["predicted_label"] = torch.argmax(output.logits, axis=1)
# compute the probabilities for logging them into Argilla
batch["predicted_probas"] = torch.nn.functional.softmax(output.logits, dim=0)
# Don't reduce the loss (return the loss for each example)
loss = torch.nn.functional.cross_entropy(output.logits, labels, reduction="none")
batch["loss"] = loss
# Datasets complains with numpy dtypes, let's use Python lists
for k, v in batch.items():
batch[k] = v.cpu().numpy().tolist()
return batch
现在,是时候将数据集转换为 Pandas 数据帧,并按损失降序对数据集进行排序了
警告
计算此信息的模型计算量很大,如果你的 GPU 不够强大,可能会很慢。即使你拥有 GPU,为完整数据集计算此信息也可能需要大约 5 分钟。尝试在 load_dataset
期间减少数据集中的记录数,或使用 .select(range(number_of_record))
。
[ ]:
losses_ds = ds_enc.remove_columns("text").map(
loss_per_example, batched=True, batch_size=32
)
# Turn the dataset into a Pandas dataframe, sort by descending loss and visualize the top examples.
pd.set_option("display.max_colwidth", None)
losses_ds.set_format("pandas")
losses_df = losses_ds[:][["label", "predicted_label", "loss", "predicted_probas"]]
# Add the text column removed by the trainer
losses_df["text"] = ds_enc["text"]
losses_df.sort_values("loss", ascending=False).head(10)
label | predicted_label | loss | predicted_probas | text | |
---|---|---|---|---|---|
44984 | 1 | 0 | 8.833023 | [0.06412869691848755, 7.090532017173246e-05, 0.00019675122166518122, 0.0002370826987316832] | 巴格达爆炸事件造成至少 16 人死亡,武装分子在巴格达南部美国军用车辆车队附近引爆了两枚炸弹,伊拉克警方称,至少造成 16 人死亡。 |
101562 | 1 | 0 | 8.781285 | [0.12395327538251877, 9.289286026614718e-06, 0.0001785584754543379, 0.0007945793331600726] | 不道德、不公正、压迫性的独裁统治。。。然后是 #39;s <b>...</b> 罗伯特·穆加贝政府正在推行旨在阻止人权组织在津巴布韦运作的立法。 |
31564 | 1 | 2 | 8.772168 | [0.00016983140085358173, 8.863882612786256e-06, 0.18702593445777893, 0.00025946463574655354] | 福特将在英国捷豹部门裁员 1150 人 福特汽车公司周五宣布,将在英格兰裁员 1150 人,以精简其捷豹汽车有限公司部门,该部门的销售疲软未能抵消新产品和业务其他部分的支出。 |
41247 | 1 | 0 | 8.751480 | [0.2929899990558624, 7.849136454751715e-05, 0.00034211069578304887, 4.463219011086039e-05] | 巴勒斯坦枪手绑架 CNN 制片人 加沙城,加沙地带 - 该网络称,巴勒斯坦枪手周一在加沙城绑架了一名 CNN 制片人。该网络称,利雅得·阿里在枪口下从 CNN 面包车上被带走。 |
44961 | 1 | 0 | 8.740394 | [0.06420651078224182, 7.788064249325544e-05, 0.0001824614155339077, 0.0002348265261389315] | 巴格达炸弹爆炸事件造成至少 35 人死亡,120 人受伤 武装分子周四在巴格达南部美国军用车辆车队附近引爆了三枚汽车炸弹,造成至少 35 人死亡,约 120 人受伤,其中许多是儿童,官员和医生说。 |
75216 | 1 | 0 | 8.735966 | [0.13383473455905914, 1.837693343986757e-05, 0.00017987379396799952, 0.00036031895433552563] | 海军陆战队妻子集会 一群海军陆战队妻子正在为一名在伊拉克阵亡的海军陆战队军官的家人奔走。 |
31229 | 1 | 2 | 8.729340 | [5.088283069198951e-05, 2.4471093638567254e-05, 0.18256260454654694, 0.00033902408904396] | 尽管福特前景乐观,汽车股仍下跌 尽管福特汽车公司发布了强劲的盈利前景,但由于担心该行业的销售额可能不如先前预期的那么强劲,汽车股周五大多走低。 |
19737 | 3 | 1 | 8.545797 | [4.129256194573827e-05, 0.1872873306274414, 4.638762402464636e-05, 0.00010757221753010526] | Mladin 从 Road Atlanta 获释 澳大利亚 #39;s Mat Mladin 在今年美国 AMA 雪佛兰超级摩托车锦标赛的倒数第二轮比赛中完成了双冠王,此前他夺得了 |
60726 | 2 | 0 | 8.437369 | [0.5235446095466614, 4.4463453377829865e-05, 3.5171411582268775e-05, 8.480428368784487e-05] | 绿区自杀式炸弹袭击造成 10 人死亡 武装分子周四将手提炸药带入巴格达防御最严密的区域,并在几秒钟内引爆,造成 10 人死亡,20 人受伤。 |
28307 | 3 | 1 | 8.386065 | [0.00018589739920571446, 0.42903241515159607, 2.5073826691368595e-05, 3.97983385482803e-05] | 德克萨斯州田野遭雷击,40 人受伤 (美联社) 美联社 - 当闪电击中德克萨斯州东部格拉普兰高中橄榄球队的训练场附近时,约 40 名球员和教练受伤,其中两人伤势危重,当局周二晚上说。 |
[2]:
# Save this to a file for further analysis
# losses_df.to_json("agnews_train_loss.json", orient="records", lines=True)
虽然使用 Pandas 和 Jupyter 笔记本对于初始检查和程序化分析很有用。如果你想快速浏览示例、重新标记它们并与其他项目成员共享,Argilla 为你提供了一种直接的方法。让我们看看如何操作。
3. 将高损失示例记录到 Argilla 中#
使用令人惊叹的 Hugging Face Hub,我们共享了结果数据集,你可以在 这里 找到并直接使用 datasets 库加载
现在,我们将前 500 个示例记录到 Argilla 数据集中
[ ]:
# If you have skipped the first two steps you can load the dataset here:
dataset = load_dataset("dvilasuero/ag_news_training_set_losses", split="train")
losses_df = dataset.to_pandas()
ds = load_dataset("ag_news", split="test") # only for getting the label names
[7]:
# Create a Text classification record for logging into Argilla
def make_record(row):
return rg.TextClassificationRecord(
text=row.text,
# This is the "gold" label in the original dataset
annotation=[(ds.features["label"].names[row.label])],
# This is the prediction together with its probability
prediction=[
(
ds.features["label"].names[row.predicted_label],
row.predicted_probas[row.predicted_label],
)
],
# Metadata fields can be used for sorting and filtering, here we log the loss
metadata={"loss": row.loss},
# Who makes the prediction
prediction_agent="andi611/distilbert-base-uncased-ner-agnews",
# Source of the gold label
annotation_agent="ag_news_benchmark",
)
# If you want to log the full dataset remove the indexing
top_losses = losses_df.sort_values("loss", ascending=False)[0:499]
# Build Argilla records
records = top_losses.apply(make_record, axis=1)
rg.log(records, name="ag_news_error_analysis")
4. 使用 Argilla UI 进行检查和重新标记#
在此步骤中,我们有一个 Argilla 数据集可用于探索和注释。对于此用例,一个有用的功能是排序。使用 Argilla,你可以通过组合来自标准字段(例如 score
)和自定义字段(通过元数据字段)的不同字段对示例进行排序。在本例中,我们记录了损失,因此我们可以按损失降序(首先显示损失较高的示例)对训练示例进行排序。
为了准备本教程,我们手动检查并重新标记了前 50 个示例。此外,我们在 Hugging Face Hub 中共享了这个重新注释的数据集。在下一节中,我们将向你展示在 Hub 中共享 Argilla 数据集有多么容易。
5. 在 Hugging Face Hub 中共享数据集#
首先,让我们加载重新注释的示例。重新标记的示例被标记为由用户 argilla
annotated_by
,这是使用 Docker 启动 Argilla 时的默认用户。我们可以使用 query
参数检索仅这些记录,如下所示
[11]:
dataset = rg.load("ag_news_error_analysis", query="annotated_by:argilla").to_pandas()
# Let's do some transformations before uploading the dataset
dataset["loss"] = dataset.metadata.transform(lambda r: r["loss"])
dataset = dataset.rename(columns={"annotation": "corrected_label"})
dataset.head()
[11]:
inputs | prediction | corrected_label | prediction_agent | annotation_agent | multi_label | explanation | id | metadata | status | event_timestamp | metrics | text | loss | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | {'text': 'Top nuclear official briefs Majlis c...' | [(World, 0.1832696944)] | World | andi611/distilbert-base-uncased-ner-agnews | argilla | False | None | 071a1014-71e7-41f4-83e4-553ba47610cf | {'loss': 7.6656146049} | Validated | None | {} | Top nuclear official briefs Majlis committee T... | 7.665615 |
1 | {'text': 'Fischer Delivers Strong Message in S...' | [(World, 0.0695228428)] | World | andi611/distilbert-base-uncased-ner-agnews | argilla | False | None | 07c8c4f6-3288-46f4-a618-3da4a537e605 | {'loss': 7.9892320633} | Validated | None | {} | Fischer Delivers Strong Message in Syria Germa... | 7.989232 |
2 | {'text': 'The Politics of Time and Dispossessi...' | [(Sci/Tech, 0.100481838)] | Sci/Tech | andi611/distilbert-base-uncased-ner-agnews | argilla | False | None | 0965a0d1-4886-432a-826a-58e99dfd9972 | {'loss': 7.133708477} | Validated | None | {} | The Politics of Time and Dispossession Make a ... | 7.133708 |
3 | {'text': 'Hadash Party joins prisoners #39; st...' | [(World, 0.1749624908)] | World | andi611/distilbert-base-uncased-ner-agnews | argilla | False | None | 09fc7065-a2c8-4041-adf8-34e029a7fde0 | {'loss': 7.339015007} | Validated | None | {} | Hadash Party joins prisoners #39; strike for 2... | 7.339015 |
4 | {'text': 'China May Join \$10Bln Sakhalin-2 Ru...' | [(Business, 0.1370282918)] | Business | andi611/distilbert-base-uncased-ner-agnews | argilla | False | None | 1ef97c49-2f0f-43be-9b28-80a291cb3b1d | {'loss': 7.321100235} | Validated | None | {} | China May Join \$10Bln Sakhalin-2 Russia said ... | 7.321100 |
[12]:
# Let's add the original dataset labels to share them together with the corrected ones
# We sort by ascending loss our corrected dataset
dataset = dataset.sort_values("loss", ascending=False)
# we add original labels in string form
id2label = list(dataset.corrected_label.unique())
original_labels = [id2label[i] for i in top_losses[0:50].label.values]
dataset["original_label"] = original_labels
现在让我们将其转换为 Dataset
并定义特征架构
[13]:
ds = dataset[["text", "corrected_label", "original_label"]].to_dict(orient="list")
hf_ds = Dataset.from_dict(
ds,
features=Features(
{
"text": Value("string"),
"corrected_label": ClassLabel(names=list(dataset.corrected_label.unique())),
"original_label": ClassLabel(names=list(dataset.corrected_label.unique())),
}
),
)
使用 push_to_hub
方法上传数据集就像
[ ]:
hf_ds.push_to_hub("argilla/ag_news_corrected_labels")
现在数据集在 Hub 上公开 可用!
总结#
在本教程中,我们学习了使用模型损失来查找训练数据集中的标签错误。Argilla 的 UI 可以轻松地按损失对数据进行排序,快速浏览数据集以更正标签错误。