🦾 微调 LLM 和其他语言模型#
反馈数据集#
注意
本节介绍的数据集类是 FeedbackDataset
。这个完全可配置的数据集将在 Argilla 2.0 中取代 DatasetForTextClassification
、DatasetForTokenClassification
和 DatasetForText2Text
。不确定使用哪个数据集?请查看我们关于选择数据集的部分。
在从我们的 FeedbackDataset
收集回复后,我们可以开始微调我们的 LLM 和其他模型。由于任务的可定制性,这可能需要设置自定义的后处理工作流,但我们将为 LLM 方法提供一些好的示例:监督微调和通过人类反馈的强化学习 (RLHF)。然而,我们仍然为其他 NLP 任务(如文本分类)提供支持。
ArgillaTrainer
#
ArgillaTrainer
是对我们许多喜爱的 NLP 库的封装。它提供了一个非常直观的抽象表示,以方便使用合理的默认预设配置进行简单的训练工作流,而无需担心来自 Argilla 的任何数据转换。
使用 ArgillaTrainer
很简单,但每个任务略有不同。
首先,我们定义一个
TrainingTask
。这可以使用自定义的formatting_func
完成。然而,像文本分类这样的任务也可以使用默认定义来定义,使用FeedbackDataset
字段和问题。这些任务然后用于从数据集中检索数据并初始化训练。我们还为开箱即用的 统一数据提供了一些想法。接下来,我们初始化
ArgillaTrainer
并转发任务和训练框架。在内部,这使用FeedbackData.prepare_for_training
方法来根据框架的期望格式化数据。一些其他有趣的方法是ArgillaTrainer.update_config
用于更改特定于框架的训练参数。ArgillaTrainer.train
用于开始训练。ArgillTrainer.predict
用于运行推理。
在下面,您可以看到使用 ArgillaTrainer
的愉快流程。
from argilla.feedback import ArgillaTrainer, FeedbackDataset, TrainingTask
dataset = FeedbackDataset.from_huggingface(
repo_id="argilla/emotion"
)
task = TrainingTask.for_text_classification(
text=dataset.field_by_name("text"),
label=dataset.question_by_name("label"),
)
trainer = ArgillaTrainer(
dataset=dataset,
task=task,
framework="setfit"
)
trainer.update_config(num_iterations=1)
trainer.train(output_dir="my_setfit_model")
trainer.predict("This is awesome!")
支持的框架#
我们计划增加对其他任务和框架的支持,因此请随时通过我们的 Discord 频道或 GitHub 与我们联系,以帮助我们确定每个任务的优先级。
任务/框架 |
TRL |
OpenAI |
SetFit |
spaCy |
Transformers |
PEFT |
SentenceTransformers |
---|---|---|---|---|---|---|---|
文本分类 |
✔️ |
✔️ |
✔️ |
✔️ |
|||
问题回答 |
✔️ |
||||||
句子相似度 |
✔️ |
||||||
监督微调 |
✔️ |
||||||
奖励建模 |
✔️ |
||||||
近端策略优化 |
✔️ |
||||||
直接偏好优化 |
✔️ |
||||||
聊天完成 |
✔️ |
训练配置#
训练器还有一个 ArgillaTrainer.update_config()
方法,它将带有 **kwargs
的字典映射到相应的框架。因此,这些可以从用于初始化训练器的底层框架派生出来。在下面,您可以找到支持的框架的这些变量的概述。
注意
请注意,您不需要直接传递所有这些变量,并且下面的值是它们的默认配置。
# `OpenAI.FineTune`
trainer.update_config(
training_file = None,
validation_file = None,
model = "gpt-3.5-turbo-0613",
hyperparameters = {"n_epochs": 1},
suffix = None
)
# `OpenAI.FineTune` (legacy)
trainer.update_config(
training_file = None,
validation_file = None,
model = "curie",
n_epochs = 2,
batch_size = None,
learning_rate_multiplier = 0.1,
prompt_loss_weight = 0.1,
compute_classification_metrics = False,
classification_n_classes = None,
classification_positive_class = None,
classification_betas = None,
suffix = None
)
# `AutoTrain.autotrain_advanced`
trainer.update_config(
model = "autotrain", # hub models like roberta-base
autotrain = [{
"source_language": "en",
"num_models": 5
}],
hub_model = [{
"learning_rate": 0.001,
"optimizer": "adam",
"scheduler": "linear",
"train_batch_size": 8,
"epochs": 10,
"percentage_warmup": 0.1,
"gradient_accumulation_steps": 1,
"weight_decay": 0.1,
"tasks": "text_binary_classification", # this is inferred from the dataset
}]
)
# `setfit.SetFitModel`
trainer.update_config(
pretrained_model_name_or_path = "all-MiniLM-L6-v2",
force_download = False,
resume_download = False,
proxies = None,
token = None,
cache_dir = None,
local_files_only = False
)
# `setfit.SetFitTrainer`
trainer.update_config(
metric = "accuracy",
num_iterations = 20,
num_epochs = 1,
learning_rate = 2e-5,
batch_size = 16,
seed = 42,
use_amp = True,
warmup_proportion = 0.1,
distance_metric = "BatchHardTripletLossDistanceFunction.cosine_distance",
margin = 0.25,
samples_per_label = 2
)
# `spacy.training`
trainer.update_config(
dev_corpus = "corpora.dev",
train_corpus = "corpora.train",
seed = 42,
gpu_allocator = 0,
accumulate_gradient = 1,
patience = 1600,
max_epochs = 0,
max_steps = 20000,
eval_frequency = 200,
frozen_components = [],
annotating_components = [],
before_to_disk = None,
before_update = None
)
# `transformers.AutoModelForTextClassification`
trainer.update_config(
pretrained_model_name_or_path = "distilbert-base-uncased",
force_download = False,
resume_download = False,
proxies = None,
token = None,
cache_dir = None,
local_files_only = False
)
# `transformers.TrainingArguments`
trainer.update_config(
per_device_train_batch_size = 8,
per_device_eval_batch_size = 8,
gradient_accumulation_steps = 1,
learning_rate = 5e-5,
weight_decay = 0,
adam_beta1 = 0.9,
adam_beta2 = 0.9,
adam_epsilon = 1e-8,
max_grad_norm = 1,
learning_rate = 5e-5,
num_train_epochs = 3,
max_steps = 0,
log_level = "passive",
logging_strategy = "steps",
save_strategy = "steps",
save_steps = 500,
seed = 42,
push_to_hub = False,
hub_model_id = "user_name/output_dir_name",
hub_strategy = "every_save",
hub_token = "1234",
hub_private_repo = False
)
# `peft.LoraConfig`
trainer.update_config(
r=8,
target_modules=None,
lora_alpha=16,
lora_dropout=0.1,
fan_in_fan_out=False,
bias="none",
inference_mode=False,
modules_to_save=None,
init_lora_weights=True,
)
# `transformers.AutoModelForTextClassification`
trainer.update_config(
pretrained_model_name_or_path = "distilbert-base-uncased",
force_download = False,
resume_download = False,
proxies = None,
token = None,
cache_dir = None,
local_files_only = False
)
# `transformers.TrainingArguments`
trainer.update_config(
per_device_train_batch_size = 8,
per_device_eval_batch_size = 8,
gradient_accumulation_steps = 1,
learning_rate = 5e-5,
weight_decay = 0,
adam_beta1 = 0.9,
adam_beta2 = 0.9,
adam_epsilon = 1e-8,
max_grad_norm = 1,
learning_rate = 5e-5,
num_train_epochs = 3,
max_steps = 0,
log_level = "passive",
logging_strategy = "steps",
save_strategy = "steps",
save_steps = 500,
seed = 42,
push_to_hub = False,
hub_model_id = "user_name/output_dir_name",
hub_strategy = "every_save",
hub_token = "1234",
hub_private_repo = False
)
# `SpanMarkerConfig`
trainer.update_config(
pretrained_model_name_or_path = "distilbert-base-cased"
model_max_length = 256,
marker_max_length = 128,
entity_max_length = 8,
)
# `transformers.TrainingArguments`
trainer.update_config(
per_device_train_batch_size = 8,
per_device_eval_batch_size = 8,
gradient_accumulation_steps = 1,
learning_rate = 5e-5,
weight_decay = 0,
adam_beta1 = 0.9,
adam_beta2 = 0.9,
adam_epsilon = 1e-8,
max_grad_norm = 1,
learning_rate = 5e-5,
num_train_epochs = 3,
max_steps = 0,
log_level = "passive",
logging_strategy = "steps",
save_strategy = "steps",
save_steps = 500,
seed = 42,
push_to_hub = False,
hub_model_id = "user_name/output_dir_name",
hub_strategy = "every_save",
hub_token = "1234",
hub_private_repo = False
)
# Parameters from `trl.RewardTrainer`, `trl.SFTTrainer`, `trl.PPOTrainer` or `trl.DPOTrainer`.
# `transformers.TrainingArguments`
trainer.update_config(
per_device_train_batch_size = 8,
per_device_eval_batch_size = 8,
gradient_accumulation_steps = 1,
learning_rate = 5e-5,
weight_decay = 0,
adam_beta1 = 0.9,
adam_beta2 = 0.9,
adam_epsilon = 1e-8,
max_grad_norm = 1,
learning_rate = 5e-5,
num_train_epochs = 3,
max_steps = 0,
log_level = "passive",
logging_strategy = "steps",
save_strategy = "steps",
save_steps = 500,
seed = 42,
push_to_hub = False,
hub_model_id = "user_name/output_dir_name",
hub_strategy = "every_save",
hub_token = "1234",
hub_private_repo = False
)
# Parameters related to the model initialization from `sentence_transformers.SentenceTransformer`
trainer.update_config(
model="sentence-transformers/all-MiniLM-L6-v2",
modules = False,
device="cuda",
cache_folder="dir/folder",
use_auth_token=True
)
# and from `sentence_transformers.CrossEncoder`
trainer.update_config(
model="cross-encoder/ms-marco-MiniLM-L-6-v2",
num_labels=2,
max_length=128,
device="cpu",
tokenizer_args={},
automodel_args={},
default_activation_function=None
)
# Related to the training procedure from `sentence_transformers.SentenceTransformer`
trainer.update_config(
steps_per_epoch = 2,
checkpoint_path: str = None,
checkpoint_save_steps: int = 500,
checkpoint_save_total_limit: int = 0
)
# and from `sentence_transformers.CrossEncoder`
trainer.update_config(
loss_fct = None
activation_fct = nn.Identity(),
)
# The remaining arguments are common for both procedures
trainer.update_config(
evaluator: SentenceEvaluator = evaluation.EmbeddingSimilarityEvaluator,
epochs: int = 1,
scheduler: str = 'WarmupLinear',
warmup_steps: int = 10000,
optimizer_class: Type[Optimizer] = torch.optim.AdamW,
optimizer_params : Dict[str, object]= {'lr': 2e-5},
weight_decay: float = 0.01,
evaluation_steps: int = 0,
output_path: str = None,
save_best_model: bool = True,
max_grad_norm: float = 1,
use_amp: bool = False,
callback: Callable[[float, int, int], None] = None,
show_progress_bar: bool = True,
)
# Other parameters that don't correspond to the initialization or the trainer, but
# can be set externally.
trainer.update_config(
batch_size=8, # It will be passed to the DataLoader to generate batches during training.
loss_cls=losses.BatchAllTripletLoss
)
TrainingTask
#
TrainingTask
用于定义应如何根据关联的任务和框架处理和格式化数据。每个任务都有自己的 TrainingTask.for_*
-classmethod,并且数据格式化始终可以使用自定义的 formatting_func
定义。然而,像文本分类这样的简单任务也可以使用默认定义来定义。这些直接使用 FeedbackDataset 配置中的字段和问题来推断如何准备数据。在下面,您可以找到 TrainingTask
要求的概述。
方法 |
内容 |
|
默认 |
---|---|---|---|
for_text_classification |
|
|
✔️ |
for_question_answering |
|
|
✔️ |
for_sentence_similarity |
|
|
✔️ |
for_supervised_fine_tuning |
|
|
✗ |
for_reward_modeling |
|
|
✗ |
for_proximal_policy_optimization |
|
|
✗ |
for_direct_preference_optimization |
|
|
✗ |
for_chat_completion |
|
|
✗ |
过滤和排序用于训练的数据集#
假设您想过滤数据集的一部分,仅保留已提交的记录,或者按日期排序以仅在数据集的最新添加项上进行训练。您可以通过使用 ArgillaTrainer
的 filter_by
、sort_by
和 max_records
参数轻松完成此操作
from argilla import SortBy
trainer = ArgillaTrainer(
dataset=dataset,
task=task,
framework="setfit",
filter_by={"response_status": ["submitted"]},
sort_by=[SortBy(field="metadata.my-metadata", order="asc")],
max_records=1000
)
注意
您可以查看文档中的过滤和查询数据集页面,以了解有关如何过滤和排序数据集的更多信息。
Huggingface Hub 集成#
本节介绍与 Hugging Face 🤗模型中心的一些集成,这是共享 Argilla 模型的最简单方法,以及生成自动化模型卡的可能性。
注意
查看 🤗huggingface hub 中带有自动生成模型卡的以下示例模型,并查看 https://hugging-face.cn/models?other=argilla 以获取即将推出的共享 Argilla 模型。
模型卡生成#
ArgillaTrainer
在保存模型时自动生成模型卡。在调用 trainer.train(output_dir="my_model")
后,您应该在通过 train 方法传递的相同输出目录下看到模型卡:./my_model/README.md
。卡片中的大多数字段在可能的情况下会自动生成,但以下字段可以通过 ArgillaTrainer
的 framework_kwargs
变量(可选)更新,如下所示
model_card_kwargs = {
"language": ["en", "es"],
"license": "Apache-2.0",
"dataset_name": "argilla/emotion",
"tags": ["nlp", "few-shot-learning", "argilla", "setfit"],
"model_summary": "Small summary of what the model does",
"model_description": "An extended explanation of the model",
"model_type": "A 1.3B parameter embedding model fine-tuned on an awesome dataset",
"finetuned_from": "all-MiniLM-L6-v2",
"repo": "https://github.com/..."
"developers": "",
"shared_by": "",
}
trainer = ArgillaTrainer(
dataset=dataset,
task=task,
framework="setfit",
framework_kwargs={"model_card_kwargs": model_card_kwargs}
)
trainer.train(output_dir="my_model")
即使它是内部生成的,您也可以通过调用 generate_model_card
方法获取卡片
argilla_model_card = trainer.generate_model_card("my_model")
将您的模型上传到 Huggingface Hub#
如果您尚未安装 huggingface hub,可以使用以下命令进行安装
pip install huggingface_hub
注意
如果您选择的框架是 spacy
或 spacy-transformers
,您还应该安装以下依赖项
pip install spacy-huggingface-hub
然后选择环境,具体取决于您是在脚本中还是在 jupyter notebook 中工作
从控制台窗口运行以下命令并插入您的 🤗huggingface hub 令牌
huggingface-cli login
从 notebook 单元格运行以下命令并插入您的 🤗huggingface hub 令牌
from huggingface_hub import notebook_login
notebook_login()
在内部,令牌将在调用 push_to_huggingface
模型时使用。
如果您需要更多关于发布模型的帮助,请务必查看 huggingface hub 要求。
在您的模型训练完成后,您只需调用 push_to_huggingface
并等待您的模型被推送到 hub(默认情况下,将生成模型卡,如果您不想要它,请将参数设置为 False
)
# spaCy based models:
repo_id = output_dir
# Every other framework:
repo_id = "organization/model-name" # for example: argilla/newest-model
trainer.push_to_huggingface(repo_id, generate_card=True)
由于 spaCy 在推送模型时的行为,repo_id 会在内部自动生成,您需要传递模型保存的路径(您可能传递给 train
方法的相同 output_dir
变量),它将以相同的方式工作。
任务#
文本分类#
背景#
文本分类是一种广泛使用的 NLP 任务,其中标签被分配给文本。主要公司依靠它来用于各种应用。情感分析是文本分类的一种流行形式,它将标签(如 🙂 正面、🙁 负面或 😐 中性)分配给文本。此外,我们区分单标签和多标签文本分类。
单标签文本分类是指为给定的文本样本分配单个类别或标签的任务。每个文本仅与一个预定义的类或类别相关联。例如,在情感分析中,单标签文本分类任务将涉及根据文本的情感为文本分配诸如“正面”、“负面”或“中性”之类的标签。
"The help for my application of a new card and mortgage was great", "positive"
由于确定和预测每个文本的多个相关标签的挑战,多标签文本分类通常比单标签分类更复杂。它在各个领域都有应用,包括文档标记、主题标记和内容推荐系统。例如,在客户服务中,多标签文本分类任务将涉及根据文本内容为文本分配诸如“new_card”、“mortgage”或“opening_hours”之类的主题。
提示
对于多标签场景,建议添加一些没有任何标签的示例,以提高模型性能。
"The help for my application of a new card and mortgage was great", ["new_card", "mortgage"]
然后我们使用文本-标签对来进一步微调模型。
训练#
文本分类是 NLP 中最广泛支持的训练任务之一。作为示例,我们将使用我们的情感演示数据集。
数据准备
from argilla.feedback import FeedbackDataset
dataset = FeedbackDataset.from_huggingface(
repo_id="argilla/emotion"
)
对于此任务,我们假设我们需要一个文本-标签对或一个 formatting_func 来定义 TrainingTask.for_text_classification。
我们提供使用基于文本-标签对的默认统一策略和格式化的选项。在这里,我们基于数据集中的 TextField
和 LabelQuestion
、MultiLabelQuestion
、RatingQuestion
或 RankingQuestion
推断格式化信息。这是为文本分类定义 TrainingTask 的最简单方法,但如果您需要自定义工作流,则可以使用 formatting_func
。
注意
统一措施的概述可以在这里找到。RatingQuestion
和 RankingQuestion
可以使用“多数”、“最小”、“最大”或“不一致”策略进行统一。LabelQuestion
和 MultiLabelQuestion
都可以使用“多数”或“不一致”策略解决。
from argilla.feedback import FeedbackDataset, TrainingTask
dataset = FeedbackDataset.from_huggingface(
repo_id="argilla/emotion"
)
task = TrainingTask.for_text_classification(
text=dataset.field_by_name("text"),
label=dataset.question_by_name("label"),
label_strategy=None # defaults presets
)
我们提供将 formatting_func
提供给 TrainingTask.for_text_classification
的选项。此函数应用于数据集中的每个样本,可用于更高级的预处理和数据格式化。该函数应返回 (text, label) 的元组,格式为 Tuple[str, str]
或 Tuple[str, List[str]]
。
from argilla.feedback import FeedbackDataset, TrainingTask
dataset = FeedbackDataset.from_huggingface(
repo_id="argilla/emotion"
)
def formatting_func(sample):
text = sample["text"]
# Choose the most common label
values = [resp["value"] for resp in sample["label"]]
counter = Counter(values)
if counter:
most_common = counter.most_common()
max_frequency = most_common[0][1]
most_common_elements = [
element for element, frequency in most_common if frequency == max_frequency
]
label = random.choice(most_common_elements)
return (text, label)
else:
return None
task = TrainingTask.for_text_classification(formatting_func=formatting_func)
然后,我们可以为任何受支持的框架定义我们的 ArgillaTrainer
,并使用 ArgillaTrainer.update_config
自定义训练配置。
from argilla.feedback import ArgillaTrainer
trainer = ArgillaTrainer(
dataset=feedback_dataset,
task=task,
framework="spacy",
train_size=0.8,
model="en_core_web_sm",
)
trainer.train(output_dir="textcat_model")
问题回答#
背景#
抽取式问题回答 (QnA) 任务涉及根据给定的上下文回答用户提出的问题。这是一项具有挑战性的任务,需要模型理解问题的上下文并提供准确的答案。模型必须能够理解问题和提问的上下文,以及两者之间的关系。此外,它必须能够从上下文中提取相关信息,并提供既准确又与问题相关的答案。
您可以在下面找到抽取式 QnA 数据集的示例
{
'question': 'To whom did the Virgin Mary allegedly appear in 1858 in Lourdes France?',
'context': 'Architecturally, the school has a Catholic character. Atop the Main Building\'s gold dome is a golden statue of the Virgin Mary. Immediately in front of the Main Building and facing it, is a copper statue of Christ with arms upraised with the legend "Venite Ad Me Omnes". Next to the Main Building is the Basilica of the Sacred Heart. Immediately behind the basilica is the Grotto, a Marian place of prayer and reflection. It is a replica of the grotto at Lourdes, France where the Virgin Mary reputedly appeared to Saint Bernadette Soubirous in 1858. At the end of the main drive (and in a direct line that connects through 3 statues and the Gold Dome), is a simple, modern stone statue of Mary.',
'answers': 'Saint Bernadette Soubirous',
}
注意
官方上,答案需要作为 {'answer_start': int, 'text': str}
-字典的列表传递。但是,我们仅支持字符串,其中 answer_start
是从上下文和文本字段推断出来的。
然后,我们使用问题-上下文-答案集或 formatting_func
来进一步微调模型。
训练#
数据准备
import argilla as rg
from datasets import Dataset
feedback_dataset = rg.FeedbackDataset.from_huggingface("argilla/squad")
我们可以使用默认配置,其中我们使用数据集中的问题-上下文-答案集初始化 TrainingTask.for_question_answering
。我们还提供将 formatting_func
提供给 TrainingTask.for_question_asnwering
的选项。此函数应用于数据集中的每个样本,可用于高级预处理和数据格式化。该函数应返回问题-上下文-答案集,格式为 str-str-str
。
from argilla.feedback import TrainingTask
task = TrainingTask.for_question_answering(
question=feedback_dataset.field_by_name("question"),
context=feedback_dataset.field_by_name("context"),
answer=feedback_dataset.question_by_name("answer"),
)
from argilla.feedback import TrainingTask
def formatting_func(sample):
question = sample["question"]
context = sample["context"]
for answer in sample["answer"]:
if not all([question, context, answer["value"]]):
continue
yield question, context, answer["value"]
task = TrainingTask.for_question_answering(formatting_func=formatting_func)
ArgillaTrainer
接下来,我们可以为任何受支持的框架定义我们的 ArgillaTrainer
,并使用 ArgillaTrainer.update_config
自定义训练配置。
from argilla.feedback import ArgillaTrainer
trainer = ArgillaTrainer(
dataset=feedback_dataset,
task=task,
framework="transformers",
train_size=0.8,
)
trainer.train(output_dir="qna_model")
推理
最后,可以使用 Transformers 库中的 pipeline
方法将此模型用于推理。我们可以为此任务使用 question-answering
管道。
from transformers import pipeline
qa_model = pipeline("question-answering", model="qna_model")
question = "Where do I live?"
context = "My name is Merve and I live in İstanbul."
qa_model(question = question, context = context)
## {'answer': 'İstanbul', 'end': 39, 'score': 0.953, 'start': 31}
句子相似度#
背景#
句子相似度是确定两个文本有多相似的任务。通过将文本转换为嵌入(表示语义信息的向量),我们可以计算这些文本之间的相似度,计算它们向量之间的距离。Sentence-Transformers 库使计算这些句子嵌入并将其用于信息检索和聚类变得容易。除了这些任务之外,它还常用于优化检索增强生成 (RAG) 和重排序任务。通常,可以微调两种类型的模型。
双编码器由两个独立的神经网络模型组成,每个模型负责编码单个句子或文本。这些编码器独立工作,不共享权重。双编码器的主要目标是以保留输入语义的方式将单个句子或文本编码为固定长度的向量。这些固定长度的向量稍后可用于各种任务,例如检索或分类。双编码器通常用于需要将大量文本编码为向量的任务中(例如,为语料库中的文档创建嵌入)。然后,这些嵌入可以用于信息检索、聚类和分类等任务。
交叉编码器由单个神经网络模型组成,该模型同时接受多个输入句子或文本。它在一个前向传递中处理句子或文本对。交叉编码器的主要目标是为一对输入句子或文本提供单个标量分数或相似度度量。此分数表示两个输入文本之间的相似度或相关性。交叉编码器通常用于文本匹配、问题回答、文档检索和推荐系统等应用程序中,在这些应用程序中,您需要比较两段文本并评估它们的相似度或相关性。
在 hugging face 的这篇博客文章中,您可以看到可用于训练 sentence-transformers
模型的不同类型的数据集。
训练#
注意
我们可以使用 framework_kwargs={"cross_encoder": True}
轻松地在基于 Bi-Encoder
和 Cross Encoder
的模型之间切换。此外,数据可以以三种不同的方式提供,因此。请记住,基于 Cross Encoder
的模型不允许使用句子三元组进行训练。
示例是一对没有标签的正面(相似)句子。例如,释义对、全文及其摘要对、重复问题对、(查询,响应)对或(源语言,目标语言)对。自然语言推理数据集也可以通过配对蕴含句子以这种方式格式化。
示例是一对句子和一个标签,指示它们的相似程度。标签可以是整数或浮点数。这种情况适用于最初为自然语言推理 (NLI) 准备的数据集,因为它们包含成对的句子,并带有一个标签,指示它们是否相互推断。
仅适用于 Bi Encoders
示例是一个三元组(anchor、positive、negative),没有类或句子标签。
仅适用于 Bi Encoders
示例是一个带有整数标签的句子。这种数据格式很容易被损失函数转换为三个句子(三元组),其中第一个是“anchor”(锚点),第二个是与锚点同类的“positive”(正例),第三个是不同类的“negative”(负例)。每个句子都有一个整数标签,指示它所属的类别。
数据准备
让我们使用一个小型版本的 snli 数据集作为示例,它已准备好与 Argilla snli-small 一起使用。
import argilla as rg
dataset = rg.FeedbackDataset.from_huggingface("plaguss/snli-small")
我们提供使用默认统一策略和格式化的选项,基于 sentence
-pairs 和 sentence-
三元组,带有或不带有 label
。在这里,我们基于两个 TextField
和一个 LabelQuestion
或 RankingQuestion
推断格式化信息。这是为句子相似性定义 TrainingTask
的最简单方法,但如果您需要自定义工作流程,可以使用 formatting_func
。
注意
统一措施的概述可以在此处找到。对于此类任务,仅适用 LabelQuestion
或 RankingQuestion
。
from argilla.feedback import TrainingTask
task = TrainingTask.for_sentence_similarity(
texts=[dataset.field_by_name("premise"), dataset.field_by_name("hypothesis")],
label=dataset.question_by_name("label")
)
对于使用数值注释的数据集,我们还可以传递我们想要使用的标签策略(假设我们在数据集中有另一个名为“other-question”的问题,其中包含来自评分答案的值)
task = TrainingTask.for_sentence_similarity(
texts=[dataset.field_by_name("premise"), dataset.field_by_name("hypothesis")],
label=dataset.question_by_name("other-question"),
label_strategy="majority" # or "mean" for RankingQuestion
)
我们提供将 formatting_func
提供给 TrainingTask.for_sentence_similarity
的选项。此函数应用于数据集中的每个样本,可用于更高级的预处理和数据格式化。该函数可以返回一个字典,其中包含 sentence-1
、sentence-2
以及可选的 sentence-3
和相应的句子,并且它还可以包含一个 label
,它可以是 int
(表示类别)或 float
,以及这些元素的列表。
def formatting_func(sample):
record = {"sentence-1": sample["premise"], "sentence-2": sample["hypothesis"]}
# Choose the most common label
values = [resp["value"] for resp in sample["label"]]
counter = Counter(values)
if counter:
most_common = counter.most_common()
max_frequency = most_common[0][1]
most_common_elements = [
element for element, frequency in most_common if frequency == max_frequency
]
label = random.choice(most_common_elements)
record["label"] = label
return record
else:
return None
task = TrainingTask.for_sentence_similarity(formatting_func=formatting_func)
ArgillaTrainer
我们将在 ArgillaTrainer
中直接将任务与我们的 FeedbackDataset
一起使用。对于这种情况,我们使用默认的 SentenceTransformer
模型,要微调基于 Cross Encoder
的模型,请传递 framework_kwargs={"cross_encoder": True}
。
from argilla.feedback import ArgillaTrainer
trainer = ArgillaTrainer(
dataset=dataset,
task=task,
framework="sentence-transformers",
framework_kwargs={"cross_encoder": False}
)
trainer.train(output_dir="my_sentence_transformer_model")
推理
这些模型可以使用 sentence-transformers
(或 transformers
)加载,读者可以查看以下链接中每种类型的模型
然而,ArgillaTrainer
提供了从其 API 预测句子相似性的可能性。让我们看看它们如何使用来自 Hugging Face 中 句子相似性任务 的相同示例句子工作
from argilla.feedback import ArgillaTrainer, FeedbackDataset, TrainingTask
trainer.predict(
[
"Machine learning is so easy.",
["Deep learning is so straightforward.", "This is so difficult, like rocket science.", "I can't believe how much I struggled with this."]
]
)
# [0.77857256, 0.4587626, 0.29062212]
只是为了查看可以传递以获得句子相似性的另一种格式(句子对的列表),让我们看看以下示例(这些对不需要共享第一个句子,这是一个示例,用于检查使用两种选项是否返回相同的值)。
trainer.predict(
[
["Machine learning is so easy.", "Deep learning is so straightforward."],
["Machine learning is so easy.", "This is so difficult, like rocket science."],
["Machine learning is so easy.", "I can't believe how much I struggled with this."]
]
)
# [0.77857256, 0.4587626, 0.29062212]
之前的结果是在假设训练的模型是 SentenceTransformer
的情况下获得的。如果不是使用 SentenceTransformer
模型(基于 Bi-Encoder
的模型),而是选择 Cross-Encoder
,我们将获得不同的结果,但具有相同的解释。
trainer = ArgillaTrainer(
dataset=dataset,
task=task,
framework="sentence-transformers",
framework_kwargs={"cross_encoder": True}
)
trainer.predict(
[
"Machine learning is so easy.",
["Deep learning is so straightforward.", "This is so difficult, like rocket science.", "I can't believe how much I struggled with this."]
]
)
# [2.2006402, -6.2634926, -10.251489]
监督式微调#
背景#
监督式微调 (SFT) 的目标是优化预训练模型,以生成用户正在寻找的响应。因果语言模型可以生成可行的人类文本,但它无法对用户在对话或指令集中提出的 question
短语给出适当的 answers
。因此,我们需要收集和整理针对此用例量身定制的数据,以教导模型模仿这些数据。我们的文档中有一个关于 为此任务收集数据 的章节,并且 Hugging Face 上有许多优秀的 预训练因果语言模型。
训练阶段的数据通常分为两种不同的类型:用于领域类微调的通用类型或用于微调指令集的聊天类型。
通用
在通用微调设置中,目的是使模型更熟练地在特定领域内生成连贯且上下文相关的文本。 例如,如果我们希望模型生成与医学研究相关的文本,我们将使用包含医学文献、研究论文或相关文档的数据集对其进行微调。 通过在训练期间将模型暴露于特定领域的数据,它可以更深入地了解该领域中流行的术语、概念和写作风格。 这使得模型在被提示与特定领域相关的查询或任务时,能够生成更准确且上下文相关的响应。 此格式的一个示例是 PubMed 数据,但通过指示数据范围的通用指令短语(例如 Generate a medical paper abstract: ...
)添加一些细微差别可能更明智。
# Five distinct ester hydrolases (EC 3-1) have been characterized in guinea-pig epidermis. These are carboxylic esterase, acid phosphatase, pyrophosphatase, and arylsulphatase A and B. Their properties are consistent with those of lysosomal enzymes.
聊天
另一方面,基于指令的微调涉及训练模型以理解和响应用户给出的特定指令或提示。 这种方法允许对生成的输出进行更大的控制和特异性。 例如,如果我们希望模型总结给定的文本,我们可以使用包含文本段落及其相应摘要对的数据集对其进行微调。 然后可以指示模型根据给定的输入文本生成摘要。 通过以这种方式微调模型,它变得更擅长遵循指令并生成与所需任务或目标一致的输出。 此格式的一个示例是我们带有 instruction
、context
和 response
字段的 精选 Dolly 数据集。 但是,我们也可以拥有更简单的数据集,仅包含 question
和 answer
字段。
### Instruction
{instruction}
### Context
{context}
### Response:
{response}
### Instruction
When did Virgin Australia start operating?
### Context
Virgin Australia, the trading name of Virgin Australia Airlines Pty Ltd, is an Australian-based airline. It is the largest airline by fleet size to use the Virgin brand. It commenced services on 31 August 2000 as Virgin Blue, with two aircraft on a single route. It suddenly found itself as a major airline in Australia's domestic market after the collapse of Ansett Australia in September 2001. The airline has since grown to directly serve 32 cities in Australia, from hubs in Brisbane, Melbourne and Sydney.
### Response:
Virgin Australia commenced services on 31 August 2000 as Virgin Blue, with two aircraft on a single route.
最终,用作 text
-字段的这两种方法之间的选择取决于应用程序的具体要求以及对模型输出的期望控制级别。 通过采用适当的微调策略,我们可以提高模型的性能,并使其更适合广泛的应用程序和用例。
训练#
有许多优秀的库可以帮助完成此步骤,但是,我们是 Transformer Reinforcement Learning (TRL) 包、Transformer Reinforcement Learning X (TRLX) 和无代码 Hugging Face AutoTrain 进行微调的粉丝。 在这两种情况下,我们都需要一个骨干模型,为了示例目的,我们将使用我们的 精选 Dolly 数据集。
注意
此数据集每个记录仅包含一个注释器响应。 我们就处理 来自多个注释器的响应 给出了一些建议。
Transformer Reinforcement Learning (TRL) 包为微调模型提供了灵活且可自定义的框架。 它允许用户对训练过程进行细粒度控制,使他们能够定义自己的函数并进一步指定模型的期望行为。 这种方法需要更深入地理解强化学习的概念和技术,以及更仔细的实验。 它最适合那些具有强化学习经验并希望对训练过程进行细粒度控制的用户。 此外,它还直接与 Parameter-Efficient Fine-Tuning (PEFT) 集成,从而降低了训练 LLM 此步骤的计算复杂性。
数据准备
import argilla as rg
from datasets import Dataset
feedback_dataset = rg.FeedbackDataset.from_huggingface("argilla/databricks-dolly-15k-curated-en")
我们提供将 formatting_func
提供给 TrainingTask.for_supervised_fine_tuning
的选项。 此函数应用于数据集中的每个样本,可用于高级预处理和数据格式化。 该函数应返回一个 text
作为 str
。
from argilla.feedback import TrainingTask
from typing import Dict, Any
template = """\
### Instruction: {instruction}\n
### Context: {context}\n
### Response: {response}"""
def formatting_func(sample: Dict[str, Any]) -> str:
# What `sample` looks like depends a lot on your FeedbackDataset fields and questions
return template.format(
instruction=sample["new-instruction"][0]["value"],
context=sample["new-context"][0]["value"],
response=sample["new-response"][0]["value"],
)
task = TrainingTask.for_supervised_fine_tuning(formatting_func=formatting_func)
您可以通过调用 FeedbackDataset.prepare_for_training
来观察生成的数据集。 我们可以使用 "trl"
作为框架示例
dataset = feedback_dataset.prepare_for_training(
framework="trl",
task=task
)
"""
>>> dataset
Dataset({
features: ['id', 'text'],
num_rows: 15015
})
>>> dataset[0]["text"]
### Instruction: When did Virgin Australia start operating?
### Context: Virgin Australia, the trading name of Virgin Australia Airlines Pty Ltd, is an Australian-based airline. It is the largest airline by fleet size to use the Virgin brand. It commenced services on 31 August 2000 as Virgin Blue, with two aircraft on a single route. It suddenly found itself as a major airline in Australia's domestic market after the collapse of Ansett Australia in September 2001. The airline has since grown to directly serve 32 cities in Australia, from hubs in Brisbane, Melbourne and Sydney.
### Response: Virgin Australia commenced services on 31 August 2000 as Virgin Blue, with two aircraft on a single route.
"""
ArgillaTrainer
from argilla.feedback import ArgillaTrainer
trainer = ArgillaTrainer(
dataset=feedback_dataset,
task=task,
framework="trl",
train_size=0.8,
model="gpt2",
)
# e.g. using LoRA:
# from peft import LoraConfig
# trainer.update_config(peft_config=LoraConfig())
trainer.train(output_dir="sft_model")
注意
您还可以使用已初始化的 model
和 tokenizer
初始化 ArgillaTrainer
,以进行额外的细粒度控制。 如果您希望确保 tokenizer 添加 EOS 令牌,这可能很有用。 缺少此令牌可能会导致模型无休止地生成。
如果训练后的模型仍然无休止地生成,则建议 1) 传递一个肯定会添加 EOS 令牌的 tokenizer
,以及 2) 传递一个自定义 Data Collator,该 Collator 不会将 EOS 令牌的标签设置为 -100。
推理
让我们观察一下训练模型以在我们的模板中响应是否有效。 我们将为此创建一个快速辅助方法。
from transformers import GenerationConfig, AutoTokenizer, GPT2LMHeadModel
def generate(model_id: str, instruction: str, context: str = "") -> str:
model = GPT2LMHeadModel.from_pretrained(model_id)
tokenizer = AutoTokenizer.from_pretrained(model_id)
inputs = template.format(
instruction=instruction,
context=context,
response="",
).strip()
encoding = tokenizer([inputs], return_tensors="pt")
outputs = model.generate(
**encoding,
generation_config=GenerationConfig(
max_new_tokens=32,
min_new_tokens=12,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
),
)
return tokenizer.decode(outputs[0])
>>> generate("sft_model", "Is a toad a frog?")
### Instruction: Is a toad a frog?
### Context:
### Response: A frog is a small, round, black-eyed, frog with a long, black-winged head. It is a member of the family Pter
好多了! 此模型按照我们的意愿遵循了模板。
奖励建模#
背景#
奖励模型 (RM) 用于根据人类偏好对响应进行评分,然后在之后使用此 RM 对 LLM 进行微调并使用关联的分数。 可以通过不同的方式使用奖励模型进行微调。 我们可以让注释器完全手动地对输出进行评分,我们可以使用简单的启发式方法,或者我们可以使用随机偏好模型。 TRL 和 TRLX 都为合并奖励提供了不错的选项。 Microsoft 的 DeepSpeed 库 也值得一提,但我们的文档中不会介绍。
这些步骤所需的数据需要用作比较数据,以展示对生成的提示的偏好。 我们的 精选 Dolly 数据集 是一个很好的例子,我们假设更新后的响应优先于旧响应。 另一个好的例子是 Anthropic RLHF 数据集。
注意
原始 Dolly 数据集包含许多参考指示符,例如“[1]”,这会导致模型产生幻觉并错误地创建参考。
### Instruction
When did Virgin Australia start operating?
### Context
Virgin Australia, the trading name of Virgin Australia Airlines Pty Ltd, is an Australian-based airline.
It is the largest airline by fleet size to use the Virgin brand. [2]
It commenced services on 31 August 2000 as Virgin Blue, with two aircraft on a single route.[3]
It suddenly found itself as a major airline in Australia's domestic market after the collapse of Ansett Australia in September 2001.
The airline has since grown to directly serve 32 cities in Australia, from hubs in Brisbane, Melbourne and Sydney.[4]
### Response:
Virgin Australia commenced services on 31 August 2000 as Virgin Blue, with two aircraft on a single route.
### Instruction
When did Virgin Australia start operating?
### Context
Virgin Australia, the trading name of Virgin Australia Airlines Pty Ltd, is an Australian-based airline.
It is the largest airline by fleet size to use the Virgin brand.
It commenced services on 31 August 2000 as Virgin Blue, with two aircraft on a single route.
It suddenly found itself as a major airline in Australia's domestic market after the collapse of Ansett Australia in September 2001.
The airline has since grown to directly serve 32 cities in Australia, from hubs in Brisbane, Melbourne and Sydney.
### Response:
Virgin Australia commenced services on 31 August 2000 as Virgin Blue, with two aircraft on a single route.
在训练 RM 的情况下,我们然后使用 chosen-rejected
-pairs 并训练分类器以区分它们。
训练#
这些步骤所需的数据需要用作比较数据,以展示对生成的提示的偏好。 我们的 精选 Dolly 数据集 是一个很好的例子,我们假设更新后的响应优先于旧响应。 另一个好的例子是 Anthropic RLHF 数据集。
注意
原始 Dolly 数据集包含许多参考指示符,例如“[1]”,这会导致模型产生幻觉并错误地创建参考。
### Instruction
When did Virgin Australia start operating?
### Context
Virgin Australia, the trading name of Virgin Australia Airlines Pty Ltd, is an Australian-based airline.
It is the largest airline by fleet size to use the Virgin brand. [2]
It commenced services on 31 August 2000 as Virgin Blue, with two aircraft on a single route.[3]
It suddenly found itself as a major airline in Australia's domestic market after the collapse of Ansett Australia in September 2001.
The airline has since grown to directly serve 32 cities in Australia, from hubs in Brisbane, Melbourne and Sydney.[4]
### Response:
Virgin Australia commenced services on 31 August 2000 as Virgin Blue, with two aircraft on a single route.
### Instruction
When did Virgin Australia start operating?
### Context
Virgin Australia, the trading name of Virgin Australia Airlines Pty Ltd, is an Australian-based airline.
It is the largest airline by fleet size to use the Virgin brand.
It commenced services on 31 August 2000 as Virgin Blue, with two aircraft on a single route.
It suddenly found itself as a major airline in Australia's domestic market after the collapse of Ansett Australia in September 2001.
The airline has since grown to directly serve 32 cities in Australia, from hubs in Brisbane, Melbourne and Sydney.
### Response:
Virgin Australia commenced services on 31 August 2000 as Virgin Blue, with two aircraft on a single route.
TRL 实现了奖励建模,可以通过 ArgillaTrainer
类使用。 我们提供将 formatting_func
提供给 TrainingTask.for_reward_modeling
的选项。 此函数应用于数据集中的每个样本,可用于预处理和数据格式化。 该函数应返回 chosen-rejected
-pairs 元组,类型为 Tuple[str, str]
。 为了确定 FeedbackDataset 中的哪个响应更优,我们可以使用用户注释。
注意
格式化函数还可以返回 None
或元组列表。 如果注释表明文本质量低劣或有害,则可以使用 None
,如果多个注释器提供额外的书面响应,从而产生多个良好的 chosen-rejected
对,则可以使用后者。
数据准备
formatting_func
的参数是什么样子很大程度上取决于您的 FeedbackDataset 字段和问题。 但是,字段(即 Argilla 注释视图的左侧)作为它们的值提供,例如
>>> sample
{
...
'original-response': 'Virgin Australia commenced services on 31 August 2000 '
'as Virgin Blue, with two aircraft on a single route.',
...
}
并且,所有问题(即 Argilla 注释视图的右侧)都像这样提供
>>> sample
{
...
'new-response': [{'status': 'submitted',
'value': 'Virgin Australia commenced services on 31 August '
'2000 as Virgin Blue, with two aircraft on a '
'single route.',
'user-id': ...}],
'new-response-suggestion': None,
'new-response-suggestion-metadata': {'agent': None,
'score': None,
'type': None},
...
}
我们现在可以定义我们的格式化函数,该函数应返回 chosen-rejected
-pairs 作为元组。
from typing import Any, Dict, Iterator, Tuple
from argilla.feedback import TrainingTask
template = """\
### Instruction: {instruction}\n
### Context: {context}\n
### Response: {response}"""
def formatting_func(sample: Dict[str, Any]) -> Iterator[Tuple[str, str]]:
# Our annotators were asked to provide new responses, which we assume are better than the originals
og_instruction = sample["original-instruction"]
og_context = sample["original-context"]
og_response = sample["original-response"]
rejected = template.format(instruction=og_instruction, context=og_context, response=og_response)
for instruction, context, response in zip(sample["new-instruction"], sample["new-context"], sample["new-response"]):
if response["status"] == "submitted":
chosen = template.format(
instruction=instruction["value"],
context=context["value"],
response=response["value"],
)
if chosen != rejected:
yield chosen, rejected
task = TrainingTask.for_reward_modeling(formatting_func=formatting_func)
您可以使用 FeedbackDataset.prepare_for_training
来观察使用此任务创建的数据集,例如使用“trl”框架
dataset = feedback_dataset.prepare_for_training(framework="trl", task=task)
"""
>>> dataset
Dataset({
features: ['chosen', 'rejected'],
num_rows: 2872
})
>>> dataset[2772]
{
'chosen': '### Instruction: Answer based on the text: Is Leucascidae a sponge\n\n'
'### Context: Leucascidae is a family of calcareous sponges in the order Clathrinida.\n\n'
'### Response: Yes',
'rejected': '### Instruction: Is Leucascidae a sponge\n\n'
'### Context: Leucascidae is a family of calcareous sponges in the order Clathrinida.[1]\n\n'
'### Response: Leucascidae is a family of calcareous sponges in the order Clathrinida.'}
"""
看起来很棒!
ArgillaTrainer
现在让我们使用 ArgillaTrainer
来训练一个具有此任务的奖励模型。
from argilla.feedback import ArgillaTrainer
trainer = ArgillaTrainer(
dataset=feedback_dataset,
task=task,
framework="trl",
model="distilroberta-base",
)
trainer.train(output_dir="reward_model")
推理
让我们在实践中试用训练好的模型。
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
model = AutoModelForSequenceClassification.from_pretrained("reward_model")
tokenizer = AutoTokenizer.from_pretrained("reward_model")
def get_score(model, tokenizer, text):
# Tokenize the input sequences
inputs = tokenizer(text, truncation=True, padding="max_length", max_length=512, return_tensors="pt")
# Perform forward pass
with torch.no_grad():
outputs = model(**inputs)
# Extract the logits
return outputs.logits[0, 0].item()
# Example usage
prompt = "Is a toad a frog?"
context = "Both frogs and toads are amphibians in the order Anura, which means \"without a tail.\" Toads are a sub-classification of frogs, meaning that all toads are frogs, but not all frogs are toads."
good_response = "Yes"
bad_response = "Both frogs and toads are amphibians in the order Anura, which means \"without a tail.\""
example_good = template.format(instruction=prompt, context=context, response=good_response)
example_bad = template.format(instruction=prompt, context=context, response=bad_response)
score = get_score(model, tokenizer, example_good)
print(score)
# >> 5.478324890136719
score = get_score(model, tokenizer, example_bad)
print(score)
# >> 2.2948970794677734
正如预期的那样,好的响应比差的响应具有更高的分数。
近端策略优化#
背景#
TRL 库实现了 RLHF 的最后一步:近端策略优化 (PPO)。 它需要提示,然后将提示馈送到正在微调的模型中。 其结果通过奖励模型传递。 最后,提示、响应和奖励用于通过强化学习更新模型。
注意
PPO 需要经过训练的监督式微调模型和奖励模型才能工作。 查看上面的任务大纲以训练您自己的模型。
这些步骤所需的数据需要用作比较数据,以展示对生成的提示的偏好。 我们的 精选 Dolly 数据集 是一个很好的例子,我们假设更新后的响应优先于旧响应。 另一个好的例子是 Anthropic RLHF 数据集。
注意
原始 Dolly 数据集包含许多参考指示符,例如“[1]”,这会导致模型产生幻觉并错误地创建参考。
### Instruction
When did Virgin Australia start operating?
### Context
Virgin Australia, the trading name of Virgin Australia Airlines Pty Ltd, is an Australian-based airline.
It is the largest airline by fleet size to use the Virgin brand. [2]
It commenced services on 31 August 2000 as Virgin Blue, with two aircraft on a single route.[3]
It suddenly found itself as a major airline in Australia's domestic market after the collapse of Ansett Australia in September 2001.
The airline has since grown to directly serve 32 cities in Australia, from hubs in Brisbane, Melbourne and Sydney.[4]
### Response:
Virgin Australia commenced services on 31 August 2000 as Virgin Blue, with two aircraft on a single route.
### Instruction
When did Virgin Australia start operating?
### Context
Virgin Australia, the trading name of Virgin Australia Airlines Pty Ltd, is an Australian-based airline.
It is the largest airline by fleet size to use the Virgin brand.
It commenced services on 31 August 2000 as Virgin Blue, with two aircraft on a single route.
It suddenly found itself as a major airline in Australia's domestic market after the collapse of Ansett Australia in September 2001.
The airline has since grown to directly serve 32 cities in Australia, from hubs in Brisbane, Melbourne and Sydney.
### Response:
Virgin Australia commenced services on 31 August 2000 as Virgin Blue, with two aircraft on a single route.
在训练 PPO 的情况下,我们然后使用提示和上下文数据,并通过使用奖励模型来纠正来自 SFT 模型的生成响应。 因此,我们将需要格式化以下 text
。
### Instruction
When did Virgin Australia start operating?
### Context
Virgin Australia, the trading name of Virgin Australia Airlines Pty Ltd, is an Australian-based airline.
It is the largest airline by fleet size to use the Virgin brand.
It commenced services on 31 August 2000 as Virgin Blue, with two aircraft on a single route.
It suddenly found itself as a major airline in Australia's domestic market after the collapse of Ansett Australia in September 2001.
The airline has since grown to directly serve 32 cities in Australia, from hubs in Brisbane, Melbourne and Sydney.
### Response:
{to be generated by SFT model}
训练#
我们将使用我们的 精选 Dolly 数据集,如上面的背景部分介绍。
import argilla as rg
feedback_dataset = rg.FeedbackDataset.from_huggingface("argilla/databricks-dolly-15k-curated-en")
数据准备
像往常一样,我们从一个带有格式化函数的任务开始。 对于 PPO,格式化函数仅返回提示作为 text
,这些提示根据模板格式化。
from argilla.feedback import TrainingTask
from typing import Dict, Any, Iterator
template = """\
### Instruction: {instruction}\n
### Context: {context}\n
### Response: {response}"""
def formatting_func(sample: Dict[str, Any]) -> Iterator[str]:
for instruction, context in zip(sample["new-instruction"], sample["new-context"]):
if instruction["status"] == "submitted":
yield template.format(
instruction=instruction["value"],
context=context["value"][:500],
response=""
).strip()
task = TrainingTask.for_proximal_policy_optimization(formatting_func=formatting_func)
和之前一样,我们可以观察生成的数据集
dataset = feedback_dataset.prepare_for_training(framework="trl", task=task)
"""
>>> dataset
Dataset({
features: ['id', 'query'],
num_rows: 15015
})
>>> dataset[922]
{'id': 922, 'query': '### Instruction: Is beauty objective or subjective?\n\n### Context: \n\n### Response:'}
"""
ArgillaTrainer
我们将直接在 ArgillaTrainer
中将任务与我们的 FeedbackDataset
一起使用,而不是使用此数据集。 PPO 要求我们指定 reward_model
,并允许我们指定一些其他有用的值
reward_model
:具有奖励模型的情感分析管道。 这会为提示 + 响应生成奖励。length_sampler_kwargs
:一个字典,其中包含min_value
和max_value
键,指示微调模型在微调时应生成的令牌数量的下限和上限。generation_kwargs
:传递给微调模型的generate
方法的关键字参数。config
:一个trl.PPOConfig
实例,其中包含许多有用的参数,例如learning_rate
和batch_size
。
from argilla.feedback import ArgillaTrainer
from transformers import pipeline
from trl import PPOConfig
trainer = ArgillaTrainer(
dataset=feedback_dataset,
task=task,
framework="trl",
model="gpt2",
)
reward_model = pipeline("sentiment-analysis", model="reward_model")
trainer.update_config(
reward_model=reward_model,
length_sampler_kwargs={"min_value": 32, "max_value": 256},
generation_kwargs={
"min_length": -1,
"top_k": 0.0,
"top_p": 1.0,
"do_sample": True,
},
config=PPOConfig(batch_size=16)
)
trainer.train(output_dir="ppo_model")
推理
训练完成后,我们可以加载此模型并使用它生成!
from transformers import AutoModelForCausalLM, AutoTokenizer
model = AutoModelForCausalLM.from_pretrained("ppo_model")
tokenizer = AutoTokenizer.from_pretrained("ppo_model")
tokenizer.pad_token = tokenizer.eos_token
inputs = template.format(
instruction="Is a toad a frog?",
context="Both frogs and toads are amphibians in the order Anura, which means \"without a tail.\" Toads are a sub-classification of frogs, meaning that all toads are frogs, but not all frogs are toads.",
response=""
).strip()
encoding = tokenizer([inputs], return_tensors="pt")
outputs = model.generate(**encoding, max_new_tokens=30)
output_text = tokenizer.decode(outputs[0])
print(output_text)
# Yes it is, toads are a sub-classification of frogs.
直接偏好优化#
背景#
TRL 库实现了将人类反馈纳入 LLM 的另一种方法,称为直接偏好优化 (DPO)。 这种方法跳过了训练单独的奖励模型的步骤,而是在训练期间直接使用偏好数据作为优化人类反馈的度量。
注意
DPO 需要经过训练的监督式微调模型才能发挥作用。 查看上面的任务大纲以训练您自己的模型。
这些步骤所需的数据需要用作比较数据,以展示对生成的提示的偏好。 我们的 精选 Dolly 数据集 是一个很好的例子,我们假设更新后的响应优先于旧响应。 另一个好的例子是 Anthropic RLHF 数据集。
注意
原始 Dolly 数据集包含许多参考指示符,例如“[1]”,这会导致模型产生幻觉并错误地创建参考。
### Instruction
When did Virgin Australia start operating?
### Context
Virgin Australia, the trading name of Virgin Australia Airlines Pty Ltd, is an Australian-based airline.
It is the largest airline by fleet size to use the Virgin brand. [2]
It commenced services on 31 August 2000 as Virgin Blue, with two aircraft on a single route.[3]
It suddenly found itself as a major airline in Australia's domestic market after the collapse of Ansett Australia in September 2001.
The airline has since grown to directly serve 32 cities in Australia, from hubs in Brisbane, Melbourne and Sydney.[4]
### Response:
Virgin Australia commenced services on 31 August 2000 as Virgin Blue, with two aircraft on a single route.
### Instruction
When did Virgin Australia start operating?
### Context
Virgin Australia, the trading name of Virgin Australia Airlines Pty Ltd, is an Australian-based airline.
It is the largest airline by fleet size to use the Virgin brand.
It commenced services on 31 August 2000 as Virgin Blue, with two aircraft on a single route.
It suddenly found itself as a major airline in Australia's domestic market after the collapse of Ansett Australia in September 2001.
The airline has since grown to directly serve 32 cities in Australia, from hubs in Brisbane, Melbourne and Sydney.
### Response:
Virgin Australia commenced services on 31 August 2000 as Virgin Blue, with two aircraft on a single route.
在使用 PPO 进行训练的情况下,我们然后使用提示和上下文数据,并通过使用奖励模型来纠正来自 SFT 模型的生成响应。 因此,我们将需要格式化以下 text
。
### Instruction
When did Virgin Australia start operating?
### Context
Virgin Australia, the trading name of Virgin Australia Airlines Pty Ltd, is an Australian-based airline.
It is the largest airline by fleet size to use the Virgin brand.
It commenced services on 31 August 2000 as Virgin Blue, with two aircraft on a single route.
It suddenly found itself as a major airline in Australia's domestic market after the collapse of Ansett Australia in September 2001.
The airline has since grown to directly serve 32 cities in Australia, from hubs in Brisbane, Melbourne and Sydney.
### Response:
{to be generated by SFT model}
在 DPO 方法中,我们从格式化的提示和提供的偏好数据中推断奖励,形式为 prompt-chosen-rejected
-pairs。
训练#
我们将使用我们的 精选 Dolly 数据集,如上面的背景部分介绍。
import argilla as rg
feedback_dataset = rg.FeedbackDataset.from_huggingface("argilla/databricks-dolly-15k-curated-en")
数据准备
我们将从格式化函数的基本示例开始。 对于 DPO,它应返回 prompt-chosen-rejected
-pairs,其中提示根据模板格式化。
from argilla.feedback import TrainingTask
from typing import Dict, Any, Iterator
template = """\
### Instruction: {instruction}\n
### Context: {context}\n
### Response: {response}"""
def formatting_func(sample: Dict[str, Any]) -> Iterator[Tuple[str, str]]:
# Our annotators were asked to provide new responses, which we assume are better than the originals
og_instruction = sample["original-instruction"]
og_context = sample["original-context"]
rejected = sample["original-response"]
prompt = template.format(instruction=og_instruction, context=og_context, response="")
for instruction, context, response in zip(sample["new-instruction"], sample["new-context"], sample["new-response"]):
if response["status"] == "submitted":
chosen = response["value"]
if chosen != rejected:
yield prompt, chosen, rejected
task = TrainingTask.for_direct_preference_optimization(formatting_func=formatting_func)
ArgillaTrainer
我们将直接在 ArgillaTrainer
中将任务与我们的 FeedbackDataset
一起使用。 与 PPO 相反,我们不需要指定任何奖励模型,因为这种偏好建模是由 DPO 算法在内部推断的。
from argilla.feedback import ArgillaTrainer
trainer = ArgillaTrainer(
dataset=feedback_dataset,
task=task,
framework="trl",
model="gpt2",
)
trainer.train(output_dir="dpo_model")
推理
训练完成后,我们可以加载此模型并使用它生成!
from transformers import AutoModelForCausalLM, AutoTokenizer
model = AutoModelForCausalLM.from_pretrained("dpo_model")
tokenizer = AutoTokenizer.from_pretrained("dpo_model")
tokenizer.pad_token = tokenizer.eos_token
inputs = template.format(
instruction="Is a toad a frog?",
context="Both frogs and toads are amphibians in the order Anura, which means \"without a tail.\" Toads are a sub-classification of frogs, meaning that all toads are frogs, but not all frogs are toads.",
response=""
).strip()
encoding = tokenizer([inputs], return_tensors="pt")
outputs = model.generate(**encoding, max_new_tokens=30)
output_text = tokenizer.decode(outputs[0])
print(output_text)
# Yes it is, toads are a sub-classification of frogs.
聊天完成#
背景#
随着 OpenAI 的 ChatGPT 下面向聊天的模型的兴起,我们看到了人们对使用 LLM 进行面向聊天的任务的浓厚兴趣。 面向聊天的模型与其他 LLM 之间的主要区别在于它们是在不同格式的数据集上训练的。 他们不是使用提示和响应的数据集,而是使用对话的数据集进行训练。 这使他们能够生成更具对话性的响应。 并且,OpenAI 确实支持为聊天完成用例微调 LLM。 更多信息请访问 https://openai.com/blog/gpt-3-5-turbo-fine-tuning-and-api-updates。
User: Hello, how are you?
Agent: I am doing great!
User: When did Virgin Australia start operating?
Agent: Virgin Australia commenced services on 31 August 2000 as Virgin Blue.
User: That is incorrect. I believe it was 2001.
Agent: You are right, it was 2001.
### Instruction
When did Virgin Australia start operating?
### Context
Virgin Australia, the trading name of Virgin Australia Airlines Pty Ltd, is an Australian-based airline.
It is the largest airline by fleet size to use the Virgin brand.
It commenced services on 31 August 2000 as Virgin Blue, with two aircraft on a single route.
It suddenly found itself as a major airline in Australia's domestic market after the collapse of Ansett Australia in September 2001.
The airline has since grown to directly serve 32 cities in Australia, from hubs in Brisbane, Melbourne and Sydney.
### Response:
{to be generated by SFT model}
训练#
我们将使用我们的 精选 Dolly 数据集,如上面的背景部分介绍。
import argilla as rg
feedback_dataset = rg.FeedbackDataset.from_huggingface("argilla/databricks-dolly-15k-curated-en")
数据准备
dataset = rg.FeedbackDataset.from_huggingface("argilla/customer_assistant")
我们将从格式化函数的基本示例开始。 对于聊天完成,它应返回 chat-turn-role-text
,其中提示根据模板格式化。 我们需要此拆分,因为每个对话链都需要能够以正确的顺序并根据可能正在说话的用户角色进行追溯。
注意
我们推断所谓的 message,因为 OpenAI 期望此输出格式,但这对于其他场景可能有所不同。
from argilla.feedback import TrainingTask
from typing import Dict, Any, Iterator
# adapation from LlamaIndex's TEXT_QA_PROMPT_TMPL_MSGS[1].content
user_message_prompt ="""Context information is below.
---------------------
{context_str}
---------------------
Given the context information and not prior knowledge but keeping your Argilla Cloud assistant style, answer the query.
Query: {query_str}
Answer:
"""
# Adapation from LlamaIndex's TEXT_QA_SYSTEM_PROMPT
system_prompt = """You are an expert customer service assistant for the Argilla Cloud product that is trusted around the world.
Always answer the query using the provided context information, and not prior knowledge.
Some rules to follow:
1. Never directly reference the given context in your answer.
2. Avoid statements like 'Based on the context, ...' or 'The context information ...' or anything along those lines.
"""
def formatting_func(sample: dict) -> Union[Tuple[str, str, str, str], List[Tuple[str, str, str, str]]]:
from uuid import uuid4
if sample["response"]:
chat = str(uuid4())
user_message = user_message_prompt.format(context_str=sample["context"], query_str=sample["user-message"])
yield [
(chat, "0", "system", system_prompt),
(chat, "1", "user", user_message),
(chat, "2", "assistant", sample["response"][0]["value"])
]
task = TrainingTask.for_chat_completion(formatting_func=formatting_func)
ArgillaTrainer
我们将直接在 ArgillaTrainer
中将任务与我们的 FeedbackDataset
一起使用。 唯一可配置的参数是 n_epochs
,但这也将在内部进行优化。
from argilla.feedback import ArgillaTrainer
trainer = ArgillaTrainer(
dataset=feedback_dataset,
task=task,
framework="openai",
)
trainer.train(output_dir="chat-completion")
推理
训练完成后,我们可以直接使用该模型,但要这样做,我们需要使用 openai
框架。 因此,我们建议查看 他们的文档。
import openai
completion = openai.ChatCompletion.create(
model="ft:gpt-3.5-turbo:my-org:custom_suffix:id",
messages=[
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Hello!"}
]
)
其他数据集#
注意
本节中涵盖的记录类对应于三个数据集:DatasetForTextClassification
、DatasetForTokenClassification
和 DatasetForText2Text
。 这些将在 Argilla 2.0 中弃用,并由完全可配置的 FeedbackDataset
类取代。 不确定使用哪个数据集? 查看我们关于选择数据集的部分。
ArgillaTrainer
#
ArgillaTrainer
是对我们许多喜爱的 NLP 库的封装。它提供了一个非常直观的抽象表示,以方便使用合理的默认预设配置进行简单的训练工作流,而无需担心来自 Argilla 的任何数据转换。
支持的框架#
框架/任务 |
文本分类 |
Token 分类 |
Text2Text |
---|---|---|---|
OpenAI |
✔️ |
✔️ |
|
SetFit |
✔️ |
||
spaCy |
✔️ |
✔️ |
|
Transformers |
✔️ |
✔️ |
|
PEFT |
✔️ |
✔️ |
|
SpanMarker |
✔️ |
训练配置#
训练器还有一个 ArgillaTrainer.update_config()
方法,它将带有 **kwargs
的字典映射到相应的框架。因此,这些可以从用于初始化训练器的底层框架派生出来。在下面,您可以找到支持的框架的这些变量的概述。
注意
请注意,您不需要直接传递所有这些变量,并且下面的值是它们的默认配置。
# `OpenAI.FineTune`
trainer.update_config(
training_file = None,
validation_file = None,
model = "gpt-3.5-turbo-0613",
hyperparameters = {"n_epochs": 1},
suffix = None
)
# `OpenAI.FineTune` (legacy)
trainer.update_config(
training_file = None,
validation_file = None,
model = "curie",
n_epochs = 2,
batch_size = None,
learning_rate_multiplier = 0.1,
prompt_loss_weight = 0.1,
compute_classification_metrics = False,
classification_n_classes = None,
classification_positive_class = None,
classification_betas = None,
suffix = None
)
# `setfit.SetFitModel`
trainer.update_config(
pretrained_model_name_or_path = "all-MiniLM-L6-v2",
force_download = False,
resume_download = False,
proxies = None,
token = None,
cache_dir = None,
local_files_only = False
)
# `setfit.SetFitTrainer`
trainer.update_config(
metric = "accuracy",
num_iterations = 20,
num_epochs = 1,
learning_rate = 2e-5,
batch_size = 16,
seed = 42,
use_amp = True,
warmup_proportion = 0.1,
distance_metric = "BatchHardTripletLossDistanceFunction.cosine_distance",
margin = 0.25,
samples_per_label = 2
)
# `spacy.training`
trainer.update_config(
dev_corpus = "corpora.dev",
train_corpus = "corpora.train",
seed = 42,
gpu_allocator = 0,
accumulate_gradient = 1,
patience = 1600,
max_epochs = 0,
max_steps = 20000,
eval_frequency = 200,
frozen_components = [],
annotating_components = [],
before_to_disk = None,
before_update = None
)
# `transformers.AutoModelForTextClassification`
trainer.update_config(
pretrained_model_name_or_path = "distilbert-base-uncased",
force_download = False,
resume_download = False,
proxies = None,
token = None,
cache_dir = None,
local_files_only = False
)
# `transformers.TrainingArguments`
trainer.update_config(
per_device_train_batch_size = 8,
per_device_eval_batch_size = 8,
gradient_accumulation_steps = 1,
learning_rate = 5e-5,
weight_decay = 0,
adam_beta1 = 0.9,
adam_beta2 = 0.9,
adam_epsilon = 1e-8,
max_grad_norm = 1,
learning_rate = 5e-5,
num_train_epochs = 3,
max_steps = 0,
log_level = "passive",
logging_strategy = "steps",
save_strategy = "steps",
save_steps = 500,
seed = 42,
push_to_hub = False,
hub_model_id = "user_name/output_dir_name",
hub_strategy = "every_save",
hub_token = "1234",
hub_private_repo = False
)
# `peft.LoraConfig`
trainer.update_config(
r=8,
target_modules=None,
lora_alpha=16,
lora_dropout=0.1,
fan_in_fan_out=False,
bias="none",
inference_mode=False,
modules_to_save=None,
init_lora_weights=True,
)
# `transformers.AutoModelForTextClassification`
trainer.update_config(
pretrained_model_name_or_path = "distilbert-base-uncased",
force_download = False,
resume_download = False,
proxies = None,
token = None,
cache_dir = None,
local_files_only = False
)
# `transformers.TrainingArguments`
trainer.update_config(
per_device_train_batch_size = 8,
per_device_eval_batch_size = 8,
gradient_accumulation_steps = 1,
learning_rate = 5e-5,
weight_decay = 0,
adam_beta1 = 0.9,
adam_beta2 = 0.9,
adam_epsilon = 1e-8,
max_grad_norm = 1,
learning_rate = 5e-5,
num_train_epochs = 3,
max_steps = 0,
log_level = "passive",
logging_strategy = "steps",
save_strategy = "steps",
save_steps = 500,
seed = 42,
push_to_hub = False,
hub_model_id = "user_name/output_dir_name",
hub_strategy = "every_save",
hub_token = "1234",
hub_private_repo = False
)
# `SpanMarkerConfig`
trainer.update_config(
pretrained_model_name_or_path = "distilbert-base-cased"
model_max_length = 256,
marker_max_length = 128,
entity_max_length = 8,
)
# `transformers.TrainingArguments`
trainer.update_config(
per_device_train_batch_size = 8,
per_device_eval_batch_size = 8,
gradient_accumulation_steps = 1,
learning_rate = 5e-5,
weight_decay = 0,
adam_beta1 = 0.9,
adam_beta2 = 0.9,
adam_epsilon = 1e-8,
max_grad_norm = 1,
learning_rate = 5e-5,
num_train_epochs = 3,
max_steps = 0,
log_level = "passive",
logging_strategy = "steps",
save_strategy = "steps",
save_steps = 500,
seed = 42,
push_to_hub = False,
hub_model_id = "user_name/output_dir_name",
hub_strategy = "every_save",
hub_token = "1234",
hub_private_repo = False
)
任务#
在本部分中,我们将探讨文本分类、Token 分类和 Text2Text 任务。我们将简要描述每个任务的含义以及训练和进行预测所涉及的步骤。
文本分类#
背景#
文本分类是一种广泛使用的 NLP 任务,其中标签被分配给文本。主要公司依靠它来用于各种应用。情感分析是文本分类的一种流行形式,它将标签(如 🙂 正面、🙁 负面或 😐 中性)分配给文本。此外,我们区分单标签和多标签文本分类。
单标签文本分类是指为给定的文本样本分配单个类别或标签的任务。每个文本仅与一个预定义的类或类别相关联。例如,在情感分析中,单标签文本分类任务将涉及根据文本的情感为文本分配诸如“正面”、“负面”或“中性”之类的标签。
"The help for my application of a new card and mortgage was great", "positive"
由于确定和预测每个文本的多个相关标签的挑战,多标签文本分类通常比单标签分类更复杂。它在各个领域都有应用,包括文档标记、主题标记和内容推荐系统。例如,在客户服务中,多标签文本分类任务将涉及根据文本内容为文本分配诸如“new_card”、“mortgage”或“opening_hours”之类的主题。
提示
对于多标签场景,建议添加一些没有任何标签的示例,以提高模型性能。
"The help for my application of a new card and mortgage was great", ["new_card", "mortgage"]
训练#
from argilla.feedback import ArgillaTrainer, FeedbackDataset, TrainingTask
dataset = FeedbackDataset.from_huggingface(
repo_id="argilla/emotion"
)
task = TrainingTask.for_text_classification(
text=dataset.field_by_name("text"),
label=dataset.question_by_name("label"),
)
trainer = ArgillaTrainer(
dataset=dataset,
task=task,
framework="setfit"
)
trainer.update_config(num_iterations=1)
trainer.train(output_dir="my_setfit_model")
trainer.predict("This is awesome!")
Token 分类#
背景#
Token 分类是 NLP 领域中的一个关键概念。它需要为给定文本中的单个单词或 token 分配特定标签。这些标签可以包含各种语言或语义属性,例如词性标注、命名实体(包括人名、组织或地点)或情感指标(表达积极性、消极性或中立性)。此过程是许多 NLP 应用程序不可或缺的基础,有助于从文本数据中提取有价值的见解。
训练#
import argilla as rg
from datasets import load_dataset
from argilla.training import ArgillaTrainer
dataset_rg = rg.DatasetForTokenClassification.from_datasets(
dataset=load_dataset("conll2003", split="train[:100]"),
tags="ner_tags",
)
rg.log(dataset_rg, name="conll2003", workspace="admin")
trainer = ArgillaTrainer(
name="conll2003",
workspace="admin",
framework="spacy",
train_size=0.8
)
trainer.update_config(num_train_epochs=2)
trainer.train(output_dir="my_spacy_model")
records = trainer.predict("The ArgillaTrainer is great!", as_argilla_records=True)
rg.log(records=records, name="conll2003", workspace="admin")
Text2Text#
背景#
NLP 领域的 Text2Text 任务代表一个框架,该框架将一段文本作为输入以将其转换为另一段文本。 T2T 不是将不同的 NLP 挑战视为孤立的问题,而是试图通过将它们构建为序列到序列的转换来创建通用的解决方案。 在这种方法中,输入和输出都被视为文本序列,并且它们的长度可能会有所不同。
训练#
import argilla as rg
from datasets import load_dataset
from argilla.training import ArgillaTrainer
dataset_rg = rg.DatasetForText2Text.from_datasets(
dataset=load_dataset("opus_books", "en-fr", split="train[:100]"),
tags="ner_tags",
)
rg.log(dataset_rg, name="opus_books", workspace="admin")
trainer = ArgillaTrainer(
name="opus_books",
workspace="admin",
framework="openAI",
train_size=0.8
)
trainer.update_config(max_epochs=2)
trainer.train(output_dir="my_openAI_model")
records = trainer.predict("The ArgillaTrainer is great!", as_argilla_records=True)
rg.log(records=records, name="opus_books", workspace="admin")
其他选项#
准备训练#
如果您想训练模型,我们提供了一种方便的方法来准备数据集:DatasetFor*.prepare_for_training()
。 它将返回一个 Hugging Face 数据集、一个 spaCy DocBin 或一个 SparkNLP 格式的 DataFrame,这些数据集、DocBin 或 DataFrame 针对使用 Hugging Face Trainer、spaCy CLI 或 SparkNLP API 进行的训练过程进行了优化。
可以通过传递 train_size
和 test_size
参数,直接将训练-测试拆分包含到 prepare_for_training
中。
import argilla as rg
dataset_rg = rg.load("<my_dataset>")
dataset_rg.prepare_for_training(framework="openai", train_size=1)
# [{'promt': 'My title', 'completion': ' My content'}]
import argilla as rg
dataset_rg = rg.load("<my_dataset>")
dataset_rg.prepare_for_training(framework="autotrain", train_size=1)
# {'title': 'My title', 'content': 'My content', 'label': 0}
import argilla as rg
dataset_rg = rg.load("<my_dataset>")
dataset_rg.prepare_for_training(framework="setfit", train_size=1)
# {'title': 'My title', 'content': 'My content', 'label': 0}
import argilla as rg
import spacy
nlp = spacy.blank("en")
dataset_rg = rg.load("<my_dataset>")
dataset_rg.prepare_for_training(framework="spacy", lang=nlp, train_size=1)
# <spacy.tokens._serialize.DocBin object at 0x280613af0>
import argilla as rg
dataset_rg = rg.load("<my_dataset>")
dataset_rg.prepare_for_training(framework="transformers", train_size=1)
# {'title': 'My title', 'content': 'My content', 'label': 0}
import argilla as rg
dataset_rg = rg.load("<my_dataset>")
dataset_rg.prepare_for_training(framework="peft", train_size=1)
# {'title': 'My title', 'content': 'My content', 'label': 0}
import argilla as rg
dataset_rg = rg.load("<my_dataset>")
dataset_rg.prepare_for_training(framework="span_marker", train_size=1)
# {'title': 'My title', 'content': 'My content', 'label': 0}
import argilla as rg
dataset_rg = rg.load("<my_dataset>")
dataset_rg.prepare_for_training(framework="spark-nlp", train_size=1)
# <pd.DataFrame>
import argilla as rg
dataset_rg = rg.load("<my_dataset>")
dataset_rg.prepare_for_training(framework="trl", task=..., train_size=1)
CLI 支持#
我们还为 ArgillaTrainer 提供 CLI 支持。 例如,在外部计算机上执行训练时可以使用它。 请注意,–update-config-kwargs 始终对相应的类使用 update_config() 方法。 因此,您应该考虑到这一点,通过传递 JSON 可序列化字符串来通过 CLI 命令配置训练。
Usage: python -m argilla train [OPTIONS] COMMAND [ARGS]...
Starts the ArgillaTrainer.
Options:
--name TEXT The name of the dataset to be used for training. [default: None]
--framework [transformers|peft|setfit|spacy| The framework to be used for training. [default: None]
spacy-transformers|span_marker|spark-nlp|
openai|trl|trlx|sentence-transformers]
--workspace TEXT The workspace to be used for training. [default: None]
--limit INTEGER The number of record to be used. [default: None]
--query TEXT The query to be used. [default: None]
--model TEXT The modelname or path to be used for training. [default: None]
--train-size FLOAT The train split to be used. [default: 1.0]
--seed INTEGER The random seed number. [default: 42]
--device INTEGER The GPU id to be used for training. [default: -1]
--output-dir TEXT Output directory for the saved model. [default: model]
--update-config-kwargs TEXT update_config() kwargs to be passed as a dictionary. [default: {}]
--help Show this message and exit.