栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

bilstm命名实体识别任务

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

bilstm命名实体识别任务

import os
from torch.utils.data import Dataset,DataLoader
import torch
import torch.nn as nn
from sklearn.metrics import f1_score


def build_corpus(split, make_vocab=True, data_dir="data"):
    """读取数据"""
    assert split in ['train', 'dev', 'test']

    word_lists = []
    tag_lists = []
    with open(os.path.join(data_dir, split+".char.bmes"), 'r', encoding='utf-8') as f:
        word_list = []
        tag_list = []
        for line in f:
            if line != 'n':
                word, tag = line.strip('n').split()
                word_list.append(word)
                tag_list.append(tag)
            else:
                word_lists.append(word_list)
                tag_lists.append(tag_list)
                word_list = []
                tag_list = []

    word_lists = sorted(word_lists, key=lambda x: len(x), reverse=False)
    tag_lists = sorted(tag_lists, key=lambda x: len(x), reverse=False)

    # 如果make_vocab为True,还需要返回word2id和tag2id
    if make_vocab:
        word2id = build_map(word_lists)
        tag2id = build_map(tag_lists)
        word2id[''] = len(word2id)
        word2id[''] = len(word2id)

        tag2id[''] = len(tag2id)
        return word_lists, tag_lists, word2id, tag2id
    else:
        return word_lists, tag_lists

def build_map(lists):
    maps = {}
    for list_ in lists:
        for e in list_:
            if e not in maps:
                maps[e] = len(maps)

    return maps

class MyDataset(Dataset):
    def __init__(self,datas,tags,word_2_index,tag_2_index):
        self.datas = datas
        self.tags = tags
        self.word_2_index = word_2_index
        self.tag_2_index = tag_2_index

    def __getitem__(self,index):
        data = self.datas[index]
        tag  = self.tags[index]

        data_index = [self.word_2_index.get(i,self.word_2_index[""]) for i in data]
        tag_index  = [self.tag_2_index[i] for i in tag]

        return data_index,tag_index

    def __len__(self):
        assert len(self.datas) == len(self.tags)
        return len(self.tags)

    def pro_batch_data(self,batch_datas):
        global device
        datas = []
        tags = []
        batch_lens = []

        for data,tag in batch_datas:
            datas.append(data)
            tags.append(tag)
            batch_lens.append(len(data))
        batch_max_len = max(batch_lens)

        datas = [i + [self.word_2_index[""]] * (batch_max_len - len(i)) for i in datas]
        tags = [i + [self.tag_2_index[""]] * (batch_max_len - len(i)) for i in tags]

        return torch.tensor(datas,dtype=torch.int64,device=device),torch.tensor(tags,dtype=torch.long,device=device)



class Mymodel(nn.Module):
    def __init__(self,corpus_num,embedding_num,hidden_num,class_num,bi=True):
        super().__init__()

        self.embedding = nn.Embedding(corpus_num,embedding_num)
        self.lstm = nn.LSTM(embedding_num,hidden_num,batch_first=True,bidirectional=bi)

        if bi :
            self.classifier = nn.Linear(hidden_num * 2,class_num)
        else:
            self.classifier = nn.Linear(hidden_num, class_num)

        self.cross_loss = nn.CrossEntropyLoss()



    def forward(self,batch_data,batch_tag=None):
        embedding = self.embedding(batch_data)
        out,_ = self.lstm(embedding)

        pre = self.classifier(out)
        self.pre = torch.argmax(pre, dim=-1).reshape(-1)
        if batch_tag is not None:
            loss = self.cross_loss(pre.reshape(-1,pre.shape[-1]),batch_tag.reshape(-1))
            return loss




def test():
    global word_2_index,model,index_2_tag,device
    while True:
        text = input("请输入:")
        text_index = [[word_2_index.get(i,word_2_index[""]) for i in text]]
        text_index = torch.tensor(text_index,dtype=torch.int64,device=device)
        model.forward(text_index)
        pre = [index_2_tag[i] for i in model.pre]

        print([f'{w}_{s}' for w,s in zip(text,pre)])




if __name__ == "__main__":
    device = "cuda:0" if torch.cuda.is_available() else "cpu"

    train_data,train_tag,word_2_index,tag_2_index = build_corpus("train",make_vocab=True)
    dev_data,dev_tag = build_corpus("dev",make_vocab=False)
    index_2_tag = [i for i in tag_2_index]

    corpus_num = len(word_2_index)
    class_num  = len(tag_2_index)

    epoch = 10
    train_batch_size = 50
    dev_batch_size = 100
    embedding_num = 101
    hidden_num = 107
    bi = True
    lr = 0.001

    train_dataset = MyDataset(train_data,train_tag,word_2_index,tag_2_index)
    train_dataloader = DataLoader(train_dataset,train_batch_size,shuffle=False,collate_fn=train_dataset.pro_batch_data)

    dev_dataset = MyDataset(dev_data, dev_tag, word_2_index, tag_2_index)
    dev_dataloader = DataLoader(dev_dataset, dev_batch_size, shuffle=False,collate_fn=dev_dataset.pro_batch_data)

    model = Mymodel(corpus_num,embedding_num,hidden_num,class_num,bi)
    opt = torch.optim.Adam(model.parameters(),lr = lr)
    model = model.to(device)

    for e in range(epoch):
        model.train()
        for batch_data,batch_tag in train_dataloader:
            train_loss = model.forward(batch_data,batch_tag)
            train_loss.backward()
            opt.step()
            opt.zero_grad()

        model.eval()
        all_pre = []
        all_tag = []
        for dev_batch_data,dev_batch_tag in dev_dataloader:
            dev_loss = model.forward(dev_batch_data,dev_batch_tag)
            all_pre.extend(model.pre.detach().cpu().numpy().tolist())
            all_tag.extend(dev_batch_tag.detach().cpu().numpy().reshape(-1).tolist())
        score = f1_score(all_tag,all_pre,average="micro")
        print(f"{e},f1_score:{score:.3f},dev_loss:{dev_loss:.3f},train_loss:{train_loss:.3f}")
    test()

转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/866810.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号