Open In Colab  View Notebook on GitHub

🧐 使用 cleanlab 查找标签错误#

在本教程中,我们将利用 Argillacleanlab 来查找、发现和纠正潜在的标签错误。您可以通过遵循 4 个基本的 MLOps 步骤来完成此操作

  • 💾 加载包含潜在标签错误的数据集,这里我们使用 ag_news 数据集;

  • 💻 训练模型以对测试集进行预测,这里我们使用轻量级的 sklearn 模型;

  • 🧐 通过 Argilla 使用 cleanlab,并在测试集中获取潜在的标签错误候选对象;

  • 🖍 使用 Argilla Web 应用程序快速舒适地发现和纠正标签错误;

monitoring-textclassification-cleanlab-explainability

简介#

正如 Curtis G. Northcutt et al. 最近表明的那样,即使在最常被引用的、用于衡量机器学习领域进展的测试集中,标签错误也普遍存在。他们引入了一个新的原则性框架,称为置信学习,以“识别标签错误、表征标签噪声以及使用噪声标签进行学习”。它以 cleanlab Python 包 的形式开源,该软件包支持在数据集中查找、量化和学习标签错误。

Argillacleanlab 提供了内置支持,并使在数据集中查找潜在的标签错误变得轻而易举。在本教程中,我们将尝试发现和纠正著名的 ag_news 数据集中的标签错误,该数据集通常用于衡量 NLP 中分类模型的性能。

运行 Argilla#

对于本教程,您需要运行 Argilla 服务器。部署和运行 Argilla 有两个主要选项

在 Hugging Face Spaces 上部署 Argilla:如果您想使用外部 Notebook(例如 Google Colab)运行教程,并且您在 Hugging Face 上有一个帐户,则只需点击几下即可在 Spaces 上部署 Argilla

deploy on spaces

有关配置部署的详细信息,请查看 官方 Hugging Face Hub 指南

使用 Argilla 的快速入门 Docker 镜像启动 Argilla:如果您想在 本地计算机上运行 Argilla,建议使用此选项。请注意,此选项仅允许您在本地运行教程,而不能与外部 Notebook 服务一起运行。

有关部署选项的更多信息,请查看文档的部署部分。

提示

本教程是一个 Jupyter Notebook。有两种运行它的选项

  • 使用此页面顶部的“在 Colab 中打开”按钮。此选项允许您直接在 Google Colab 上运行 Notebook。不要忘记将运行时类型更改为 GPU,以加快模型训练和推理速度。

  • 单击页面顶部的“查看源代码”链接下载 .ipynb 文件。此选项允许您下载 Notebook 并在本地计算机或您选择的 Jupyter Notebook 工具上运行它。

[ ]:
%pip install argilla datasets scikit-learn cleanlab -qqq

让我们导入 Argilla 模块以进行数据读取和写入

[ ]:
import argilla as rg

如果您使用 Docker 快速入门镜像或 Hugging Face Spaces 运行 Argilla,则需要使用 URLAPI_KEY 初始化 Argilla 客户端

[ ]:
# Replace api_url with the url to your HF Spaces URL if using Spaces
# Replace api_key if you configured a custom API key
# Replace workspace with the name of your workspace
rg.init(
    api_url="https://#:6900",
    api_key="owner.apikey",
    workspace="admin"
)

如果您正在运行私有的 Hugging Face Space,您还需要按如下方式设置 HF_TOKEN

[ ]:
# # Set the HF_TOKEN environment variable
# import os
# os.environ['HF_TOKEN'] = "your-hf-token"

# # Replace api_url with the url to your HF Spaces URL
# # Replace api_key if you configured a custom API key
# # Replace workspace with the name of your workspace
# rg.init(
#     api_url="https://[your-owner-name]-[your_space_name].hf.space",
#     api_key="owner.apikey",
#     workspace="admin",
#     extra_headers={"Authorization": f"Bearer {os.environ['HF_TOKEN']}"},
# )

最后,让我们包含所需的导入

[ ]:
from datasets import load_dataset

from sklearn.feature_extraction.text import CountVectorizer
from sklearn.naive_bayes import MultinomialNB
from sklearn.pipeline import Pipeline

from argilla.labeling.text_classification import find_label_errors

启用遥测#

我们从您与我们教程的互动中获得宝贵的见解。为了改进我们自己,为您提供最合适的内容,使用以下代码行将帮助我们了解本教程是否有效地为您服务。虽然这是完全匿名的,但如果您愿意,可以选择跳过此步骤。有关更多信息,请查看 遥测 页面。

[ ]:
try:
    from argilla.utils.telemetry import tutorial_running
    tutorial_running()
except ImportError:
    print("Telemetry is introduced in Argilla 1.20.0 and not found in the current installation. Skipping telemetry.")

注意

如果您想跳过本教程的前三个部分,并且只想发现和纠正 Argilla Web 应用程序中的标签错误,您可以直接从 Hugging Face Hub 加载记录

records_with_label_errors = rg.read_datasets(
    load_dataset("argilla/cleanlab-label_errors", split="train"),
    task="TextClassification",
)

1. 加载数据集#

我们首先通过非常方便的 datasets 库下载 ag_news 数据集。然后,我们提取训练集和测试集,以及此分类任务的标签。我们还对训练集进行洗牌,因为默认情况下它是按分类标签排序的。

[ ]:
# Download the data
dataset = load_dataset("ag_news")

# Get the train set and shuffle
ds_train = dataset["train"].shuffle(seed=43)

# Get the test set
ds_test = dataset["test"]

# Get the classification labels
labels = ds_train.features["label"].names

2. 训练模型#

在本教程中,我们将使用多项式朴素贝叶斯分类器,这是一个轻量级且易于训练的 sklearn 模型。但是,您可以使用任何您选择的模型,只要它在其预测中包含所有标签的概率即可。

我们的分类器的特征将只是输入文本的标记计数。

定义分类器后,我们可以使用训练集对其进行拟合。由于我们使用的是相当轻量级的模型,因此这不应花费太长时间。

[ ]:
# Define our classifier as a pipeline of token counts + naive bayes model
classifier = Pipeline([("vect", CountVectorizer()), ("clf", MultinomialNB())])

# Fit the classifier
classifier.fit(X=ds_train["text"], y=ds_train["label"])

让我们检查一下我们的模型在测试集上的表现如何。

[ ]:
# Compute the test accuracy
classifier.score(
    X=ds_test["text"],
    y=ds_test["label"],
)

我们应该获得 0.90 的良好准确率,特别是因为我们仅使用标记计数作为输入特征。

3. 获取标签错误候选对象#

作为在测试集中获取标签错误候选对象的第一步,我们必须预测所有标签的概率。

[ ]:
# Get predicted probabilities for all labels
probabilities = classifier.predict_proba(ds_test["text"])

有了预测,我们创建包含文本输入、模型预测、潜在的错误注释以及您选择的一些元数据的 Argilla 记录。

[ ]:
# Create records for the test set
records = [
    rg.TextClassificationRecord(
        text=data["text"],
        prediction=list(zip(labels, prediction)),
        annotation=labels[data["label"]],
        metadata={"split": "test"},
    )
    for data, prediction in zip(ds_test, probabilities)
]

我们可以将这些记录直接记录到 Argilla 并方便地用肉眼检查它们,检查每个文本输入的注释。但是,在这里我们将使用更快的方法,即利用 Argilla 对 cleanlab 的内置支持。您只需从 Argilla 导入 find_label_errors 函数,然后传入记录列表即可。就是这样。

[ ]:
# Get records with potential label errors
records_with_label_error = find_label_errors(records)

records_with_label_error 列表包含大约 600 个潜在标签错误的候选对象,这超过了我们测试数据的 8%。

4. 发现和纠正标签错误#

现在让我们将这些记录记录到 Argilla Web 应用程序中,以便方便地用肉眼检查它们,并同时快速纠正潜在的标签错误。

[ ]:
# Uncover label errors in the Argilla web app
rg.log(records_with_label_error, "label_errors")

默认情况下,records_with_label_error 列表中的记录按其包含标签错误的 likelihood 排序。默认情况下,它们还将包含一个名为 “label_error_candidate” 的元数据,该元数据反映了列表中的顺序。您可以使用 Argilla Web 应用程序中的此字段对记录进行排序。

我们可以确认,最有可能的候选对象确实是明显的标签错误。在候选列表的末尾,示例变得更加模棱两可,并且黄金注释是否错误并不立即显而易见。

总结#

借助 Argilla,您可以快速方便地在数据中查找标签错误。对 cleanlab 的内置支持,以及 Argilla Web 应用程序优化的用户体验,使该过程变得轻而易举,并允许您高效地动态纠正标签错误。

只需几个步骤,您就可以快速检查您的测试数据集是否受到标签错误的严重影响,以及您的基准在实践中是否真正有意义。也许您不太复杂的模型最终会击败您资源匮乏的超级模型,并且部署过程也变得更加容易 😀。

尽管在本教程中我们仅使用了 sklearn 模型,但 Argilla 并不关心模型架构或您正在使用的框架。它只关心底层数据,并允许您将更多人纳入您的人工智能生命周期循环中。

附录 I:使用交叉验证在训练数据中查找标签错误#

为了检查训练数据中的标签错误,您可以回退到交叉验证技术以获得样本外预测。使用 sklearn 中的分类器,交叉验证非常容易,您只需一行代码即可方便地完成它。之后,创建 Argilla 记录、查找标签错误候选对象和发现它们的步骤与上面教程中显示的步骤相同。

[ ]:
from sklearn.model_selection import cross_val_predict

# Get predicted probabilities for the whole dataset via cross-validation
cv_probs = cross_val_predict(
    classifier,
    X=ds_train["text"] + ds_test["text"],
    y=ds_train["label"] + ds_test["label"],
    cv=int(len(ds_train) / len(ds_test)),
    method="predict_proba",
    n_jobs=-1,
)

[ ]:
# Create records for the training set
records = [
    rg.TextClassificationRecord(
        text=data["text"],
        prediction=list(zip(labels, prediction)),
        annotation=labels[data["label"]],
        metadata={"split": "train"},
    )
    for data, prediction in zip(ds_train, cv_probs)
]

# Uncover label errors for the train set in the Argilla web app
rg.log(find_label_errors(records), "label_errors_in_train")

在这里,我们找到大约 9400 条记录,其中包含潜在的标签错误,这也大约占训练数据的 8%。

附录 II:将数据集记录到 Hugging Face Hub#

在这里,我们将向您展示一个示例,说明如何将 Argilla 数据集(记录)推送到 Hugging Face Hub。通过这种方式,您可以有效地版本化您的任何 Argilla 数据集。

[ ]:
records = rg.load("label_errors")
records.to_datasets().push_to_hub("<name of the dataset on the HF Hub>")