训练#
这里我们描述了 Argilla 中可用的训练器
基础训练器:用于处理训练器的内部机制
SetFit 训练器:用于处理 SetFit 模型训练逻辑的内部机制
OpenAI 训练器:用于处理 OpenAI 模型训练逻辑的内部机制
PEFT (LoRA) 训练器:用于处理 PEFT (LoRA) 模型训练逻辑的内部机制
spaCy 训练器:用于处理 spaCy 模型训练逻辑的内部机制
Transformers 训练器:用于处理 Transformers 模型训练逻辑的内部机制
SpanMarker 训练器:用于处理 SpanMarker 模型训练逻辑的内部机制
TRL 训练器:用于处理 TRL 模型训练逻辑的内部机制
SentenceTransformer 训练器:用于处理 SentenceTransformer 模型训练逻辑的内部机制
基础训练器#
- class argilla.training.base.ArgillaTrainerSkeleton(name, dataset, record_class, workspace=None, multi_label=False, settings=None, model=None, seed=None, *arg, **kwargs)#
- 参数:
name (str) –
record_class (Union[TokenClassificationRecord, Text2TextRecord, TextClassificationRecord]) –
workspace (Optional[str]) –
multi_label (bool) –
settings (Union[TextClassificationSettings, TokenClassificationSettings]) –
model (str) –
seed (int) –
- get_model()#
返回模型。
- get_model_card_data(card_data_kwargs)#
生成一个 FrameworkCardData 实例以从中生成模型卡片。
- 参数:
card_data_kwargs (Dict[str, Any]) –
- 返回类型:
- get_model_kwargs()#
返回模型 kwargs。
- get_tokenizer()#
返回 tokenizer。
- get_trainer()#
返回 trainer。
- get_trainer_kwargs()#
返回训练 kwargs。
- abstract init_model()#
初始化模型。
- abstract init_training_args()#
初始化训练参数。
- abstract predict(text, as_argilla_records=True, **kwargs)#
预测文本的标签。
- 参数:
text (Union[List[str], str]) –
as_argilla_records (bool) –
- push_to_huggingface(repo_id, **kwargs)#
将模型上传到 [Huggingface Hub](https://hugging-face.cn/docs/hub/models-the-hub)。
- 参数:
repo_id (str) –
- 返回类型:
Optional[str]
- abstract save(output_dir)#
将模型保存到指定路径。
- 参数:
output_dir (str) –
- abstract train(output_dir=None)#
训练模型。
- 参数:
output_dir (Optional[str]) –
- abstract update_config(*args, **kwargs)#
更新训练器的配置,但参数取决于 trainer.subclass。
- class argilla.client.feedback.integrations.huggingface.model_card.FrameworkCardData(*args, **kwargs)#
用于生成要添加到 ModelCard 的变量的父类。
每个框架将从此继承并进行相应的更新。
- 参数:
language (Optional[Union[str, List[str]]]) –
license (Optional[str]) –
model_name (Optional[str]) –
model_id (Optional[str]) –
dataset_name (Optional[str]) –
dataset_id (Optional[str]) –
tags (Optional[List[str]]) –
model_summary (Optional[str]) –
model_description (Optional[str]) –
developers (Optional[str]) –
shared_by (Optional[str]) –
model_type (Optional[str]) –
finetuned_from (Optional[str]) –
repo (Optional[str]) –
_is_on_huggingface (bool) –
framework (Optional[Framework]) –
train_size (Optional[float]) –
seed (Optional[int]) –
framework_kwargs (Dict[str, Any]) –
task (Optional[Union[TrainingTaskForTextClassification, TrainingTaskForSFT, TrainingTaskForRM, TrainingTaskForPPO, TrainingTaskForDPO, TrainingTaskForChatCompletion, TrainingTaskForSentenceSimilarity]]) –
output_dir (Optional[str]) –
library_name (Optional[str]) –
update_config_kwargs (Dict[str, Any]) –
- to_dict()#
生成将写入模型卡片的变量的主要方法。
- 返回类型:
Dict[str, Any]
SetFit 训练器#
- class argilla.training.setfit.ArgillaSetFitTrainer(*args, **kwargs)#
- init_model()#
初始化模型。
- init_training_args()#
初始化训练参数。
- 返回类型:
无
- predict(text, as_argilla_records=True, **kwargs)#
该函数接受字符串列表并返回预测列表
- 参数:
text (Union[List[str], str]) – 要分类的文本。
as_argilla_records (bool) – 如果为 True,则预测将作为 Argilla 记录返回。如果
为 False,则预测将作为字符串返回。默认为 True
- 返回:
预测列表
- 参数:
text (Union[List[str], str]) –
as_argilla_records (bool) –
- save(output_dir)#
该函数将模型保存到指定的路径,并将 label2id 和 id2label 字典保存到同一路径
- 参数:
path (str) – 保存模型的路径
output_dir (str) –
- train(output_dir=None)#
我们从预训练模型创建一个 SetFitModel 对象,然后使用该模型创建一个 SetFitTrainer 对象,然后训练该模型
- 参数:
output_dir (Optional[str]) –
- update_config(**kwargs)#
使用传递给 update_config 函数的关键字参数更新 model_kwargs 和 trainer_kwargs 字典。
- 返回类型:
无
- class argilla.client.feedback.integrations.huggingface.model_card.SetFitModelCardData(language: Union[str, List[str], NoneType] = None, license: Optional[str] = None, model_name: Optional[str] = None, model_id: Optional[str] = None, dataset_name: Optional[str] = None, dataset_id: Optional[str] = None, tags: Optional[List[str]] = <factory>, model_summary: Optional[str] = None, model_description: Optional[str] = None, developers: Optional[str] = None, shared_by: Optional[str] = None, model_type: Optional[str] = None, finetuned_from: Optional[str] = None, repo: Optional[str] = None, _is_on_huggingface: bool = False, framework: argilla.client.models.Framework = <Framework.SETFIT: 'setfit'>, train_size: Optional[float] = None, seed: Optional[int] = None, framework_kwargs: Dict[str, Any] = <factory>, task: Union[argilla.client.feedback.training.schemas.base.TrainingTaskForTextClassification, argilla.client.feedback.training.schemas.base.TrainingTaskForSFT, argilla.client.feedback.training.schemas.base.TrainingTaskForRM, argilla.client.feedback.training.schemas.base.TrainingTaskForPPO, argilla.client.feedback.training.schemas.base.TrainingTaskForDPO, argilla.client.feedback.training.schemas.base.TrainingTaskForChatCompletion, argilla.client.feedback.training.schemas.base.TrainingTaskForSentenceSimilarity, NoneType] = None, output_dir: Optional[str] = None, library_name: Optional[str] = None, update_config_kwargs: Dict[str, Any] = <factory>, tokenizer: 'PreTrainedTokenizer' = '')#
- 参数:
language (Optional[Union[str, List[str]]]) –
license (Optional[str]) –
model_name (Optional[str]) –
model_id (Optional[str]) –
dataset_name (Optional[str]) –
dataset_id (Optional[str]) –
tags (Optional[List[str]]) –
model_summary (Optional[str]) –
model_description (Optional[str]) –
developers (Optional[str]) –
shared_by (Optional[str]) –
model_type (Optional[str]) –
finetuned_from (Optional[str]) –
repo (Optional[str]) –
_is_on_huggingface (bool) –
framework (Framework) –
train_size (Optional[float]) –
seed (Optional[int]) –
framework_kwargs (Dict[str, Any]) –
task (Optional[Union[TrainingTaskForTextClassification, TrainingTaskForSFT, TrainingTaskForRM, TrainingTaskForPPO, TrainingTaskForDPO, TrainingTaskForChatCompletion, TrainingTaskForSentenceSimilarity]]) –
output_dir (Optional[str]) –
library_name (Optional[str]) –
update_config_kwargs (Dict[str, Any]) –
tokenizer (PreTrainedTokenizer) –
OpenAI 训练器#
- class argilla.training.openai.ArgillaOpenAITrainer(*args, **kwargs)#
- init_model()#
初始化模型。
- 返回类型:
无
- init_training_args(training_file=None, validation_file=None, model='curie', n_epochs=None, batch_size=None, learning_rate_multiplier=0.1, prompt_loss_weight=0.1, compute_classification_metrics=False, classification_n_classes=None, classification_positive_class=None, classification_betas=None, suffix=None, hyperparameters=None)#
初始化训练参数。
- 参数:
training_file (Optional[str]) –
validation_file (Optional[str]) –
model (str) –
n_epochs (Optional[int]) –
batch_size (Optional[int]) –
learning_rate_multiplier (float) –
prompt_loss_weight (float) –
compute_classification_metrics (bool) –
classification_n_classes (Optional[int]) –
classification_positive_class (Optional[str]) –
classification_betas (Optional[list]) –
suffix (Optional[str]) –
hyperparameters (Optional[dict]) –
- 返回类型:
无
- predict(text, as_argilla_records=True, **kwargs)#
该函数接受字符串列表并返回预测列表
- 参数:
text (Union[List[str], str]) – 要分类的文本。
as_argilla_records (bool) – 如果为 True,则预测将作为 Argilla 记录返回。如果
- 返回类型:
Union[List, str]
为 False,则预测将作为字符串返回。默认为 True
- 返回:
预测列表
- 参数:
text (Union[List[str], str]) –
as_argilla_records (bool) –
- 返回类型:
Union[List, str]
- save(*arg, **kwargs)#
该函数将模型保存到指定的路径,并将 label2id 和 id2label 字典保存到同一路径
- 参数:
output_dir (str) – 保存模型的路径
- 返回类型:
无
- train(output_dir=None)#
我们从预训练模型创建一个 openai.FineTune 对象,并将数据发送到微调它。
- 参数:
output_dir (Optional[str]) –
- 返回类型:
无
- update_config(**kwargs)#
使用传递给 update_config 函数的关键字参数更新 model_kwargs 字典。
- class argilla.client.feedback.integrations.huggingface.model_card.OpenAIModelCardData(language: Union[str, List[str], NoneType] = None, license: Optional[str] = None, model_name: Optional[str] = None, model_id: Optional[str] = None, dataset_name: Optional[str] = None, dataset_id: Optional[str] = None, tags: Optional[List[str]] = <factory>, model_summary: Optional[str] = None, model_description: Optional[str] = None, developers: Optional[str] = None, shared_by: Optional[str] = None, model_type: Optional[str] = None, finetuned_from: Optional[str] = None, repo: Optional[str] = None, _is_on_huggingface: bool = False, framework: argilla.client.models.Framework = <Framework.OPENAI: 'openai'>, train_size: Optional[float] = None, seed: Optional[int] = None, framework_kwargs: Dict[str, Any] = <factory>, task: Union[argilla.client.feedback.training.schemas.base.TrainingTaskForTextClassification, argilla.client.feedback.training.schemas.base.TrainingTaskForSFT, argilla.client.feedback.training.schemas.base.TrainingTaskForRM, argilla.client.feedback.training.schemas.base.TrainingTaskForPPO, argilla.client.feedback.training.schemas.base.TrainingTaskForDPO, argilla.client.feedback.training.schemas.base.TrainingTaskForChatCompletion, argilla.client.feedback.training.schemas.base.TrainingTaskForSentenceSimilarity, NoneType] = None, output_dir: Optional[str] = None, library_name: Optional[str] = None, update_config_kwargs: Dict[str, Any] = <factory>)#
- 参数:
language (Optional[Union[str, List[str]]]) –
license (Optional[str]) –
model_name (Optional[str]) –
model_id (Optional[str]) –
dataset_name (Optional[str]) –
dataset_id (Optional[str]) –
tags (Optional[List[str]]) –
model_summary (Optional[str]) –
model_description (Optional[str]) –
developers (Optional[str]) –
shared_by (Optional[str]) –
model_type (Optional[str]) –
finetuned_from (Optional[str]) –
repo (Optional[str]) –
_is_on_huggingface (bool) –
framework (Framework) –
train_size (Optional[float]) –
seed (Optional[int]) –
framework_kwargs (Dict[str, Any]) –
task (Optional[Union[TrainingTaskForTextClassification, TrainingTaskForSFT, TrainingTaskForRM, TrainingTaskForPPO, TrainingTaskForDPO, TrainingTaskForChatCompletion, TrainingTaskForSentenceSimilarity]]) –
output_dir (Optional[str]) –
library_name (Optional[str]) –
update_config_kwargs (Dict[str, Any]) –
PEFT (LoRA) 训练器#
- class argilla.training.peft.ArgillaPeftTrainer(*args, **kwargs)#
- init_model(new=False)#
初始化模型。
- 参数:
new (bool) –
- init_training_args()#
初始化训练参数。
- predict(text, as_argilla_records=True, **kwargs)#
该函数接受字符串列表并返回预测列表
- 参数:
text (Union[List[str], str]) – 要分类的文本。
as_argilla_records (bool) – 如果为 True,则预测将作为 Argilla 记录返回。如果
为 False,则预测将作为字符串返回。默认为 True
- 返回:
预测列表
- 参数:
text (Union[List[str], str]) –
as_argilla_records (bool) –
- save(output_dir)#
该函数将模型保存到指定的路径,并将 label2id 和 id2label 字典保存到同一路径
- 参数:
output_dir (str) – 保存模型的路径
- sys = <module 'sys' (built-in)>#
- update_config(**kwargs)#
使用传递给 update_config 函数的关键字参数更新 model_kwargs 和 trainer_kwargs 字典。
- class argilla.client.feedback.integrations.huggingface.model_card.PeftModelCardData(language: Union[str, List[str], NoneType] = None, license: Optional[str] = None, model_name: Optional[str] = None, model_id: Optional[str] = None, dataset_name: Optional[str] = None, dataset_id: Optional[str] = None, tags: Optional[List[str]] = <factory>, model_summary: Optional[str] = None, model_description: Optional[str] = None, developers: Optional[str] = None, shared_by: Optional[str] = None, model_type: Optional[str] = None, finetuned_from: Optional[str] = None, repo: Optional[str] = None, _is_on_huggingface: bool = False, framework: argilla.client.models.Framework = <Framework.PEFT: 'peft'>, train_size: Optional[float] = None, seed: Optional[int] = None, framework_kwargs: Dict[str, Any] = <factory>, task: Union[argilla.client.feedback.training.schemas.base.TrainingTaskForTextClassification, argilla.client.feedback.training.schemas.base.TrainingTaskForSFT, argilla.client.feedback.training.schemas.base.TrainingTaskForRM, argilla.client.feedback.training.schemas.base.TrainingTaskForPPO, argilla.client.feedback.training.schemas.base.TrainingTaskForDPO, argilla.client.feedback.training.schemas.base.TrainingTaskForChatCompletion, argilla.client.feedback.training.schemas.base.TrainingTaskForSentenceSimilarity, NoneType] = None, output_dir: Optional[str] = None, library_name: Optional[str] = None, update_config_kwargs: Dict[str, Any] = <factory>, tokenizer: 'PreTrainedTokenizer' = '')#
- 参数:
language (Optional[Union[str, List[str]]]) –
license (Optional[str]) –
model_name (Optional[str]) –
model_id (Optional[str]) –
dataset_name (Optional[str]) –
dataset_id (Optional[str]) –
tags (Optional[List[str]]) –
model_summary (Optional[str]) –
model_description (Optional[str]) –
developers (Optional[str]) –
shared_by (Optional[str]) –
model_type (Optional[str]) –
finetuned_from (Optional[str]) –
repo (Optional[str]) –
_is_on_huggingface (bool) –
framework (Framework) –
train_size (Optional[float]) –
seed (Optional[int]) –
framework_kwargs (Dict[str, Any]) –
task (Optional[Union[TrainingTaskForTextClassification, TrainingTaskForSFT, TrainingTaskForRM, TrainingTaskForPPO, TrainingTaskForDPO, TrainingTaskForChatCompletion, TrainingTaskForSentenceSimilarity]]) –
output_dir (Optional[str]) –
library_name (Optional[str]) –
update_config_kwargs (Dict[str, Any]) –
tokenizer (PreTrainedTokenizer) –
spaCy 训练器#
- class argilla.training.spacy.ArgillaSpaCyTrainer(freeze_tok2vec=False, **kwargs)#
- 参数:
freeze_tok2vec (bool) –
- init_training_args()#
此方法用于生成 spacy 配置文件,该文件用于训练
- 返回类型:
无
- class argilla.training.spacy.ArgillaSpaCyTransformersTrainer(update_transformer=True, **kwargs)#
- 参数:
update_transformer (bool) –
- init_training_args()#
此方法用于生成 spacy 配置文件,该文件用于训练
- 返回类型:
无
- class argilla.training.spacy._ArgillaSpaCyTrainerBase(language=None, gpu_id=-1, model=None, optimize='efficiency', *args, **kwargs)#
- 参数:
language (Optional[str]) –
gpu_id (Optional[int]) –
model (Optional[str]) –
optimize (Literal['efficiency', 'accuracy']) –
- init_model()#
初始化模型。
- predict(text, as_argilla_records=True, **kwargs)#
使用经过训练的 pipeline 预测给定文本的标签。
- 参数:
text (Union[List[str], str]) – 一个 str 或一个 List[str],包含要预测标签的文本。
as_argilla_records (bool) – 一个 bool,指示是否将预测结果作为 argilla 记录或 dicts 返回。默认为 True。
- 返回:
一个 dict、BaseModel(如果 as_argilla_records 为 True)或一个 List[dict]、List[BaseModel](如果 as_argilla_records 为 True),包含预测结果。
- 返回类型:
Union[Dict[str, Any], List[Dict[str, Any]], BaseModel, List[BaseModel]]
- save(output_dir)#
将训练后的 pipeline 保存到磁盘。
- 参数:
output_dir (str) – 一个 str,包含保存 pipeline 的路径。
- 返回类型:
无
- train(output_dir=None)#
使用 spaCy 训练 pipeline。
- 参数:
output_dir (Optional[str]) – 一个 str,包含保存训练后的 pipeline 的路径。默认为 None。
- 返回类型:
无
- update_config(**spacy_training_config)#
更新 spaCy 训练配置。
免责声明:目前仅支持 training 配置,但未来我们将支持所有的 spaCy 配置值,以便更精确地控制训练过程。另请注意,CPU 和 GPU 训练之间的参数可能有所不同。
- 参数:
**spacy_training_config – spaCy 训练配置。
- 返回类型:
无
- class argilla.client.feedback.integrations.huggingface.model_card.SpacyTransformersModelCardData(language: Union[str, List[str], NoneType] = None, license: Optional[str] = None, model_name: Optional[str] = None, model_id: Optional[str] = None, dataset_name: Optional[str] = None, dataset_id: Optional[str] = None, tags: Optional[List[str]] = <factory>, model_summary: Optional[str] = None, model_description: Optional[str] = None, developers: Optional[str] = None, shared_by: Optional[str] = None, model_type: Optional[str] = None, finetuned_from: Optional[str] = None, repo: Optional[str] = None, _is_on_huggingface: bool = False, framework: argilla.client.models.Framework = <Framework.SPACY_TRANSFORMERS: 'spacy-transformers'>, train_size: Optional[float] = None, seed: Optional[int] = None, framework_kwargs: Dict[str, Any] = <factory>, task: Union[argilla.client.feedback.training.schemas.base.TrainingTaskForTextClassification, argilla.client.feedback.training.schemas.base.TrainingTaskForSFT, argilla.client.feedback.training.schemas.base.TrainingTaskForRM, argilla.client.feedback.training.schemas.base.TrainingTaskForPPO, argilla.client.feedback.training.schemas.base.TrainingTaskForDPO, argilla.client.feedback.training.schemas.base.TrainingTaskForChatCompletion, argilla.client.feedback.training.schemas.base.TrainingTaskForSentenceSimilarity, NoneType] = None, output_dir: Optional[str] = None, library_name: Optional[str] = None, update_config_kwargs: Dict[str, Any] = <factory>, lang: Optional[ForwardRef('spacy.Language')] = None, gpu_id: Optional[int] = -1, optimize: Literal['efficiency', 'accuracy'] = 'efficiency', pipeline: List[str] = <factory>, update_transformer: bool = True)#
- 参数:
language (Optional[Union[str, List[str]]]) –
license (Optional[str]) –
model_name (Optional[str]) –
model_id (Optional[str]) –
dataset_name (Optional[str]) –
dataset_id (Optional[str]) –
tags (Optional[List[str]]) –
model_summary (Optional[str]) –
model_description (Optional[str]) –
developers (Optional[str]) –
shared_by (Optional[str]) –
model_type (Optional[str]) –
finetuned_from (Optional[str]) –
repo (Optional[str]) –
_is_on_huggingface (bool) –
framework (Framework) –
train_size (Optional[float]) –
seed (Optional[int]) –
framework_kwargs (Dict[str, Any]) –
task (Optional[Union[TrainingTaskForTextClassification, TrainingTaskForSFT, TrainingTaskForRM, TrainingTaskForPPO, TrainingTaskForDPO, TrainingTaskForChatCompletion, TrainingTaskForSentenceSimilarity]]) –
output_dir (Optional[str]) –
library_name (Optional[str]) –
update_config_kwargs (Dict[str, Any]) –
lang (Optional[spacy.Language]) –
gpu_id (Optional[int]) –
optimize (Literal['efficiency', 'accuracy']) –
pipeline (List[str]) –
update_transformer (bool) –
Transformers 训练器#
- class argilla.training.transformers.ArgillaTransformersTrainer(*args, **kwargs)#
- init_model(new=False)#
初始化模型。
- 参数:
new (bool) –
- init_training_args()#
初始化训练参数。
- predict(text, as_argilla_records=True, **kwargs)#
该函数接受字符串列表并返回预测列表
- 参数:
text (Union[List[str], str]) – 要分类的文本。
as_argilla_records (bool) – 如果为 True,则预测将作为 Argilla 记录返回。如果
为 False,则预测将作为字符串返回。默认为 True
- 返回:
预测列表
- 参数:
text (Union[List[str], str]) –
as_argilla_records (bool) –
- save(output_dir)#
该函数将模型保存到指定的路径,并将 label2id 和 id2label 字典保存到同一路径
- 参数:
output_dir (str) – 保存模型的路径
- train(output_dir)#
训练模型。
- 参数:
output_dir (str) –
- update_config(**kwargs)#
使用传递给 update_config 函数的关键字参数更新 setfit_model_kwargs 和 setfit_trainer_kwargs 字典。
- class argilla.client.feedback.integrations.huggingface.model_card.TransformersModelCardData(language: Union[str, List[str], NoneType] = None, license: Optional[str] = None, model_name: Optional[str] = None, model_id: Optional[str] = None, dataset_name: Optional[str] = None, dataset_id: Optional[str] = None, tags: Optional[List[str]] = <factory>, model_summary: Optional[str] = None, model_description: Optional[str] = None, developers: Optional[str] = None, shared_by: Optional[str] = None, model_type: Optional[str] = None, finetuned_from: Optional[str] = None, repo: Optional[str] = None, _is_on_huggingface: bool = False, framework: argilla.client.models.Framework = <Framework.TRANSFORMERS: 'transformers'>, train_size: Optional[float] = None, seed: Optional[int] = None, framework_kwargs: Dict[str, Any] = <factory>, task: Union[argilla.client.feedback.training.schemas.base.TrainingTaskForTextClassification, argilla.client.feedback.training.schemas.base.TrainingTaskForSFT, argilla.client.feedback.training.schemas.base.TrainingTaskForRM, argilla.client.feedback.training.schemas.base.TrainingTaskForPPO, argilla.client.feedback.training.schemas.base.TrainingTaskForDPO, argilla.client.feedback.training.schemas.base.TrainingTaskForChatCompletion, argilla.client.feedback.training.schemas.base.TrainingTaskForSentenceSimilarity, NoneType] = None, output_dir: Optional[str] = None, library_name: Optional[str] = None, update_config_kwargs: Dict[str, Any] = <factory>, tokenizer: 'PreTrainedTokenizer' = '')#
- 参数:
language (Optional[Union[str, List[str]]]) –
license (Optional[str]) –
model_name (Optional[str]) –
model_id (Optional[str]) –
dataset_name (Optional[str]) –
dataset_id (Optional[str]) –
tags (Optional[List[str]]) –
model_summary (Optional[str]) –
model_description (Optional[str]) –
developers (Optional[str]) –
shared_by (Optional[str]) –
model_type (Optional[str]) –
finetuned_from (Optional[str]) –
repo (Optional[str]) –
_is_on_huggingface (bool) –
framework (Framework) –
train_size (Optional[float]) –
seed (Optional[int]) –
framework_kwargs (Dict[str, Any]) –
task (Optional[Union[TrainingTaskForTextClassification, TrainingTaskForSFT, TrainingTaskForRM, TrainingTaskForPPO, TrainingTaskForDPO, TrainingTaskForChatCompletion, TrainingTaskForSentenceSimilarity]]) –
output_dir (Optional[str]) –
library_name (Optional[str]) –
update_config_kwargs (Dict[str, Any]) –
tokenizer (PreTrainedTokenizer) –
SpanMarker 训练器#
- class argilla.training.span_marker.ArgillaSpanMarkerTrainer(*args, **kwargs)#
- init_model()#
初始化模型。
- 返回类型:
无
- init_training_args()#
初始化训练参数。
- 返回类型:
无
- predict(text, as_argilla_records=True, **kwargs)#
该函数接受字符串列表并返回预测列表
- 参数:
text (Union[List[str], str]) – 要分类的文本。
as_argilla_records (bool) – 如果为 True,则预测将作为 Argilla 记录返回。如果
为 False,则预测将作为字符串返回。默认为 True
- 返回:
预测列表
- 参数:
text (Union[List[str], str]) –
as_argilla_records (bool) –
- save(output_dir)#
该函数将模型保存到指定的路径,并将 label2id 和 id2label 字典保存到同一路径
- 参数:
output_dir (str) – 保存模型的路径
- train(output_dir)#
我们从预训练模型创建一个 SetFitModel 对象,然后使用该模型创建一个 SetFitTrainer 对象,然后训练该模型
- 参数:
output_dir (str) –
- update_config(**kwargs)#
使用传递给 update_config 函数的关键字参数更新 model_kwargs 和 trainer_kwargs 字典。
- 返回类型:
无
TRL 训练器#
- class argilla.client.feedback.integrations.huggingface.model_card.TRLModelCardData(language: Union[str, List[str], NoneType] = None, license: Optional[str] = None, model_name: Optional[str] = None, model_id: Optional[str] = None, dataset_name: Optional[str] = None, dataset_id: Optional[str] = None, tags: Optional[List[str]] = <factory>, model_summary: Optional[str] = None, model_description: Optional[str] = None, developers: Optional[str] = None, shared_by: Optional[str] = None, model_type: Optional[str] = None, finetuned_from: Optional[str] = None, repo: Optional[str] = None, _is_on_huggingface: bool = False, framework: argilla.client.models.Framework = <Framework.TRL: 'trl'>, train_size: Optional[float] = None, seed: Optional[int] = None, framework_kwargs: Dict[str, Any] = <factory>, task: Union[argilla.client.feedback.training.schemas.base.TrainingTaskForTextClassification, argilla.client.feedback.training.schemas.base.TrainingTaskForSFT, argilla.client.feedback.training.schemas.base.TrainingTaskForRM, argilla.client.feedback.training.schemas.base.TrainingTaskForPPO, argilla.client.feedback.training.schemas.base.TrainingTaskForDPO, argilla.client.feedback.training.schemas.base.TrainingTaskForChatCompletion, argilla.client.feedback.training.schemas.base.TrainingTaskForSentenceSimilarity, NoneType] = None, output_dir: Optional[str] = None, library_name: Optional[str] = None, update_config_kwargs: Dict[str, Any] = <factory>)#
- 参数:
language (Optional[Union[str, List[str]]]) –
license (Optional[str]) –
model_name (Optional[str]) –
model_id (Optional[str]) –
dataset_name (Optional[str]) –
dataset_id (Optional[str]) –
tags (Optional[List[str]]) –
model_summary (Optional[str]) –
model_description (Optional[str]) –
developers (Optional[str]) –
shared_by (Optional[str]) –
model_type (Optional[str]) –
finetuned_from (Optional[str]) –
repo (Optional[str]) –
_is_on_huggingface (bool) –
framework (Framework) –
train_size (Optional[float]) –
seed (Optional[int]) –
framework_kwargs (Dict[str, Any]) –
task (Optional[Union[TrainingTaskForTextClassification, TrainingTaskForSFT, TrainingTaskForRM, TrainingTaskForPPO, TrainingTaskForDPO, TrainingTaskForChatCompletion, TrainingTaskForSentenceSimilarity]]) –
output_dir (Optional[str]) –
library_name (Optional[str]) –
update_config_kwargs (Dict[str, Any]) –
SentenceTransformer 训练器#
- class argilla.client.feedback.training.frameworks.sentence_transformers.ArgillaSentenceTransformersTrainer(dataset, task, prepared_data=None, model=None, seed=None, train_size=1, cross_encoder=False)#
- 参数:
dataset (FeedbackDataset) –
task (TrainingTaskForSentenceSimilarity) –
model (str) –
seed (int) –
train_size (Optional[float]) –
cross_encoder (bool) –
- get_model_card_data(**card_data_kwargs)#
生成用于 ArgillaModelCard 的卡片数据。
- 参数:
card_data_kwargs – 用户在创建 ArgillaTrainer 时提供的额外参数。
- 返回:
用于写入 ArgillaModelCard 的数据的容器。
- 返回类型:
- init_model()#
初始化模型。
- 返回类型:
无
- init_training_args()#
初始化训练参数。
- 返回类型:
无
- predict(text, as_argilla_records=False, **kwargs)#
预测句子的相似度。
- 参数:
text (Union[List[List[str]], Tuple[str, List[str]]]) – 用于获取相似度的句子。允许的输入包括:- 包含单个句子(作为字符串)和要与之比较的句子列表的列表。- 包含句子对的列表。
as_argilla_records (bool) – 如果为 True,则预测将作为 Argilla 记录返回。如果为 False,则预测将作为字符串返回。默认为 True
- 返回:
预测相似度的列表。
- 返回类型:
List[float]
- push_to_huggingface(repo_id, **kwargs)#
将模型上传到 [huggingface 的模型中心](https://hugging-face.cn/models)。
完整的参数列表可以在以下位置查看:[sentence-transformer api 文档](https://www.sbert.net/docs/package_reference/SentenceTransformer.html#sentence_transformers.SentenceTransformer.save_to_hub)。
- 参数:
repo_id (str) – 您想要将模型和 tokenizer 推送到的仓库的名称。当推送到给定的组织时,它应该包含您的组织名称。
- Raises:
NotImplementedError – 对于目前未在底层实现的 CrossEncoder 模型。
- 返回类型:
无
- save(output_dir)#
将模型保存到指定路径。
- 参数:
output_dir (str) –
- 返回类型:
无
- train(output_dir=None)#
训练模型。
- 参数:
output_dir (Optional[str]) –
- 返回类型:
无
- update_config(**kwargs)#
更新训练器的配置,但参数取决于 trainer.subclass。
- 返回类型:
无
- class argilla.client.feedback.integrations.huggingface.model_card.SentenceTransformerCardData(language: Union[str, List[str], NoneType] = None, license: Optional[str] = None, model_name: Optional[str] = None, model_id: Optional[str] = None, dataset_name: Optional[str] = None, dataset_id: Optional[str] = None, tags: Optional[List[str]] = <factory>, model_summary: Optional[str] = None, model_description: Optional[str] = None, developers: Optional[str] = None, shared_by: Optional[str] = None, model_type: Optional[str] = None, finetuned_from: Optional[str] = None, repo: Optional[str] = None, _is_on_huggingface: bool = False, framework: argilla.client.models.Framework = <Framework.SENTENCE_TRANSFORMERS: 'sentence-transformers'>, train_size: Optional[float] = None, seed: Optional[int] = None, framework_kwargs: Dict[str, Any] = <factory>, task: Union[argilla.client.feedback.training.schemas.base.TrainingTaskForTextClassification, argilla.client.feedback.training.schemas.base.TrainingTaskForSFT, argilla.client.feedback.training.schemas.base.TrainingTaskForRM, argilla.client.feedback.training.schemas.base.TrainingTaskForPPO, argilla.client.feedback.training.schemas.base.TrainingTaskForDPO, argilla.client.feedback.training.schemas.base.TrainingTaskForChatCompletion, argilla.client.feedback.training.schemas.base.TrainingTaskForSentenceSimilarity, NoneType] = None, output_dir: Optional[str] = None, library_name: Optional[str] = None, update_config_kwargs: Dict[str, Any] = <factory>, cross_encoder: bool = False, trainer_cls: Optional[Callable] = None)#
- 参数:
language (Optional[Union[str, List[str]]]) –
license (Optional[str]) –
model_name (Optional[str]) –
model_id (Optional[str]) –
dataset_name (Optional[str]) –
dataset_id (Optional[str]) –
tags (Optional[List[str]]) –
model_summary (Optional[str]) –
model_description (Optional[str]) –
developers (Optional[str]) –
shared_by (Optional[str]) –
model_type (Optional[str]) –
finetuned_from (Optional[str]) –
repo (Optional[str]) –
_is_on_huggingface (bool) –
framework (Framework) –
train_size (Optional[float]) –
seed (Optional[int]) –
framework_kwargs (Dict[str, Any]) –
task (Optional[Union[TrainingTaskForTextClassification, TrainingTaskForSFT, TrainingTaskForRM, TrainingTaskForPPO, TrainingTaskForDPO, TrainingTaskForChatCompletion, TrainingTaskForSentenceSimilarity]]) –
output_dir (Optional[str]) –
library_name (Optional[str]) –
update_config_kwargs (Dict[str, Any]) –
cross_encoder (bool) –
trainer_cls (Optional[Callable]) –