栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

深度学习

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

深度学习

一、目的

在我们日常开发环境(pycharm / VSCode)中都自带单词纠错的插件。由于近段时间准备重新梳理下NLP的知识,所以准备从这个单词纠错的插件着手,逐步构建出一个单词纠错神器。
里面主要会涉及输入输出数据集的构建,以及单词纠错网络的搭建。以及输入输出数据集构建的优化,网络框架的优化(会输出tensorflow2.6 以及 pytorch1.8 )

二、单词纠错神器网络搭建思路

第一直觉是sequence to sequence的模型, 输入一个拼写错误的单词,输出一个拼写正确的单词。
基于该直觉构建网络图如下:

网络简单搭建方式:

  • 需要对输入input_1以及input_2进行编码
    • input_2编码的时候需要加入特殊符号表示单词的开始和结束
  • 损失函数确定:交叉熵
  • 评估指标选取:Accuary
  • 梯度优化两个备选: rmsprop 和 adam
  • 网络搭建
三、单词纠错神器网络搭建 3.1 tensorflow 2.0及以上版本
from tensorflow.keras.layers import Dense, LSTM, Input
from tensorflow.keras import Model

def de_right_word_tf2(lstm_units, out_dims, encode_max_len, decode_max_len):
    encoder_lstm = LSTM(lstm_units, return_state=True)
    decoder_lstm = LSTM(lstm_units, return_state=True, return_sequences=True)
    fc = Dense(out_dims, activation='softmax')

    input_1 = Input(shape=(None, encode_max_len))
    encode_out, encode_h, encode_c = encoder_lstm(input_1)
    input_2 = Input(shape=(None, decode_max_len))
    decode_out, decode_h, decode_c = decoder_lstm(input_2, initial_state=[encode_h, encode_c])
    predict_out = fc(decode_out)
    model = Model([input_1, input_2], predict_out)
    model.compile(
        optimizer='rmsprop',
        loss=['categorical_crossentropy'],
        metrics=['accuracy']
    )
    return model
转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/307590.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号