Open In Colab  View Notebook on GitHub

🧱 使用 Sentence Transformers 增强弱监督规则#

在本教程中,我们将展示如何在 Argilla 中使用句子嵌入扩展弱监督工作流。我们从 Argilla 弱监督教程 中介绍的弱监督工作流开始,并通过扩展其规则的覆盖范围来改进其结果。

  • ✍️ 我们定义规则并为 ag_news 数据集生成弱标签。

  • 🧱 我们使用来自 Sentence Transformers 库的句子嵌入来扩展我们的弱标签。

  • 📰 最后,我们使用标签模型生成数据,用于训练下游模型作为新闻分类器。

  • 🚀 通过简单地扩展我们的弱标签,我们在准确率上实现了 4% 的提升

Original and extended coverage of the weak labels

上面的两个图表显示了在使用嵌入扩展弱标签之前和之后的覆盖率。每个点对应于 ag news 测试集中的一个示例。颜色表示示例的相应类别。透明圆圈中的点至少被一个规则覆盖。

简介#

标注函数通常具有高精度,但覆盖率较低。只有严格匹配给定函数所确定条件的记录才会被标注,而其他潜在的候选记录将被排除在外。

基于 Hazy Research 小组的发现,我们提出了一种通过使用句子嵌入扩展标注函数产生的弱标签来解决此问题的方法。

我们通过给未标注的记录赋予与其在嵌入空间中最接近的已标注邻居相同的标签来扩展标注函数的覆盖率,前提是它们之间的余弦相似度得分高于某个阈值。

在本教程中,我们将展示,通过调整这些相似度阈值并选择合适的句子嵌入,我们能够显着提高弱监督工作流产生的下游分类器的准确率。

运行 Argilla#

对于本教程,您需要运行 Argilla 服务器。部署和运行 Argilla 有两个主要选项

在 Hugging Face Spaces 上部署 Argilla:如果您想使用外部 notebook(例如,Google Colab)运行教程,并且您在 Hugging Face 上有一个帐户,您只需点击几下即可在 Spaces 上部署 Argilla

deploy on spaces

有关配置部署的详细信息,请查看 Hugging Face Hub 官方指南

使用 Argilla 的快速入门 Docker 镜像启动 Argilla:如果您想在 本地机器上运行 Argilla,这是推荐选项。请注意,此选项仅允许您在本地运行本教程,而不能使用外部 notebook 服务。

有关部署选项的更多信息,请查看文档的部署部分。

提示

本教程是一个 Jupyter Notebook。有两种选项可以运行它

  • 使用此页面顶部的“在 Colab 中打开”按钮。此选项允许您直接在 Google Colab 上运行 notebook。不要忘记将运行时类型更改为 GPU 以加快模型训练和推理速度。

  • 通过单击页面顶部的“查看源代码”链接下载 .ipynb 文件。此选项允许您下载 notebook 并在本地机器或您选择的 Jupyter notebook 工具上运行它。

设置#

对于本教程,您需要使用 pip 安装 Argilla 客户端和一些第三方库

[ ]:
%pip install argilla faiss-cpu sentence_transformers transformers datasets snorkel -qqq

让我们导入 Argilla 模块以进行数据读取和写入

[4]:
import argilla as rg

如果您使用 Docker 快速入门镜像或公共 Hugging Face Spaces 运行 Argilla,则需要使用 URLAPI_KEY 初始化 Argilla 客户端

[5]:
# Replace api_url with the url to your HF Spaces URL if using Spaces
# Replace api_key if you configured a custom API key
rg.init(
    api_url="https://#:6900",
    api_key="admin.apikey"
)

如果您运行的是私有 Hugging Face Space,您还需要按如下方式设置 HF_TOKEN

[ ]:
# # Set the HF_TOKEN environment variable
# import os
# os.environ['HF_TOKEN'] = "your-hf-token"

# # Replace api_url with the url to your HF Spaces URL
# # Replace api_key if you configured a custom API key
# rg.init(
#     api_url="https://[your-owner-name]-[your_space_name].hf.space",
#     api_key="admin.apikey",
#     extra_headers={"Authorization": f"Bearer {os.environ['HF_TOKEN']}"},
# )

现在让我们添加所需的导入

[16]:
from datasets import load_dataset
from argilla.labeling.text_classification import Rule, add_rules, WeakLabels, Snorkel
from sentence_transformers import SentenceTransformer
from tqdm.auto import tqdm

启用遥测#

我们从您与我们的教程的互动中获得宝贵的见解。为了改进我们自己,为您提供最合适的内容,使用以下代码行将帮助我们了解本教程是否有效地为您服务。虽然这是完全匿名的,但如果您愿意,可以选择跳过此步骤。有关更多信息,请查看 遥测 页面。

[ ]:
try:
    from argilla.utils.telemetry import tutorial_running
    tutorial_running()
except ImportError:
    print("Telemetry is introduced in Argilla 1.20.0 and not found in the current installation. Skipping telemetry.")

详细工作流#

使用句子嵌入执行弱监督的典型工作流程是

  1. 使用您的原始数据集创建 Argilla 数据集。如果您有一些已标注的数据,您可以将其记录到同一数据集中。

  2. 使用 UI 中的“规则定义”模式定义一组弱标注规则。

  3. 创建一个 WeakLabels 对象并应用规则。您可以从数据集中加载规则,并使用 Python 添加其他规则和标注函数。通常,您会在步骤 2 和此步骤之间迭代。

  4. 通过为每个记录(矩阵的行)提供句子嵌入和为每个规则(矩阵的列)提供相似度阈值来扩展 WeakLabels 对象。

  5. 对扩展的弱标签感到满意后,将 WeakLabels 实例的扩展矩阵与您选择的库/方法一起使用,以构建训练集,甚至训练下游文本分类模型。您可以在此步骤和步骤 4 之间迭代,尝试多个阈值和嵌入可能性,直到获得令人满意的结果。

本指南向您展示了一个使用 Snorkel 的端到端示例。您也可以选择使用 Argilla 中提供的任何其他标签模型。如果您有兴趣了解其他选项,请查看我们的 弱监督指南

数据集#

我们将使用 ag_news 数据集,这是一个著名的基准文本分类模型。

但是,为了保证公平的比较,我们将在验证集上优化阈值,并将测试集留作最终评估。

[17]:
agnews = load_dataset("ag_news")

agnews_train, agnews_valid = (
    agnews["train"].train_test_split(test_size=4000, seed=43).values()
)
WARNING:datasets.builder:Found cached dataset ag_news (/root/.cache/huggingface/datasets/ag_news/default/0.0.0/bc2bcb40336ace1a0374767fc29bb0296cdaf8a6da7298436239c54d79180548)

1. 创建包含未标注数据和测试数据的 Argilla 数据集#

让我们将已标注和未标注的记录集加载到 Argilla 中。

[18]:
# build our labelled records to evaluate our heuristic rules and optimize the thresholds
records = [
    rg.TextClassificationRecord(
        text=record["text"],
        metadata={"split": "labelled"},
        annotation=agnews_valid.features["label"].int2str(record["label"]),
        id=f"valid_{idx}",
    )
    for idx, record in enumerate(agnews_valid)
]

# build our unlabelled records
records += [
    rg.TextClassificationRecord(
        text=record["text"],
        metadata={"split": "unlabelled"},
        id=f"train_{idx}",
    )
    for idx, record in enumerate(agnews_train.select(range(8000)))
]

# log the records to Argilla
rg.log(records, name="agnews")
12000 records logged to https://dvilasuero-argilla-space-52456aa.hf.space/datasets/team/agnews
[18]:
BulkResponse(dataset='agnews', processed=12000, failed=0)

完成此步骤后,您将拥有一个完全可浏览的数据集,可以通过 Argilla Web 应用程序 访问它。

2. 定义规则#

我们将使用以下规则。

[19]:
# define queries and patterns for each category (using ES DSL)
queries = [
    (["money", "financ*", "dollar*"], "Business"),
    (["war", "gov*", "minister*", "conflict"], "World"),
    (["footbal*", "sport*", "game", "play*"], "Sports"),
    (["sci*", "techno*", "computer*", "software", "web"], "Sci/Tech"),
]

# define rules
rules = [Rule(query=term, label=label) for terms, label in queries for term in terms]

现在我们可以按如下方式将它们添加到数据集中

[20]:
add_rules(dataset="agnews", rules=rules)

3. 构建和分析弱标签#

从我们的规则构建弱标签后,它们的摘要显示如下

[21]:
# apply the rules to the dataset to obtain the weak labels
weak_labels = WeakLabels(dataset="agnews")
weak_labels.summary()
[21]:
label coverage annotated_coverage overlaps conflicts correct incorrect precision
money {Business} 0.008000 0.00925 0.002750 0.002083 13 24 0.351351
financ* {Business} 0.020667 0.02100 0.005417 0.004667 56 28 0.666667
dollar* {Business} 0.016250 0.01550 0.003833 0.002750 42 20 0.677419
war {World} 0.013750 0.01175 0.003000 0.001333 34 13 0.723404
gov* {World} 0.045167 0.04000 0.011083 0.006000 76 84 0.475000
minister* {World} 0.028917 0.03175 0.007167 0.002583 114 13 0.897638
conflict {World} 0.003167 0.00300 0.001333 0.000250 10 2 0.833333
footbal* {Sports} 0.014333 0.01475 0.005583 0.000333 53 6 0.898305
sport* {Sports} 0.020750 0.02375 0.006250 0.001333 87 8 0.915789
game {Sports} 0.039917 0.04150 0.013417 0.001917 132 34 0.795181
play* {Sports} 0.055000 0.05875 0.016667 0.004500 168 67 0.714894
sci* {Sci/Tech} 0.015833 0.01700 0.002583 0.001250 55 13 0.808824
techno* {Sci/Tech} 0.028250 0.02900 0.008500 0.002667 82 34 0.706897
computer* {Sci/Tech} 0.027917 0.02925 0.011583 0.004167 97 20 0.829060
software {Sci/Tech} 0.031000 0.03225 0.009667 0.002500 104 25 0.806202
web {Sci/Tech} 0.018250 0.01975 0.004417 0.001500 70 9 0.886076
total {Business, World, Sports, Sci/Tech} 0.327583 0.33450 0.053667 0.017917 1193 400 0.748901

在接下来的步骤中,我们将尝试通过句子嵌入扩展我们的弱标签矩阵。通过这种方式,我们将增加规则的覆盖率,同时保持可接受的精度。

4. 使用弱标签#

使用 Snorkel 的标签模型#

Snorkel 的标签模型是迄今为止使用弱监督的最流行选项,Argilla 为其提供了内置支持。在这里,我们将我们的弱标签拟合到 Snorkel 标签模型,然后我们检查规则覆盖的记录的性能。

[22]:
# create the Snorkel label model
label_model = Snorkel(weak_labels)

# fit the model, for the learning rate and epochs we ran a quick grid search
label_model.fit(lr=0.002, n_epochs=10, progress_bar=False)

# evaluate the label model
print(label_model.score(output_str=True))
              precision    recall  f1-score   support

      Sports       0.79      0.95      0.86       380
    Sci/Tech       0.80      0.76      0.78       454
       World       0.70      0.82      0.75       257
    Business       0.69      0.40      0.50       247

    accuracy                           0.76      1338
   macro avg       0.74      0.73      0.72      1338
weighted avg       0.76      0.76      0.75      1338

5. 扩展弱标签#

让我们扩展我们的弱标签,看看这如何影响 Snorkel 标签模型的评估。

生成句子嵌入#

让我们为弱标签矩阵的每个记录生成句子嵌入。通过强大的通用预训练嵌入,或通过专门为手头任务的领域预训练的嵌入,将获得最佳结果。

在这里,我们选择来自著名的 Sentence Transformers 库all-MiniLM-L6-v2 嵌入。Argilla 允许我们尝试来自任何来源的嵌入,只要它们以二维数组的形式提供给弱标签矩阵。

例如,除了 Sentence Transformers 之外,我们还可以使用 OpenAI 嵌入,或者来自 Tensorflow Hub 的文本嵌入。

[23]:
# instantiate the model for the sentence embeddings
# we strongly recommend using a GPU for the computation of the embeddings
model = SentenceTransformer("all-MiniLM-L6-v2", device="cpu")

# compute the embeddings and store them in a list
embeddings = []
for rec in tqdm(weak_labels.records()):
    embeddings.append(model.encode(rec.text))

设置阈值#

我们首先对哪些阈值适用于此特定弱标签矩阵做出有根据的猜测。我们将所有规则的阈值设置为 0.60。这意味着,对于每个规则,如果记录的余弦相似度高于此值,则该记录的标签将扩展到其最近的未标注邻居。

[24]:
thresholds = [0.6] * len(rules)

扩展弱标签矩阵#

我们通过提供阈值和句子嵌入来调用 extend_matrix 方法。

[25]:
weak_labels.extend_matrix(thresholds, embeddings)

随着弱标签矩阵的扩展,我们可以看到覆盖率上升了。

[26]:
weak_labels.summary()
[26]:
label coverage annotated_coverage overlaps conflicts correct incorrect precision
money {Business} 0.017667 0.02025 0.009750 0.008083 43 38 0.530864
financ* {Business} 0.037500 0.03800 0.016417 0.013917 99 53 0.651316
dollar* {Business} 0.039667 0.04050 0.020583 0.017750 118 44 0.728395
war {World} 0.031833 0.03125 0.015083 0.008250 81 44 0.648000
gov* {World} 0.096083 0.08750 0.042000 0.024417 188 162 0.537143
minister* {World} 0.053750 0.05350 0.023000 0.008083 197 17 0.920561
conflict {World} 0.010583 0.00925 0.007833 0.003917 24 13 0.648649
footbal* {Sports} 0.018333 0.01925 0.007833 0.000333 71 6 0.922078
sport* {Sports} 0.036667 0.03900 0.014417 0.004000 142 14 0.910256
game {Sports} 0.062417 0.06525 0.026750 0.004917 211 50 0.808429
play* {Sports} 0.082417 0.08650 0.033833 0.011583 248 98 0.716763
sci* {Sci/Tech} 0.023667 0.02500 0.003750 0.001917 80 20 0.800000
techno* {Sci/Tech} 0.059667 0.05850 0.029833 0.019167 130 104 0.555556
computer* {Sci/Tech} 0.052000 0.05250 0.029583 0.013750 165 45 0.785714
software {Sci/Tech} 0.051917 0.05125 0.025667 0.010417 162 43 0.790244
web {Sci/Tech} 0.036417 0.03650 0.015667 0.007167 121 25 0.828767
total {Business, World, Sports, Sci/Tech} 0.523500 0.52525 0.134917 0.057917 2080 776 0.728291

我们还看到,我们规则的平均精度下降了(从 0.75 降至 0.66)。然而,这种下降可以通过我们的标签模型部分补偿。如果我们再次将弱标签拟合到 Snorkel 标签模型,我们可以看到,支持度显着提高,正如预期的那样,而准确率的下降幅度很小。

[27]:
label_model = Snorkel(weak_labels)
label_model.fit(lr=0.002, n_epochs=10, progress_bar=False)
print(label_model.score(output_str=True))
              precision    recall  f1-score   support

      Sports       0.81      0.95      0.87       550
    Sci/Tech       0.77      0.78      0.77       655
       World       0.70      0.84      0.76       468
    Business       0.72      0.38      0.50       428

    accuracy                           0.76      2101
   macro avg       0.75      0.74      0.73      2101
weighted avg       0.75      0.76      0.74      2101

您可以查看 附录,以详细了解弱标签矩阵如何在后台扩展。

我们建议以某种方式优化阈值以获得最高的性能提升,而不是使用通用的固定阈值。我们在 附录 中详细描述的优化产生了以下阈值

[28]:
optimized_thresholds = [
    0.4,
    0.4,
    0.6,
    0.4,
    0.5,
    0.8,
    1.0,
    0.4,
    0.4,
    0.5,
    0.6,
    0.4,
    0.4,
    0.6,
    0.6,
    0.8,
]

每次使用阈值和嵌入调用 extend_matrix 都会构建一个 faiss 索引,该索引将被缓存在弱标签对象中。

如果我们在下次调用 extend_matrix 时不提供嵌入,则将重新利用此索引,并且新的扩展矩阵将替换当前的扩展矩阵。因此,使用新阈值扩展矩阵非常便宜。

[29]:
weak_labels.extend_matrix(optimized_thresholds)
label_model = Snorkel(weak_labels)
label_model.fit(lr=0.002, n_epochs=10, progress_bar=False)
print(label_model.score(output_str=True))
              precision    recall  f1-score   support

      Sports       0.87      0.90      0.88       883
    Sci/Tech       0.69      0.72      0.70       880
       World       0.78      0.74      0.76       751
    Business       0.64      0.62      0.63       826

    accuracy                           0.75      3340
   macro avg       0.74      0.74      0.74      3340
weighted avg       0.75      0.75      0.75      3340

优化的阈值似乎进一步降低了标签模型的准确率,但也显着提高了覆盖率。

6. 训练下游模型#

现在,我们将训练与 之前的教程 中相同的下游模型,但基于由我们扩展的弱标签的标签模型生成的数据。

首先,让我们定义一个辅助函数,它基本上是先前教程中的复制粘贴。

[30]:
import pandas as pd
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.naive_bayes import MultinomialNB
from sklearn.pipeline import Pipeline
from sklearn import metrics


def train_and_evaluate_downstream_model(label_model):
    """
    Train a downstream model with the predictions of a label model and
    evaluate it with the test split of the ag news dataset
    """
    # get records with the predictions from the label model
    records = label_model.predict()

    # turn str labels into integers
    label2int = label_model.weak_labels.label2int

    # extract training data
    X_train = [rec.text for rec in records]
    y_train = [label2int[rec.prediction[0][0]] for rec in records]

    # define our final classifier
    classifier = Pipeline([("vect", CountVectorizer()), ("clf", MultinomialNB())])

    # fit the classifier
    classifier.fit(
        X=X_train,
        y=y_train,
    )

    # extract text and labels
    X_test = [rec["text"] for rec in agnews["test"]]
    y_test = [
        label2int[agnews["test"].features["label"].int2str(rec["label"])]
        for rec in agnews["test"]
    ]

    # get predictions for the test set
    predicted = classifier.predict(X_test)

    return metrics.classification_report(
        y_test, predicted, target_names=[k for k in label2int.keys() if k]
    )

现在让我们看看我们的下游模型与 之前的教程 中的原始模型相比如何。请记住,我们实现了约 82% 的准确率。

[31]:
print(train_and_evaluate_downstream_model(label_model))

              precision    recall  f1-score   support

      Sports       0.88      0.96      0.92      1900
    Sci/Tech       0.77      0.82      0.79      1900
       World       0.86      0.84      0.85      1900
    Business       0.82      0.71      0.76      1900

    accuracy                           0.83      7600
   macro avg       0.83      0.83      0.83      7600
weighted avg       0.83      0.83      0.83      7600

现在,使用我们扩展的弱标签矩阵,我们能够实现 86% 的准确率,比我们最初的方法提高了 4%。

总结#

在本教程中,您已经了解了如何使用词嵌入改进 Argilla 中的弱监督工作流。通过对原始工作流进行非常小的更改,我们能够显着提高下游模型的准确率。这表明 Argilla 可以大大减少人工标注者在编写规则之前需要付出的努力,然后才能取得卓越的成果。

附录:可视化更改#

让我们可视化弱标签矩阵如何在单行中扩展。

[32]:
import pandas as pd


def get_transitions(weak_labels, idx):
    transitions = list(
        list(zip(row[0], row[1]))
        for row in zip(weak_labels._matrix, weak_labels._extended_matrix)
    )
    transitions = transitions[idx]
    label_dict = weak_labels.int2label
    rule_labels = weak_labels.summary().reset_index()["index"].values.tolist()[:-1]
    transitions_df = []
    for rule_idx, rule in enumerate(rule_labels):
        old_label = transitions[rule_idx][0]
        new_label = transitions[rule_idx][1]
        transitions_df.append(
            {
                "rule": rule,
                "old label": label_dict[old_label],
                "new label": label_dict[new_label],
            }
        )
    transitions_df = pd.DataFrame(transitions_df)
    text = weak_labels.records()[idx].text
    return transitions_df, text


transitions, text = get_transitions(weak_labels, 15)

通过阅读选定的记录,我们可以清楚地注意到这是一篇关于世界政治的新闻文章,因此应归类为 World

[33]:
text
[33]:
'Nicaragua tells US it will destroy its antiaircraft missiles  MANAGUA, Nicaragua -- President Enrique Bolanos told US Defense Secretary Donald H. Rumsfeld yesterday that Nicaragua would completely eliminate a stockpile of hundreds of surface-to-air missiles with no expectation of compensation from the United States.'

让我们将此记录的原始弱标签矩阵的行("old label" 行)和扩展后的同一行("new label" 行)并排放置。

我们看到,这篇新闻文章在原始矩阵中未被我们的任何规则标注。

但是,它是两个 Business 文章的最近未标注邻居,与规则 financ*dollar* 匹配,并且其与它们的相似度得分高于我们选择的阈值。对于两个 World 文章,与规则 warminister* 匹配,以及对于一个与规则 sci* 匹配的 Sci/Tech 文章,也发生了同样的情况。

[34]:
transitions.transpose()

[34]:
0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
rule money financ* dollar* war gov* minister* conflict footbal* sport* game play* sci* techno* computer* software web
old label None None None None None None None None None None None None None None None None
new label None None None None None None None None None None None None None None None None

附录:优化阈值#

每次使用阈值和嵌入调用 extend_matrix 都会构建一个 faiss 索引,该索引将被缓存在弱标签对象中。

如果我们在下次调用 extend_matrix 时不提供嵌入,则将重新利用此索引,并且新的扩展矩阵将替换当前的扩展矩阵。这个新矩阵是根据我们的新相似度阈值对原始弱标签矩阵进行的扩展。

[35]:
# Let's try to set all thresholds to 0.8 instead of 0.6.
thresholds = [0.8] * len(rules)

# As we have already generated the index in our first call, we just need to provide the thresholds.
weak_labels.extend_matrix(thresholds)

有几种不同的方法可以找到扩展弱标签矩阵的最佳相似度阈值:我们将从计算成本最低到最高的顺序进行列出。

1. 阻止低重叠规则的扩展#

在将所有相似度阈值设置为合理值后,优化单个级别的相似度阈值的一个好方法是阻止低重叠规则的扩展,因为它们在扩展后更可能产生不准确的结果。

[36]:
summary = weak_labels.summary(normalize_by_coverage=True).reset_index().head(len(rules))
summary = summary.rename(columns={"index": "rule"})
summary = summary.sort_values(by="overlaps", ascending=True)[["rule", "overlaps"]]
summary = summary.reset_index()
summary

[36]:
index rule overlaps
0 11 sci* 0.158974
1 3 war 0.208092
2 2 dollar* 0.235577
3 4 gov* 0.236427
4 15 web 0.240664
5 5 minister* 0.244382
6 1 financ* 0.257692
7 8 sport* 0.287823
8 12 techno* 0.295580
9 10 play* 0.299259
10 14 software 0.303030
11 9 game 0.334694
12 0 money 0.340000
13 7 footbal* 0.380682
14 13 computer* 0.401662
15 6 conflict 0.404762
[37]:
thresholds = [0.6] * len(rules)

# Let's block the extension of the top 5 rules with the least overlap.
turn_off_index = summary["index"][0:6]

# We block the extension of a rule by setting its similarity threshold to 1.0.
for rule_index in turn_off_index:
    thresholds[rule_index] = 1.0

weak_labels.extend_matrix(thresholds)
label_model = Snorkel(weak_labels)
label_model.fit(lr=0.002, n_epochs=10, progress_bar=False)
print(train_and_evaluate_downstream_model(label_model))

              precision    recall  f1-score   support

      Sports       0.79      0.99      0.88      1900
    Sci/Tech       0.60      0.83      0.70      1900
       World       0.81      0.84      0.82      1900
    Business       0.91      0.31      0.46      1900

    accuracy                           0.74      7600
   macro avg       0.78      0.74      0.72      7600
weighted avg       0.78      0.74      0.72      7600

2. 暴力破解:在标签模型上进行网格搜索#

在这种方法中,我们将所有阈值设置为初始值,然后网格搜索每个阈值的最佳值。然后,我们优化开发集中标签模型的覆盖率和准确率之间的调和平均值。这将确保我们选择在两个指标之间具有最佳权衡的阈值。

我们获得了与先前方法相同的改进,在测试集上的最终准确率为 86%。

[38]:
def train_eval_labelmodel(ths):
    weak_labels.extend_matrix(ths)

    label_model = Snorkel(weak_labels)
    label_model.fit(lr=0.002, n_epochs=10, progress_bar=False)

    metrics = label_model.score()
    acc, sup, n = (
        metrics["accuracy"],
        metrics["macro avg"]["support"],
        len(weak_labels.annotation()),
    )
    coverage = sup / n
    return 2 * acc * coverage / (acc + coverage)

[39]:
import copy
from tqdm.auto import tqdm
import numpy as np

ths_range = np.arange(1, 0.3, -0.1)
n_ths = len(weak_labels.rules)

best_thresholds = [1.0] * n_ths
best_acc = 0.0
for i in tqdm(range(n_ths), total=n_ths):
    thresholds = best_thresholds.copy()
    for threshold in ths_range:
        thresholds[i] = threshold
        acc = train_eval_labelmodel(thresholds)
        if acc > best_acc:
            best_acc = acc
            best_thresholds = thresholds.copy()

[40]:
np.array(best_thresholds)

[40]:
array([0.4, 0.4, 0.4, 0.4, 0.4, 0.5, 1. , 0.4, 0.4, 0.4, 0.5, 0.4, 0.4,
       0.4, 0.5, 0.4])
[41]:
weak_labels.extend_matrix(best_thresholds)
label_model = Snorkel(weak_labels)
label_model.fit(lr=0.002, n_epochs=10, progress_bar=False)
print(train_and_evaluate_downstream_model(label_model))

              precision    recall  f1-score   support

      Sports       0.89      0.97      0.93      1900
    Sci/Tech       0.62      0.88      0.73      1900
       World       0.79      0.88      0.83      1900
    Business       0.90      0.33      0.48      1900

    accuracy                           0.77      7600
   macro avg       0.80      0.77      0.74      7600
weighted avg       0.80      0.77      0.74      7600

3. 暴力破解:在下游模型上进行网格搜索#

在这里,我们再次将所有阈值设置为初始值,并网格搜索每个单独阈值的最佳值,但现在我们优化开发集中下游模型的准确率。我们在测试集上获得了 85% 的最终准确率,略低于我们通过先前方法获得的准确率。

[42]:
# retrieve records with annotations
test_ds = weak_labels.records(has_annotation=True)

# extract text and labels
X_test_for_grid_search = [rec.text for rec in test_ds]
y_test_for_grid_search = [weak_labels.label2int[rec.annotation] for rec in test_ds]


def train_eval_downstream(ths):
    weak_labels.extend_matrix(ths)

    label_model = Snorkel(weak_labels)
    label_model.fit(lr=0.002, n_epochs=10, progress_bar=False)

    records = label_model.predict()

    X_train = [rec.text for rec in records]
    y_train = [weak_labels.label2int[rec.prediction[0][0]] for rec in records]

    classifier = Pipeline([("vect", CountVectorizer()), ("clf", MultinomialNB())])

    classifier.fit(
        X=X_train,
        y=y_train,
    )

    accuracy = classifier.score(
        X=X_test_for_grid_search,
        y=y_test_for_grid_search,
    )

    return accuracy

[43]:
from copy import copy
from tqdm.auto import tqdm

best_thresholds, best_acc = [1.0] * len(weak_labels.rules), 0
ths_range = np.arange(1, 0.3, -0.1)
n_ths = len(weak_labels.rules)

for i in tqdm(range(n_ths), total=n_ths):
    thresholds = best_thresholds.copy()
    for threshold in ths_range:
        thresholds[i] = threshold
        acc = train_eval_downstream(thresholds)
        if acc > best_acc:
            best_acc = acc
            best_thresholds = thresholds.copy()

[44]:
np.array(best_thresholds)

[44]:
array([0.4, 0.6, 0.6, 0.5, 1. , 0.6, 0.6, 0.5, 0.5, 0.5, 1. , 0.4, 1. ,
       0.7, 1. , 0.9])
[45]:
weak_labels.extend_matrix(best_thresholds)
label_model = Snorkel(weak_labels)
label_model.fit(lr=0.002, n_epochs=10, progress_bar=False)
print(train_and_evaluate_downstream_model(label_model))

              precision    recall  f1-score   support

      Sports       0.86      0.97      0.91      1900
    Sci/Tech       0.79      0.77      0.78      1900
       World       0.88      0.82      0.85      1900
    Business       0.77      0.75      0.76      1900

    accuracy                           0.83      7600
   macro avg       0.82      0.83      0.82      7600
weighted avg       0.82      0.83      0.82      7600

阈值优化技巧#

使用大型下游模型(如 Transformer)进行网格搜索可能非常昂贵。在这种情况下,我们可以考虑仅优化阈值的子集,或者在开发集的小样本上优化所有阈值。

尽管在本教程中我们按顺序执行网格搜索,但只要我们为每个进程或线程制作弱标签对象的深层副本,就可以并行执行网格搜索来加快速度。

附录:绘图扩展#

[ ]:
%pip uninstall umap
%pip install umap-learn
[ ]:
import umap.umap_ as umap
import matplotlib.pyplot as plt

umap_data = umap.UMAP(
    n_neighbors=15, n_components=2, min_dist=0.0, metric="cosine"
).fit_transform(embeddings)

df = rg.DatasetForTextClassification(weak_labels.records()).to_pandas()
df["x"], df["y"] = umap_data[:, 0], umap_data[:, 1]
df["wl"] = [em for em in weak_labels._matrix]
df["wl_ext"] = [em for em in weak_labels._extended_matrix]

cov_idx = df["wl"].map(lambda x: x.sum() != -16)
cov_ext_idx = df["wl_ext"].map(lambda x: x.sum() != -16)
test_idx = ~(df.annotation.isna())

df_test = df[test_idx]
df_cov, df_cov_ext = df[cov_idx & test_idx], df[cov_ext_idx & test_idx]

label2int = {
    label: i for i, label in enumerate(df_test.annotation.value_counts().index)
}

fig, ax = plt.subplots(
    1,
    2,
    figsize=(13, 6),
)

ax[0].scatter(
    df_test.x, df_test.y, c=df_test.annotation.map(lambda x: label2int[x]), s=10
)
ax[0].scatter(
    df_cov.x,
    df_cov.y,
    c=df_cov.annotation.map(lambda x: label2int[x]),
    s=100,
    alpha=0.2,
)

scatter = ax[1].scatter(
    df_test.x, df_test.y, c=df_test.annotation.map(lambda x: label2int[x]), s=10
)
ax[1].scatter(
    df_cov_ext.x,
    df_cov_ext.y,
    c=df_cov_ext.annotation.map(lambda x: label2int[x]),
    s=100,
    alpha=0.2,
)

ax[0].set_title("Original", {"fontsize": "xx-large"})
ax[0].set_xticks([]), ax[0].set_yticks([])

ax[1].set_title("Extended", {"fontsize": "xx-large"})
ax[1].set_xticks([]), ax[1].set_yticks([])

labels = list(scatter.legend_elements())
labels[1] = list(label2int.keys())
legend1 = ax[0].legend(*labels, loc="lower right", fontsize="xx-large")
ax[0].add_artist(legend1)

fig.tight_layout()
plt.savefig("extend_weak_labels.png", facecolor="white", transparent=False)