栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

paddlepadlde-paddlenlp换成自定义数据集本地加载更改步骤

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

paddlepadlde-paddlenlp换成自定义数据集本地加载更改步骤

一、自定义数据读取函数,并且做好读取配置,注意data_path这个参数必须对应好,名字和load_dataset保持一致
   
def read_out(data_path):
    """
    pass
    """
    with open(data_path, 'r', encoding='utf-8') as f:
        for line in f:
            line_stripped = line.strip().split('t')
            if not line_stripped:
                break
            if len(line_stripped) == 2:
                tokens = line_stripped[0].split("02")
                tags = line_stripped[1].split("02")
            else:
                tokens = line_stripped.split("02")
                tags = []
            yield {"tokens": tokens, "labels": tags}


   train_ds = load_dataset(read_out, data_path=path_out_train_, lazy=False)
   test_ds = load_dataset(read_out, data_path=path_out_test_, lazy=False)
二、标签不是自带加载的了,注意结合本地文件主动加载
 	label_vocab = load_dict(dict_path_)
    label_num = len(label_vocab)
    no_entity_id = label_num - 1
三 特别注意内部加载的函数,label会自动转换成id,如果是自己本地加载,需要在特征转换哪里传入相关参数,自己做好label到id的转换
def tokenize_and_align_labels(example, tokenizer, no_entity_id, label_vocab, max_seq_len=512):
    """
    pass
    """
    labels = example['labels']
    example = example['tokens']
    tokenized_input = tokenizer(
        example,
        return_length=True,
        is_split_into_words=True,
        max_seq_len=max_seq_len)

    # -2 for [CLS] and [SEP]
    if len(tokenized_input['input_ids']) - 2 < len(labels):
        labels = labels[:len(tokenized_input['input_ids']) - 2]
    # Read custom data locally, the system will not automatically convert it, you must manually convert label to id
    tokenized_input['labels'] = [no_entity_id] + [label_vocab[x] for x in labels] + [no_entity_id]
    tokenized_input['labels'] += [no_entity_id] * (
        len(tokenized_input['input_ids']) - len(tokenized_input['labels']))
    return tokenized_input
四、预测的时候不能是tensor,可以按tensor转换成numpy,或者直接用python数据输入
    def do_predict(self, title, text):
        """
        Entry function
        """
        # Create dataset, tokenizer and dataloader.
        predict_ds, raw_data = self.pre_data(text)

        pred_list = []
        len_list = []
        num_of_examples = len(predict_ds)
        start_idx = 0
        while start_idx < num_of_examples:
            end_idx = start_idx + args.batch_size
            end_idx = end_idx if end_idx < num_of_examples else num_of_examples
            batch_data = [
                self.trans_func(example) for example in predict_ds[start_idx:end_idx]
            ]
            start_idx += args.batch_size

            input_ids, token_type_ids, length = self.batchify_fn(batch_data)
            self.input_handles[0].copy_from_cpu(input_ids)  # must not tensor
            self.input_handles[1].copy_from_cpu(token_type_ids)  # must not tensor
            self.predictor.run()
            logits = self.output_handle.copy_to_cpu()
            pred = np.argmax(logits, axis=-1)  # output is numpy
            pred_list.append(pred)
            len_list.append(length)

        preds = self.parse_decodes(predict_ds, self.id2label, pred_list, len_list)
        result = self.post_data(preds, title, raw_data)
        return result
转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/331428.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号