栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

NMT平行语料划分数据集

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

NMT平行语料划分数据集

目标:

将数据集按比例划分为 train、test、val。

对平行语料处理后如下图所示:

步骤:
  1. 随机打乱数据集
  2. 划分数据集
  3. 划分平行语料
代码如下:
import os
import random


def data_split(config, file, train_ratio=0.98, shuffle=True):
    """
    :param config: 数据文件所在的文件夹名
    :param file: 要处理数据的文件名(全称)
    :param train_ratio: 训练集占比
    :param shuffle: 是否打乱
    :return: None
    """
    with open(os.path.join(config, file), 'r', encoding='utf-8') as fp:  # 用拼接config与file
        lines = fp.read().strip().split('n')
    n = len(lines)
    if shuffle:
        random.shuffle(lines)  # 随机打乱数据集

    train_len = int(n * train_ratio)
    val_len = int(n * (1 - train_ratio) / 2)

    train_data = lines[:train_len]  # 训练集
    val_data = lines[train_len:(train_len + val_len + 1)]  # 验证集
    test_data = lines[(train_len + val_len + 1):]  # 测试集

    train = 'train.txt'
    test = 'test.txt'
    val = 'val.txt'
    s_config = os.path.join(config, 'dataset')  # 划分后数据集存放位置

    with open(os.path.join(config, train), 'w', encoding='utf-8') as fp:
        fp.write("n".join(train_data))
        para_divide(s_config, train)
    with open(os.path.join(config, test), 'w', encoding='utf-8') as fp:
        fp.write("n".join(test_data))
        para_divide(s_config, test)
    with open(os.path.join(config, val), 'w', encoding='utf-8') as fp:
        fp.write("n".join(val_data))
        para_divide(s_config, val)

    print('总共有数据:{}条'.format(n))
    print('训练集:{}条'.format(len(train_data)))
    print('测试集:{}条'.format(len(test_data)))
    print('验证集:{}条'.format(len(val_data)))


def para_divide(config, data):  # 划分平行语料
    f_data = open(os.path.join(config[:4], data), 'r', encoding='utf-8', errors='ignore')  # config[:4]='data'
    en = open(os.path.join(config, (data[:-4]+'.en')), 'w', encoding='utf-8')  # data[:-4]去除文件后缀
    zh = open(os.path.join(config, (data[:-4]+'.zh')), 'w', encoding='utf-8')

    line = f_data.readline()
    while line:
        l, r = line.strip().split('t')  # 按t划分

        l = l.strip() + 'n'
        r = r.strip() + 'n'

        en.writelines(l)
        zh.writelines(r)

        line = f_data.readline()

    en.close()
    zh.close()
    f_data.close()


def main():
    data_split("data", "en-zh.txt")


if __name__ == '__main__':
    main()

转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/689006.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号