训练#

这里我们描述了 Argilla 中可用的训练器

基础训练器#

class argilla.training.base.ArgillaTrainerSkeleton(name, dataset, record_class, workspace=None, multi_label=False, settings=None, model=None, seed=None, *arg, **kwargs)#
参数:
get_model()#

返回模型。

get_model_card_data(card_data_kwargs)#

生成一个 FrameworkCardData 实例以从中生成模型卡片。

参数:

card_data_kwargs (Dict[str, Any]) –

返回类型:

FrameworkCardData

get_model_kwargs()#

返回模型 kwargs。

get_tokenizer()#

返回 tokenizer。

get_trainer()#

返回 trainer。

get_trainer_kwargs()#

返回训练 kwargs。

abstract init_model()#

初始化模型。

abstract init_training_args()#

初始化训练参数。

abstract predict(text, as_argilla_records=True, **kwargs)#

预测文本的标签。

参数:
  • text (Union[List[str], str]) –

  • as_argilla_records (bool) –

push_to_huggingface(repo_id, **kwargs)#

将模型上传到 [Huggingface Hub](https://hugging-face.cn/docs/hub/models-the-hub)。

参数:

repo_id (str) –

返回类型:

Optional[str]

abstract save(output_dir)#

将模型保存到指定路径。

参数:

output_dir (str) –

abstract train(output_dir=None)#

训练模型。

参数:

output_dir (Optional[str]) –

abstract update_config(*args, **kwargs)#

更新训练器的配置,但参数取决于 trainer.subclass。

class argilla.client.feedback.integrations.huggingface.model_card.FrameworkCardData(*args, **kwargs)#

用于生成要添加到 ModelCard 的变量的父类。

每个框架将从此继承并进行相应的更新。

参数:
  • language (Optional[Union[str, List[str]]]) –

  • license (Optional[str]) –

  • model_name (Optional[str]) –

  • model_id (Optional[str]) –

  • dataset_name (Optional[str]) –

  • dataset_id (Optional[str]) –

  • tags (Optional[List[str]]) –

  • model_summary (Optional[str]) –

  • model_description (Optional[str]) –

  • developers (Optional[str]) –

  • shared_by (Optional[str]) –

  • model_type (Optional[str]) –

  • finetuned_from (Optional[str]) –

  • repo (Optional[str]) –

  • _is_on_huggingface (bool) –

  • framework (Optional[Framework]) –

  • train_size (Optional[float]) –

  • seed (Optional[int]) –

  • framework_kwargs (Dict[str, Any]) –

  • task (Optional[Union[TrainingTaskForTextClassification, TrainingTaskForSFT, TrainingTaskForRM, TrainingTaskForPPO, TrainingTaskForDPO, TrainingTaskForChatCompletion, TrainingTaskForSentenceSimilarity]]) –

  • output_dir (Optional[str]) –

  • library_name (Optional[str]) –

  • update_config_kwargs (Dict[str, Any]) –

to_dict()#

生成将写入模型卡片的变量的主要方法。

返回类型:

Dict[str, Any]

SetFit 训练器#

class argilla.training.setfit.ArgillaSetFitTrainer(*args, **kwargs)#
init_model()#

初始化模型。

init_training_args()#

初始化训练参数。

返回类型:

predict(text, as_argilla_records=True, **kwargs)#

该函数接受字符串列表并返回预测列表

参数:
  • text (Union[List[str], str]) – 要分类的文本。

  • as_argilla_records (bool) – 如果为 True,则预测将作为 Argilla 记录返回。如果

为 False,则预测将作为字符串返回。默认为 True

返回:

预测列表

参数:
  • text (Union[List[str], str]) –

  • as_argilla_records (bool) –

save(output_dir)#

该函数将模型保存到指定的路径,并将 label2id 和 id2label 字典保存到同一路径

参数:
  • path (str) – 保存模型的路径

  • output_dir (str) –

train(output_dir=None)#

我们从预训练模型创建一个 SetFitModel 对象,然后使用该模型创建一个 SetFitTrainer 对象,然后训练该模型

参数:

output_dir (Optional[str]) –

update_config(**kwargs)#

使用传递给 update_config 函数的关键字参数更新 model_kwargstrainer_kwargs 字典。

返回类型:

class argilla.client.feedback.integrations.huggingface.model_card.SetFitModelCardData(language: Union[str, List[str], NoneType] = None, license: Optional[str] = None, model_name: Optional[str] = None, model_id: Optional[str] = None, dataset_name: Optional[str] = None, dataset_id: Optional[str] = None, tags: Optional[List[str]] = <factory>, model_summary: Optional[str] = None, model_description: Optional[str] = None, developers: Optional[str] = None, shared_by: Optional[str] = None, model_type: Optional[str] = None, finetuned_from: Optional[str] = None, repo: Optional[str] = None, _is_on_huggingface: bool = False, framework: argilla.client.models.Framework = <Framework.SETFIT: 'setfit'>, train_size: Optional[float] = None, seed: Optional[int] = None, framework_kwargs: Dict[str, Any] = <factory>, task: Union[argilla.client.feedback.training.schemas.base.TrainingTaskForTextClassification, argilla.client.feedback.training.schemas.base.TrainingTaskForSFT, argilla.client.feedback.training.schemas.base.TrainingTaskForRM, argilla.client.feedback.training.schemas.base.TrainingTaskForPPO, argilla.client.feedback.training.schemas.base.TrainingTaskForDPO, argilla.client.feedback.training.schemas.base.TrainingTaskForChatCompletion, argilla.client.feedback.training.schemas.base.TrainingTaskForSentenceSimilarity, NoneType] = None, output_dir: Optional[str] = None, library_name: Optional[str] = None, update_config_kwargs: Dict[str, Any] = <factory>, tokenizer: 'PreTrainedTokenizer' = '')#
参数:
  • language (Optional[Union[str, List[str]]]) –

  • license (Optional[str]) –

  • model_name (Optional[str]) –

  • model_id (Optional[str]) –

  • dataset_name (Optional[str]) –

  • dataset_id (Optional[str]) –

  • tags (Optional[List[str]]) –

  • model_summary (Optional[str]) –

  • model_description (Optional[str]) –

  • developers (Optional[str]) –

  • shared_by (Optional[str]) –

  • model_type (Optional[str]) –

  • finetuned_from (Optional[str]) –

  • repo (Optional[str]) –

  • _is_on_huggingface (bool) –

  • framework (Framework) –

  • train_size (Optional[float]) –

  • seed (Optional[int]) –

  • framework_kwargs (Dict[str, Any]) –

  • task (Optional[Union[TrainingTaskForTextClassification, TrainingTaskForSFT, TrainingTaskForRM, TrainingTaskForPPO, TrainingTaskForDPO, TrainingTaskForChatCompletion, TrainingTaskForSentenceSimilarity]]) –

  • output_dir (Optional[str]) –

  • library_name (Optional[str]) –

  • update_config_kwargs (Dict[str, Any]) –

  • tokenizer (PreTrainedTokenizer) –

OpenAI 训练器#

class argilla.training.openai.ArgillaOpenAITrainer(*args, **kwargs)#
init_model()#

初始化模型。

返回类型:

init_training_args(training_file=None, validation_file=None, model='curie', n_epochs=None, batch_size=None, learning_rate_multiplier=0.1, prompt_loss_weight=0.1, compute_classification_metrics=False, classification_n_classes=None, classification_positive_class=None, classification_betas=None, suffix=None, hyperparameters=None)#

初始化训练参数。

参数:
  • training_file (Optional[str]) –

  • validation_file (Optional[str]) –

  • model (str) –

  • n_epochs (Optional[int]) –

  • batch_size (Optional[int]) –

  • learning_rate_multiplier (float) –

  • prompt_loss_weight (float) –

  • compute_classification_metrics (bool) –

  • classification_n_classes (Optional[int]) –

  • classification_positive_class (Optional[str]) –

  • classification_betas (Optional[list]) –

  • suffix (Optional[str]) –

  • hyperparameters (Optional[dict]) –

返回类型:

predict(text, as_argilla_records=True, **kwargs)#

该函数接受字符串列表并返回预测列表

参数:
  • text (Union[List[str], str]) – 要分类的文本。

  • as_argilla_records (bool) – 如果为 True,则预测将作为 Argilla 记录返回。如果

返回类型:

Union[List, str]

为 False,则预测将作为字符串返回。默认为 True

返回:

预测列表

参数:
  • text (Union[List[str], str]) –

  • as_argilla_records (bool) –

返回类型:

Union[List, str]

save(*arg, **kwargs)#

该函数将模型保存到指定的路径,并将 label2id 和 id2label 字典保存到同一路径

参数:

output_dir (str) – 保存模型的路径

返回类型:

train(output_dir=None)#

我们从预训练模型创建一个 openai.FineTune 对象,并将数据发送到微调它。

参数:

output_dir (Optional[str]) –

返回类型:

update_config(**kwargs)#

使用传递给 update_config 函数的关键字参数更新 model_kwargs 字典。

class argilla.client.feedback.integrations.huggingface.model_card.OpenAIModelCardData(language: Union[str, List[str], NoneType] = None, license: Optional[str] = None, model_name: Optional[str] = None, model_id: Optional[str] = None, dataset_name: Optional[str] = None, dataset_id: Optional[str] = None, tags: Optional[List[str]] = <factory>, model_summary: Optional[str] = None, model_description: Optional[str] = None, developers: Optional[str] = None, shared_by: Optional[str] = None, model_type: Optional[str] = None, finetuned_from: Optional[str] = None, repo: Optional[str] = None, _is_on_huggingface: bool = False, framework: argilla.client.models.Framework = <Framework.OPENAI: 'openai'>, train_size: Optional[float] = None, seed: Optional[int] = None, framework_kwargs: Dict[str, Any] = <factory>, task: Union[argilla.client.feedback.training.schemas.base.TrainingTaskForTextClassification, argilla.client.feedback.training.schemas.base.TrainingTaskForSFT, argilla.client.feedback.training.schemas.base.TrainingTaskForRM, argilla.client.feedback.training.schemas.base.TrainingTaskForPPO, argilla.client.feedback.training.schemas.base.TrainingTaskForDPO, argilla.client.feedback.training.schemas.base.TrainingTaskForChatCompletion, argilla.client.feedback.training.schemas.base.TrainingTaskForSentenceSimilarity, NoneType] = None, output_dir: Optional[str] = None, library_name: Optional[str] = None, update_config_kwargs: Dict[str, Any] = <factory>)#
参数:
  • language (Optional[Union[str, List[str]]]) –

  • license (Optional[str]) –

  • model_name (Optional[str]) –

  • model_id (Optional[str]) –

  • dataset_name (Optional[str]) –

  • dataset_id (Optional[str]) –

  • tags (Optional[List[str]]) –

  • model_summary (Optional[str]) –

  • model_description (Optional[str]) –

  • developers (Optional[str]) –

  • shared_by (Optional[str]) –

  • model_type (Optional[str]) –

  • finetuned_from (Optional[str]) –

  • repo (Optional[str]) –

  • _is_on_huggingface (bool) –

  • framework (Framework) –

  • train_size (Optional[float]) –

  • seed (Optional[int]) –

  • framework_kwargs (Dict[str, Any]) –

  • task (Optional[Union[TrainingTaskForTextClassification, TrainingTaskForSFT, TrainingTaskForRM, TrainingTaskForPPO, TrainingTaskForDPO, TrainingTaskForChatCompletion, TrainingTaskForSentenceSimilarity]]) –

  • output_dir (Optional[str]) –

  • library_name (Optional[str]) –

  • update_config_kwargs (Dict[str, Any]) –

PEFT (LoRA) 训练器#

class argilla.training.peft.ArgillaPeftTrainer(*args, **kwargs)#
init_model(new=False)#

初始化模型。

参数:

new (bool) –

init_training_args()#

初始化训练参数。

predict(text, as_argilla_records=True, **kwargs)#

该函数接受字符串列表并返回预测列表

参数:
  • text (Union[List[str], str]) – 要分类的文本。

  • as_argilla_records (bool) – 如果为 True,则预测将作为 Argilla 记录返回。如果

为 False,则预测将作为字符串返回。默认为 True

返回:

预测列表

参数:
  • text (Union[List[str], str]) –

  • as_argilla_records (bool) –

save(output_dir)#

该函数将模型保存到指定的路径,并将 label2id 和 id2label 字典保存到同一路径

参数:

output_dir (str) – 保存模型的路径

sys = <module 'sys' (built-in)>#
update_config(**kwargs)#

使用传递给 update_config 函数的关键字参数更新 model_kwargstrainer_kwargs 字典。

class argilla.client.feedback.integrations.huggingface.model_card.PeftModelCardData(language: Union[str, List[str], NoneType] = None, license: Optional[str] = None, model_name: Optional[str] = None, model_id: Optional[str] = None, dataset_name: Optional[str] = None, dataset_id: Optional[str] = None, tags: Optional[List[str]] = <factory>, model_summary: Optional[str] = None, model_description: Optional[str] = None, developers: Optional[str] = None, shared_by: Optional[str] = None, model_type: Optional[str] = None, finetuned_from: Optional[str] = None, repo: Optional[str] = None, _is_on_huggingface: bool = False, framework: argilla.client.models.Framework = <Framework.PEFT: 'peft'>, train_size: Optional[float] = None, seed: Optional[int] = None, framework_kwargs: Dict[str, Any] = <factory>, task: Union[argilla.client.feedback.training.schemas.base.TrainingTaskForTextClassification, argilla.client.feedback.training.schemas.base.TrainingTaskForSFT, argilla.client.feedback.training.schemas.base.TrainingTaskForRM, argilla.client.feedback.training.schemas.base.TrainingTaskForPPO, argilla.client.feedback.training.schemas.base.TrainingTaskForDPO, argilla.client.feedback.training.schemas.base.TrainingTaskForChatCompletion, argilla.client.feedback.training.schemas.base.TrainingTaskForSentenceSimilarity, NoneType] = None, output_dir: Optional[str] = None, library_name: Optional[str] = None, update_config_kwargs: Dict[str, Any] = <factory>, tokenizer: 'PreTrainedTokenizer' = '')#
参数:
  • language (Optional[Union[str, List[str]]]) –

  • license (Optional[str]) –

  • model_name (Optional[str]) –

  • model_id (Optional[str]) –

  • dataset_name (Optional[str]) –

  • dataset_id (Optional[str]) –

  • tags (Optional[List[str]]) –

  • model_summary (Optional[str]) –

  • model_description (Optional[str]) –

  • developers (Optional[str]) –

  • shared_by (Optional[str]) –

  • model_type (Optional[str]) –

  • finetuned_from (Optional[str]) –

  • repo (Optional[str]) –

  • _is_on_huggingface (bool) –

  • framework (Framework) –

  • train_size (Optional[float]) –

  • seed (Optional[int]) –

  • framework_kwargs (Dict[str, Any]) –

  • task (Optional[Union[TrainingTaskForTextClassification, TrainingTaskForSFT, TrainingTaskForRM, TrainingTaskForPPO, TrainingTaskForDPO, TrainingTaskForChatCompletion, TrainingTaskForSentenceSimilarity]]) –

  • output_dir (Optional[str]) –

  • library_name (Optional[str]) –

  • update_config_kwargs (Dict[str, Any]) –

  • tokenizer (PreTrainedTokenizer) –

spaCy 训练器#

class argilla.training.spacy.ArgillaSpaCyTrainer(freeze_tok2vec=False, **kwargs)#
参数:

freeze_tok2vec (bool) –

init_training_args()#

此方法用于生成 spacy 配置文件,该文件用于训练

返回类型:

class argilla.training.spacy.ArgillaSpaCyTransformersTrainer(update_transformer=True, **kwargs)#
参数:

update_transformer (bool) –

init_training_args()#

此方法用于生成 spacy 配置文件,该文件用于训练

返回类型:

class argilla.training.spacy._ArgillaSpaCyTrainerBase(language=None, gpu_id=-1, model=None, optimize='efficiency', *args, **kwargs)#
参数:
  • language (Optional[str]) –

  • gpu_id (Optional[int]) –

  • model (Optional[str]) –

  • optimize (Literal['efficiency', 'accuracy']) –

init_model()#

初始化模型。

predict(text, as_argilla_records=True, **kwargs)#

使用经过训练的 pipeline 预测给定文本的标签。

参数:
  • text (Union[List[str], str]) – 一个 str 或一个 List[str],包含要预测标签的文本。

  • as_argilla_records (bool) – 一个 bool,指示是否将预测结果作为 argilla 记录或 dicts 返回。默认为 True。

返回:

一个 dictBaseModel(如果 as_argilla_records 为 True)或一个 List[dict]List[BaseModel](如果 as_argilla_records 为 True),包含预测结果。

返回类型:

Union[Dict[str, Any], List[Dict[str, Any]], BaseModel, List[BaseModel]]

save(output_dir)#

将训练后的 pipeline 保存到磁盘。

参数:

output_dir (str) – 一个 str,包含保存 pipeline 的路径。

返回类型:

train(output_dir=None)#

使用 spaCy 训练 pipeline。

参数:

output_dir (Optional[str]) – 一个 str,包含保存训练后的 pipeline 的路径。默认为 None。

返回类型:

update_config(**spacy_training_config)#

更新 spaCy 训练配置。

免责声明:目前仅支持 training 配置,但未来我们将支持所有的 spaCy 配置值,以便更精确地控制训练过程。另请注意,CPU 和 GPU 训练之间的参数可能有所不同。

参数:

**spacy_training_configspaCy 训练配置。

返回类型:

class argilla.client.feedback.integrations.huggingface.model_card.SpacyTransformersModelCardData(language: Union[str, List[str], NoneType] = None, license: Optional[str] = None, model_name: Optional[str] = None, model_id: Optional[str] = None, dataset_name: Optional[str] = None, dataset_id: Optional[str] = None, tags: Optional[List[str]] = <factory>, model_summary: Optional[str] = None, model_description: Optional[str] = None, developers: Optional[str] = None, shared_by: Optional[str] = None, model_type: Optional[str] = None, finetuned_from: Optional[str] = None, repo: Optional[str] = None, _is_on_huggingface: bool = False, framework: argilla.client.models.Framework = <Framework.SPACY_TRANSFORMERS: 'spacy-transformers'>, train_size: Optional[float] = None, seed: Optional[int] = None, framework_kwargs: Dict[str, Any] = <factory>, task: Union[argilla.client.feedback.training.schemas.base.TrainingTaskForTextClassification, argilla.client.feedback.training.schemas.base.TrainingTaskForSFT, argilla.client.feedback.training.schemas.base.TrainingTaskForRM, argilla.client.feedback.training.schemas.base.TrainingTaskForPPO, argilla.client.feedback.training.schemas.base.TrainingTaskForDPO, argilla.client.feedback.training.schemas.base.TrainingTaskForChatCompletion, argilla.client.feedback.training.schemas.base.TrainingTaskForSentenceSimilarity, NoneType] = None, output_dir: Optional[str] = None, library_name: Optional[str] = None, update_config_kwargs: Dict[str, Any] = <factory>, lang: Optional[ForwardRef('spacy.Language')] = None, gpu_id: Optional[int] = -1, optimize: Literal['efficiency', 'accuracy'] = 'efficiency', pipeline: List[str] = <factory>, update_transformer: bool = True)#
参数:
  • language (Optional[Union[str, List[str]]]) –

  • license (Optional[str]) –

  • model_name (Optional[str]) –

  • model_id (Optional[str]) –

  • dataset_name (Optional[str]) –

  • dataset_id (Optional[str]) –

  • tags (Optional[List[str]]) –

  • model_summary (Optional[str]) –

  • model_description (Optional[str]) –

  • developers (Optional[str]) –

  • shared_by (Optional[str]) –

  • model_type (Optional[str]) –

  • finetuned_from (Optional[str]) –

  • repo (Optional[str]) –

  • _is_on_huggingface (bool) –

  • framework (Framework) –

  • train_size (Optional[float]) –

  • seed (Optional[int]) –

  • framework_kwargs (Dict[str, Any]) –

  • task (Optional[Union[TrainingTaskForTextClassification, TrainingTaskForSFT, TrainingTaskForRM, TrainingTaskForPPO, TrainingTaskForDPO, TrainingTaskForChatCompletion, TrainingTaskForSentenceSimilarity]]) –

  • output_dir (Optional[str]) –

  • library_name (Optional[str]) –

  • update_config_kwargs (Dict[str, Any]) –

  • lang (Optional[spacy.Language]) –

  • gpu_id (Optional[int]) –

  • optimize (Literal['efficiency', 'accuracy']) –

  • pipeline (List[str]) –

  • update_transformer (bool) –

Transformers 训练器#

class argilla.training.transformers.ArgillaTransformersTrainer(*args, **kwargs)#
init_model(new=False)#

初始化模型。

参数:

new (bool) –

init_training_args()#

初始化训练参数。

predict(text, as_argilla_records=True, **kwargs)#

该函数接受字符串列表并返回预测列表

参数:
  • text (Union[List[str], str]) – 要分类的文本。

  • as_argilla_records (bool) – 如果为 True,则预测将作为 Argilla 记录返回。如果

为 False,则预测将作为字符串返回。默认为 True

返回:

预测列表

参数:
  • text (Union[List[str], str]) –

  • as_argilla_records (bool) –

save(output_dir)#

该函数将模型保存到指定的路径,并将 label2id 和 id2label 字典保存到同一路径

参数:

output_dir (str) – 保存模型的路径

train(output_dir)#

训练模型。

参数:

output_dir (str) –

update_config(**kwargs)#

使用传递给 update_config 函数的关键字参数更新 setfit_model_kwargssetfit_trainer_kwargs 字典。

class argilla.client.feedback.integrations.huggingface.model_card.TransformersModelCardData(language: Union[str, List[str], NoneType] = None, license: Optional[str] = None, model_name: Optional[str] = None, model_id: Optional[str] = None, dataset_name: Optional[str] = None, dataset_id: Optional[str] = None, tags: Optional[List[str]] = <factory>, model_summary: Optional[str] = None, model_description: Optional[str] = None, developers: Optional[str] = None, shared_by: Optional[str] = None, model_type: Optional[str] = None, finetuned_from: Optional[str] = None, repo: Optional[str] = None, _is_on_huggingface: bool = False, framework: argilla.client.models.Framework = <Framework.TRANSFORMERS: 'transformers'>, train_size: Optional[float] = None, seed: Optional[int] = None, framework_kwargs: Dict[str, Any] = <factory>, task: Union[argilla.client.feedback.training.schemas.base.TrainingTaskForTextClassification, argilla.client.feedback.training.schemas.base.TrainingTaskForSFT, argilla.client.feedback.training.schemas.base.TrainingTaskForRM, argilla.client.feedback.training.schemas.base.TrainingTaskForPPO, argilla.client.feedback.training.schemas.base.TrainingTaskForDPO, argilla.client.feedback.training.schemas.base.TrainingTaskForChatCompletion, argilla.client.feedback.training.schemas.base.TrainingTaskForSentenceSimilarity, NoneType] = None, output_dir: Optional[str] = None, library_name: Optional[str] = None, update_config_kwargs: Dict[str, Any] = <factory>, tokenizer: 'PreTrainedTokenizer' = '')#
参数:
  • language (Optional[Union[str, List[str]]]) –

  • license (Optional[str]) –

  • model_name (Optional[str]) –

  • model_id (Optional[str]) –

  • dataset_name (Optional[str]) –

  • dataset_id (Optional[str]) –

  • tags (Optional[List[str]]) –

  • model_summary (Optional[str]) –

  • model_description (Optional[str]) –

  • developers (Optional[str]) –

  • shared_by (Optional[str]) –

  • model_type (Optional[str]) –

  • finetuned_from (Optional[str]) –

  • repo (Optional[str]) –

  • _is_on_huggingface (bool) –

  • framework (Framework) –

  • train_size (Optional[float]) –

  • seed (Optional[int]) –

  • framework_kwargs (Dict[str, Any]) –

  • task (Optional[Union[TrainingTaskForTextClassification, TrainingTaskForSFT, TrainingTaskForRM, TrainingTaskForPPO, TrainingTaskForDPO, TrainingTaskForChatCompletion, TrainingTaskForSentenceSimilarity]]) –

  • output_dir (Optional[str]) –

  • library_name (Optional[str]) –

  • update_config_kwargs (Dict[str, Any]) –

  • tokenizer (PreTrainedTokenizer) –

SpanMarker 训练器#

class argilla.training.span_marker.ArgillaSpanMarkerTrainer(*args, **kwargs)#
init_model()#

初始化模型。

返回类型:

init_training_args()#

初始化训练参数。

返回类型:

predict(text, as_argilla_records=True, **kwargs)#

该函数接受字符串列表并返回预测列表

参数:
  • text (Union[List[str], str]) – 要分类的文本。

  • as_argilla_records (bool) – 如果为 True,则预测将作为 Argilla 记录返回。如果

为 False,则预测将作为字符串返回。默认为 True

返回:

预测列表

参数:
  • text (Union[List[str], str]) –

  • as_argilla_records (bool) –

save(output_dir)#

该函数将模型保存到指定的路径,并将 label2id 和 id2label 字典保存到同一路径

参数:

output_dir (str) – 保存模型的路径

train(output_dir)#

我们从预训练模型创建一个 SetFitModel 对象,然后使用该模型创建一个 SetFitTrainer 对象,然后训练该模型

参数:

output_dir (str) –

update_config(**kwargs)#

使用传递给 update_config 函数的关键字参数更新 model_kwargstrainer_kwargs 字典。

返回类型:

TRL 训练器#

class argilla.client.feedback.integrations.huggingface.model_card.TRLModelCardData(language: Union[str, List[str], NoneType] = None, license: Optional[str] = None, model_name: Optional[str] = None, model_id: Optional[str] = None, dataset_name: Optional[str] = None, dataset_id: Optional[str] = None, tags: Optional[List[str]] = <factory>, model_summary: Optional[str] = None, model_description: Optional[str] = None, developers: Optional[str] = None, shared_by: Optional[str] = None, model_type: Optional[str] = None, finetuned_from: Optional[str] = None, repo: Optional[str] = None, _is_on_huggingface: bool = False, framework: argilla.client.models.Framework = <Framework.TRL: 'trl'>, train_size: Optional[float] = None, seed: Optional[int] = None, framework_kwargs: Dict[str, Any] = <factory>, task: Union[argilla.client.feedback.training.schemas.base.TrainingTaskForTextClassification, argilla.client.feedback.training.schemas.base.TrainingTaskForSFT, argilla.client.feedback.training.schemas.base.TrainingTaskForRM, argilla.client.feedback.training.schemas.base.TrainingTaskForPPO, argilla.client.feedback.training.schemas.base.TrainingTaskForDPO, argilla.client.feedback.training.schemas.base.TrainingTaskForChatCompletion, argilla.client.feedback.training.schemas.base.TrainingTaskForSentenceSimilarity, NoneType] = None, output_dir: Optional[str] = None, library_name: Optional[str] = None, update_config_kwargs: Dict[str, Any] = <factory>)#
参数:
  • language (Optional[Union[str, List[str]]]) –

  • license (Optional[str]) –

  • model_name (Optional[str]) –

  • model_id (Optional[str]) –

  • dataset_name (Optional[str]) –

  • dataset_id (Optional[str]) –

  • tags (Optional[List[str]]) –

  • model_summary (Optional[str]) –

  • model_description (Optional[str]) –

  • developers (Optional[str]) –

  • shared_by (Optional[str]) –

  • model_type (Optional[str]) –

  • finetuned_from (Optional[str]) –

  • repo (Optional[str]) –

  • _is_on_huggingface (bool) –

  • framework (Framework) –

  • train_size (Optional[float]) –

  • seed (Optional[int]) –

  • framework_kwargs (Dict[str, Any]) –

  • task (Optional[Union[TrainingTaskForTextClassification, TrainingTaskForSFT, TrainingTaskForRM, TrainingTaskForPPO, TrainingTaskForDPO, TrainingTaskForChatCompletion, TrainingTaskForSentenceSimilarity]]) –

  • output_dir (Optional[str]) –

  • library_name (Optional[str]) –

  • update_config_kwargs (Dict[str, Any]) –

SentenceTransformer 训练器#

class argilla.client.feedback.training.frameworks.sentence_transformers.ArgillaSentenceTransformersTrainer(dataset, task, prepared_data=None, model=None, seed=None, train_size=1, cross_encoder=False)#
参数:
  • dataset (FeedbackDataset) –

  • task (TrainingTaskForSentenceSimilarity) –

  • model (str) –

  • seed (int) –

  • train_size (Optional[float]) –

  • cross_encoder (bool) –

get_model_card_data(**card_data_kwargs)#

生成用于 ArgillaModelCard 的卡片数据。

参数:

card_data_kwargs – 用户在创建 ArgillaTrainer 时提供的额外参数。

返回:

用于写入 ArgillaModelCard 的数据的容器。

返回类型:

SentenceTransformerCardData

init_model()#

初始化模型。

返回类型:

init_training_args()#

初始化训练参数。

返回类型:

predict(text, as_argilla_records=False, **kwargs)#

预测句子的相似度。

参数:
  • text (Union[List[List[str]], Tuple[str, List[str]]]) – 用于获取相似度的句子。允许的输入包括:- 包含单个句子(作为字符串)和要与之比较的句子列表的列表。- 包含句子对的列表。

  • as_argilla_records (bool) – 如果为 True,则预测将作为 Argilla 记录返回。如果为 False,则预测将作为字符串返回。默认为 True

返回:

预测相似度的列表。

返回类型:

List[float]

push_to_huggingface(repo_id, **kwargs)#

将模型上传到 [huggingface 的模型中心](https://hugging-face.cn/models)。

完整的参数列表可以在以下位置查看:[sentence-transformer api 文档](https://www.sbert.net/docs/package_reference/SentenceTransformer.html#sentence_transformers.SentenceTransformer.save_to_hub)。

参数:

repo_id (str) – 您想要将模型和 tokenizer 推送到的仓库的名称。当推送到给定的组织时,它应该包含您的组织名称。

Raises:

NotImplementedError – 对于目前未在底层实现的 CrossEncoder 模型。

返回类型:

save(output_dir)#

将模型保存到指定路径。

参数:

output_dir (str) –

返回类型:

train(output_dir=None)#

训练模型。

参数:

output_dir (Optional[str]) –

返回类型:

update_config(**kwargs)#

更新训练器的配置,但参数取决于 trainer.subclass。

返回类型:

class argilla.client.feedback.integrations.huggingface.model_card.SentenceTransformerCardData(language: Union[str, List[str], NoneType] = None, license: Optional[str] = None, model_name: Optional[str] = None, model_id: Optional[str] = None, dataset_name: Optional[str] = None, dataset_id: Optional[str] = None, tags: Optional[List[str]] = <factory>, model_summary: Optional[str] = None, model_description: Optional[str] = None, developers: Optional[str] = None, shared_by: Optional[str] = None, model_type: Optional[str] = None, finetuned_from: Optional[str] = None, repo: Optional[str] = None, _is_on_huggingface: bool = False, framework: argilla.client.models.Framework = <Framework.SENTENCE_TRANSFORMERS: 'sentence-transformers'>, train_size: Optional[float] = None, seed: Optional[int] = None, framework_kwargs: Dict[str, Any] = <factory>, task: Union[argilla.client.feedback.training.schemas.base.TrainingTaskForTextClassification, argilla.client.feedback.training.schemas.base.TrainingTaskForSFT, argilla.client.feedback.training.schemas.base.TrainingTaskForRM, argilla.client.feedback.training.schemas.base.TrainingTaskForPPO, argilla.client.feedback.training.schemas.base.TrainingTaskForDPO, argilla.client.feedback.training.schemas.base.TrainingTaskForChatCompletion, argilla.client.feedback.training.schemas.base.TrainingTaskForSentenceSimilarity, NoneType] = None, output_dir: Optional[str] = None, library_name: Optional[str] = None, update_config_kwargs: Dict[str, Any] = <factory>, cross_encoder: bool = False, trainer_cls: Optional[Callable] = None)#
参数:
  • language (Optional[Union[str, List[str]]]) –

  • license (Optional[str]) –

  • model_name (Optional[str]) –

  • model_id (Optional[str]) –

  • dataset_name (Optional[str]) –

  • dataset_id (Optional[str]) –

  • tags (Optional[List[str]]) –

  • model_summary (Optional[str]) –

  • model_description (Optional[str]) –

  • developers (Optional[str]) –

  • shared_by (Optional[str]) –

  • model_type (Optional[str]) –

  • finetuned_from (Optional[str]) –

  • repo (Optional[str]) –

  • _is_on_huggingface (bool) –

  • framework (Framework) –

  • train_size (Optional[float]) –

  • seed (Optional[int]) –

  • framework_kwargs (Dict[str, Any]) –

  • task (Optional[Union[TrainingTaskForTextClassification, TrainingTaskForSFT, TrainingTaskForRM, TrainingTaskForPPO, TrainingTaskForDPO, TrainingTaskForChatCompletion, TrainingTaskForSentenceSimilarity]]) –

  • output_dir (Optional[str]) –

  • library_name (Optional[str]) –

  • update_config_kwargs (Dict[str, Any]) –

  • cross_encoder (bool) –

  • trainer_cls (Optional[Callable]) –