栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

bert中的数据输入制作data

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

bert中的数据输入制作data

数据生成器 自动生成训练、验证、测试集 1.需要创建一个类data_generator 这个类继承DataGenerator类 bert4kreas.snippets 这个类主要是做数据生成的迭代器2. 创建 DateProcess() 类 3.测试

‘’’
这的数据是这样的 是一个 txt 文件
每一行是一个样本 label text
所以在后文加载数据的时候使用的5个
‘’’

1.需要创建一个类data_generator 这个类继承DataGenerator类 bert4kreas.snippets 这个类主要是做数据生成的迭代器
from bert4keras.snippets import DataGenerator,sequence_padding # 导入相关的包
from bert4keras.tokenizers import Tokenizer # 导入分词器
from bert4keras.snippets import open as bert4keras_open
import json
import os
import random
# 这是一个数据生成器
class data_generator(DataGenerator):
 # data:数据
 # max_len:句子的最大长度
 # batch_size:小批量数据的条数
 def __init__(self, data, max_len, batch_size, vocab_path, buffer_size None):
 self.max_len max_len
 self.batch_size batch_size
 self.vocab_path vocab_path
 self.data data
 self.buffer_size buffer_size
 # 判断 self.data 中是否有 len 的方法
 if hasattr(self.data, __len__ ):
 self.steps len(self.data) // self.batch_size
 if len(self.data) % self.batch_size ! 0:
 self.steps 1
 else:
 self.steps None
 self.buffer_size buffer_size or self.batch_size * 1000
 def __iter__(self, random False):
 # 建立一个分词器
 tokenizer Tokenizer(self.vocab_path, do_lower_case True)
 # 创建存 token segment label的列表
 batch_token_ids, batch_segment_ids, batch_label_ids [], [], []
 # is_end 判断到没到最后一条数据 最后一条数据is_end True 否则is_end False 
 for is_end, (label, text) in self.sample(random):
 # 将 text 文本进行编码得到 token_ids 与 segment_ids
 token_ids, segment_ids tokenizer.encode(text, maxlen self.max_len)
 # 加入创建好的列表
 batch_token_ids.append(token_ids)
 batch_segment_ids.append(segment_ids)
 batch_label_ids.append([label])
 # 判断 是否是最后一条数据 or 是否达到了一个batch的数量 
 if is_end or len(batch_token_ids) self.batch_size:
 # 对于每个 token segment 进行补全 根据 max_len 进行长度的统一
 batch_token_ids sequence_padding(batch_token_ids)
 batch_segment_ids sequence_padding(batch_segment_ids)
 batch_label_ids sequence_padding(batch_label_ids)
 # 返回每个 batch 的数据
 yield [batch_token_ids, batch_segment_ids], batch_label_ids
 # 重新计数 下一个 batch
 batch_token_ids, batch_segment_ids, batch_label_ids [], [], []
2. 创建 DateProcess() 类
class DateProcess(object):
 def __init__(self, vocab_path, max_len, batch_size):
 self.vocab_path vocab_path
 self.max_len max_len
 self.batch_size batch_size
 def get_label2id(self, train_data_path, model_output_path):
 这个主要实现了将标签与 id 进行替换
 :param train_data_path: 数据的路径
 :param model_output_path: 模型输出的路径
 :return: 返回两个字典 一个是 label- id 一个是 id- label
 label_list []
 with bert4keras_open(train_data_path, r , encoding utf-8 ) as f:
 for text in f:
 label_list.append(text.strip().split( )[0])
 # 将label进行去重
 labels sorted(set(label_list))
 id2label {}
 label2id {}
 for index, label in enumerate(labels):
 label2id[label] index
 id2label[index] label
 # 为了以后预测模型时 id 与 label 好对应 所以将 id2label 进行 json 格式的保存
 with bert4keras_open(os.path.join(model_output_path id2abel_new_new.json ), w ) as f:
 json.dump(id2label, f, ensure_ascii False)
 return label2id, id2label
 def load_data(self, file_name, label2id):
 这个函数用来加载数据 将样本存到一个列表中 每个样本是一个元组 [(label,text),(label,text),(label,text)..........(label,text)]
 这里就通过 label2id(字典)将文本的 label 转化为了 id
 :param file_name: 
 :param label2id: 
 :return: 返回一个列表
 data_list []
 with bert4keras_open(file_name, r , encoding utf-8 ) as f:
 for line in f:
 label label2id[line.strip().split( )[0]]
 text line.strip().split( )[1]
 data_list.append((label, text))
 return data_list
 def generate_data(self, train_data_path, model_output_path):
 这个函数就是将 加载数据 label- id 生成 训练集 测试集 验证机生成器的一个函数
 :param train_data_path: 
 :param model_output_path: 
 :return: 返回5个参数 label2id, id2label, train_data_generate, vail_data_generate, test_data_generate
 # 生成 label- id 的字典
 label2id, id2label self.get_label2id(train_data_path, model_output_path)
 # 生成 数据的列表
 data_list self.load_data(train_data_path, label2id)
 length len(data_list)
 # 进行数据的打乱
 random.shuffle(data_list)
 # 划分训练集 验证集 测试集
 train_data data_list[:int(0.8 * length)]
 vail_data data_list[int(0.8 * length):int(0.9 * length)]
 test_data data_list[int(0.9 * length):]
 # 创建三个数据生成器
 train_data_generate data_generator(train_data, self.max_len, self.batch_size, self.vocab_path)
 vail_data_generate data_generator(vail_data, self.max_len, self.batch_size, self.vocab_path)
 test_data_generate data_generator(test_data, self.max_len, self.batch_size, self.vocab_path)
 return label2id, id2label, train_data_generate, vail_data_generate, test_data_generate
3.测试
if __name__ __main__ :
 data_process DateProcess( ../data/vocab.txt , max_len 128, batch_size 32)
 _, _, train, vail, test data_process.generate_data( ../data/result_sample.txt , ../data/output )
 for token_and_segment,label in test:
 print( * *1000)
 print(token_and_segment[0])
 print()
 print(token_and_segment[1])
 print()
 print(label[:, 0])

测试结果

转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/268300.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号