import json
import re
from typing import List
class CharacterTokenizer:
"""
Tokenizer的功能是实现文本的编解码。编码,即把字符转成数字,但是实际生活中的字符是无限的,我们总可以遇到新的字符,
而这些字符在训练集中并不能得到充分训练,于是我们暂时用来表示。
编码过后,无限的字符变成有限的id;而后,在解码阶段,将id恢复成原始的字符,那些可以恢复的字符都是得到充分训练的,而无法恢复的字符,
可以理解成失真了。
"""
def __init__(self):
self.pad_token = ''
self.digit_token = ''
self.alpha_token = ''
self.unk_token = ''
self.token2id = None
def build_vocab(self, corpus_file_path, dump_file_path, min_count=5):
token2freq = {}
with open(corpus_file_path) as fin:
for line in fin:
line = json.loads(line.strip())['text']
for ch in line:
if re.match(r'd', ch) is not None:
continue
elif re.match(r'[a-zA-Z]', ch) is not None:
continue
token2freq[ch] = token2freq.get(ch, 0) + 1
token2freq = sorted(token2freq.items(), key=lambda x: x[1], reverse=True)
tokens = [pair[0] for pair in token2freq if pair[1] >= min_count]
tokens = [self.pad_token, self.unk_token, self.digit_token, self.alpha_token] + tokens
print(f'vocabulary character: {len(tokens)}')
with open(dump_file_path, 'w') as fout:
for token in tokens:
fout.write(token + 'n')
print('vocabulary built!')
def load_vocab(self, vocab_file):
tokens = open(vocab_file).read().splitlines()
self.token2id = {token: idx for idx, token in enumerate(tokens)}
def encode_tokens(self, token_list, padding=False, max_length=None):
assert self.token2id is not None, 'you MUST load vocab first!'
id_list = []
for token in token_list:
if re.match(r'd', token) is not None:
id_list.append(self.token2id[self.digit_token])
elif re.match('[a-zA-Z]', token) is not None:
id_list.append(self.token2id[self.alpha_token])
elif token in self.token2id:
id_list.append(self.token2id[token])
else:
id_list.append(self.token2id[self.unk_token])
if padding and max_length is not None:
id_list = id_list[:max_length] + [self.token2id[self.pad_token]] * max(0, max_length - len(id_list))
return id_list
def decode_ids(self, id_list, truncate_pad_tokens=True):
assert self.token2id is not None, 'you MUST load vocab first!'
id2token = {v: k for k, v in self.token2id.items()}
token_list = [id2token[idx] for idx in id_list]
if truncate_pad_tokens:
while token_list and token_list[-1] == self.pad_token:
token_list.pop(self.pad_token)
return token_list