栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

文本编解码tokenizer

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

文本编解码tokenizer

import json
import re
from typing import List


class CharacterTokenizer:
    """
    Tokenizer的功能是实现文本的编解码。编码,即把字符转成数字,但是实际生活中的字符是无限的,我们总可以遇到新的字符,
    而这些字符在训练集中并不能得到充分训练,于是我们暂时用来表示。
    编码过后,无限的字符变成有限的id;而后,在解码阶段,将id恢复成原始的字符,那些可以恢复的字符都是得到充分训练的,而无法恢复的字符,
    可以理解成失真了。
    """
    def __init__(self):
        self.pad_token = ''
        self.digit_token = ''
        self.alpha_token = ''
        self.unk_token = ''
        self.token2id = None

    def build_vocab(self, corpus_file_path, dump_file_path, min_count=5):
        token2freq = {}
        with open(corpus_file_path) as fin:
            for line in fin:
                line = json.loads(line.strip())['text']
                for ch in line:
                    if re.match(r'd', ch) is not None:
                        continue
                    elif re.match(r'[a-zA-Z]', ch) is not None:
                        continue
                    token2freq[ch] = token2freq.get(ch, 0) + 1
        token2freq = sorted(token2freq.items(), key=lambda x: x[1], reverse=True)
        tokens = [pair[0] for pair in token2freq if pair[1] >= min_count]
        tokens = [self.pad_token, self.unk_token, self.digit_token, self.alpha_token] + tokens
        print(f'vocabulary character: {len(tokens)}')

        with open(dump_file_path, 'w') as fout:
            for token in tokens:
                fout.write(token + 'n')
        print('vocabulary built!')

    def load_vocab(self, vocab_file):
        tokens = open(vocab_file).read().splitlines()
        self.token2id = {token: idx for idx, token in enumerate(tokens)}

    def encode_tokens(self, token_list, padding=False, max_length=None):
        assert self.token2id is not None, 'you MUST load vocab first!'
        id_list = []
        for token in token_list:
            if re.match(r'd', token) is not None:
                id_list.append(self.token2id[self.digit_token])
            elif re.match('[a-zA-Z]', token) is not None:
                id_list.append(self.token2id[self.alpha_token])
            elif token in self.token2id:
                id_list.append(self.token2id[token])
            else:
                id_list.append(self.token2id[self.unk_token])
        if padding and max_length is not None:
            id_list = id_list[:max_length] + [self.token2id[self.pad_token]] * max(0, max_length - len(id_list))
        return id_list

    def decode_ids(self, id_list, truncate_pad_tokens=True):
        assert self.token2id is not None, 'you MUST load vocab first!'
        id2token = {v: k for k, v in self.token2id.items()}
        token_list = [id2token[idx] for idx in id_list]
        if truncate_pad_tokens:
            while token_list and token_list[-1] == self.pad_token:
                token_list.pop(self.pad_token)
        return token_list
转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/307398.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号