栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

知识图谱:【知识图谱问答KBQA(六)】——P-tuning V2训练代码解析

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

知识图谱:【知识图谱问答KBQA(六)】——P-tuning V2训练代码解析

文章目录

一.arguments.py

DataTrainingArguments类ModelArguments类QuestionAnwseringArguments类get_args()函数 二.run.py

Step 1. 获取所有参数Step 2. 根据任务名称选择导入对应的get_trainerStep 3. 将参数args传入get_trainer,得到trainer

1)根据模型名称或路径加载tokenizer2)根据tokenizer和参数data_args、training_args加载数据集dataset3)根据模型名称或路径、dataset加载模型配置config4)根据模型参数和模型配置加载模型(get_model)5)根据model、训练参数、tokenizer以及dataset初始化并返回trainer Step 4. 模型训练、验证及测试

一.arguments.py DataTrainingArguments类

关于我们将输入模型进行训练和评估的数据参数

task_name.任务名称dataset_name.数据集名称dataset_config_name.要使用的数据集的配置名称max_seq_length.标记化(tokenization)后的最大总输入序列长度。 比这长的序列将被截断,短的序列将被填充。overwrite_cache.是否覆盖缓存的预处理数据集pad_to_max_length.是否将所有样本填充到 max_seq_length。 如果为 False,将在批处理时动态填充样本到批处理中的最大长度max_train_samples.出于调试目的或更快的训练,将训练示例的数量截断为该值(如果已设置)max_eval_samples.出于调试目的或更快的训练,将验证示例的数量截断为该值(如果已设置)max_predict_samples.出于调试目的或更快的训练,将测试示例的数量截断为该值(如果已设置)train_file.包含训练数据的 csv 或 json 文件validation_file.包含验证数据的 csv 或 json 文件test_file.包含测试数据的 csv 或 json 文件template_id.要使用的特定prompt字符串 ModelArguments类

关于我们将从哪个模型/配置/标记器进行微调的参数

model_name_or_path.从 huggingface.co/models 下载预训练模型的路径或模型标识符config_name.如果与 model_name 不同,则需指定预训练的配置名称或路径tokenizer_name.如果与 model_name 不同,则需指定预训练的标记器名称或路径cache_dir.用于存储从 huggingface.co 下载的预训练模型的路径use_fast_tokenizer.是否使用快速分词器之一(由分词器库支持)model_revision.要使用的特定模型版本(可以是分支名称、标签名称或提交 ID)use_auth_token.是否使用模型加密prefix.训练时使用P-Tuning V2prompt.训练时使用P-Tuningpre_seq_len.prompt的长度prefix_projection.在前缀嵌入上应用两层 MLP 头prefix_hidden_size.如果使用前缀投影,则前缀编码器中 MLP 投影头的隐藏层大小hidden_dropout_prob.dropout比例 QuestionAnwseringArguments类

n_best_size.寻找答案时生成的 n 最佳预测的总数max_answer_length.可以生成的答案的最大长度version_2_with_negative.如果为真,有些例子没有答案null_score_diff_threshold.用于选择空答案的阈值:如果最佳答案的分数小于空答案的分数减去此阈值,则本示例选择空答案。 仅在 version_2_with_negative=True 时有用 get_args()函数

用于解析P-Tuning V2中的所有参数。

parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments, QuestionAnwseringArguments))
args = parser.parse_args_into_dataclasses()
return args
二.run.py Step 1. 获取所有参数
args = get_args()
_, data_args, training_args, _ = args
Step 2. 根据任务名称选择导入对应的get_trainer Step 3. 将参数args传入get_trainer,得到trainer 1)根据模型名称或路径加载tokenizer
tokenizer = AutoTokenizer.from_pretrained(
        model_args.model_name_or_path,
        use_fast=model_args.use_fast_tokenizer,
        revision=model_args.model_revision,
    )
2)根据tokenizer和参数data_args、training_args加载数据集dataset 3)根据模型名称或路径、dataset加载模型配置config
 config = AutoConfig.from_pretrained(
            model_args.model_name_or_path,
            num_labels=dataset.num_labels,
            label2id=dataset.label2id,
            id2label=dataset.id2label,
            finetuning_task=data_args.dataset_name,
            revision=model_args.model_revision,
        )
4)根据模型参数和模型配置加载模型(get_model)

通过模型参数可以选择三种不同的训练方式:

训练方式1:P-Tuning V2(prefix=True)

    if model_args.prefix:
        config.hidden_dropout_prob = model_args.hidden_dropout_prob
        config.pre_seq_len = model_args.pre_seq_len
        config.prefix_projection = model_args.prefix_projection
        config.prefix_hidden_size = model_args.prefix_hidden_size
        
        model_class = PREFIX_MODELS[config.model_type][task_type]
        model = model_class.from_pretrained(
            model_args.model_name_or_path,
            config=config,
            revision=model_args.model_revision,
        )

训练方式2:P-Tuning(prefix=False && prompt=True)

    elif model_args.prompt:
        config.pre_seq_len = model_args.pre_seq_len
        model_class = prompt_MODELS[config.model_type][task_type]
        model = model_class.from_pretrained(
            model_args.model_name_or_path,
            config=config,
            revision=model_args.model_revision,
        )

训练方式3:fine-tuning(prefix=False && prompt=False)

    else:
        model_class = AUTO_MODELS[task_type]
        model = model_class.from_pretrained(
            model_args.model_name_or_path,
            config=config,
            revision=model_args.model_revision,
        )

        bert_param = 0
        if fix_bert:
            if config.model_type == "bert":
                for param in model.bert.parameters():
                    param.requires_grad = False
                for _, param in model.bert.named_parameters():
                    bert_param += param.numel()
            elif config.model_type == "roberta":
                for param in model.roberta.parameters():
                    param.requires_grad = False
                for _, param in model.roberta.named_parameters():
                    bert_param += param.numel()
            elif config.model_type == "deberta":
                for param in model.deberta.parameters():
                    param.requires_grad = False
                for _, param in model.deberta.named_parameters():
                    bert_param += param.numel()
        all_param = 0
        for _, param in model.named_parameters():
            all_param += param.numel()
        total_param = all_param - bert_param
        print('***** total param is {} *****'.format(total_param))
5)根据model、训练参数、tokenizer以及dataset初始化并返回trainer
# Initialize our Trainer
    trainer = baseTrainer(
        model=model,
        args=training_args,
        train_dataset=dataset.train_dataset if training_args.do_train else None,
        eval_dataset=dataset.eval_dataset if training_args.do_eval else None,
        compute_metrics=dataset.compute_metrics,
        tokenizer=tokenizer,
        data_collator=dataset.data_collator,
        test_key=dataset.test_key
    )


    return trainer, None
Step 4. 模型训练、验证及测试
    if training_args.do_train:
        train(trainer, training_args.resume_from_checkpoint, last_checkpoint)
    
    if training_args.do_eval:
        evaluate(trainer)

    if training_args.do_predict:
        predict(trainer, predict_dataset)
转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/740421.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号