栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

基于CNN+LSTM+CTC的不定长电表数字识别(Pytorch模型篇)

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

基于CNN+LSTM+CTC的不定长电表数字识别(Pytorch模型篇)

最近在做一个项目的某个模块,主要涉及文字识别的相关技术,文字识别主要分为两个步骤,文字检测与识别,本文主要针对文字识别的板块搭建模型,主流的就要属CRNN+CTC了。今天就送上案例实操,也是自己动手搭建的,分享一点心得。
做的过程中也是查看了许多相关文献和网站,这里主推一篇知乎文章,讲的真的很好,附上链接:一文读懂CRNN+CTC文字识别
要实现文字识别的最终落地,包括搭建模型,构造自己的datasets,然后开始training,本文先讲搭建模型。
废话不多说,直接上code。

在这里插入代码片import torch.nn as nn
import torch.nn.functional as F
class ResidualBlock(nn.Module):
    """
    每一个ResidualBlock,需要保证输入和输出的维度不变
    所以卷积核的通道数都设置成一样
    """
    def __init__(self, channel):
        super().__init__()
        self.conv1 = nn.Conv2d(channel, channel, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(channel, channel, kernel_size=3, padding=1)

    def forward(self, x):
        """
        ResidualBlock中有跳跃连接;
        在得到第二次卷积结果时,需要加上该残差块的输入,
        再将结果进行激活,实现跳跃连接 ==> 可以避免梯度消失
        在求导时,因为有加上原始的输入x,所以梯度为: dy + 1,在1附近
        """
        y = F.relu(self.conv1(x))
        y = self.conv2(y)

        return F.relu(x + y)

class myLSTM(nn.Module):
    def __init__(self,input_size,hidden_size,nout):
        super(myLSTM, self).__init__()
        self.lstm = nn.LSTM(input_size,hidden_size,num_layers=2,bidirectional=True)
        self.linear = nn.Linear(2*hidden_size,nout)

    def forward(self,input):
        output,(h_n,c_n) = self.lstm(input)
        T,B,H = output.size()
        rec = output.view(T*B,H)
        lout = self.linear(rec)
        lout = lout.view(T,B,-1)#为了满足CTCloss的输入
        return lout

class myCNN(nn.Module):
    def __init__(self):
        super(myCNN, self).__init__()
        self.resblock1 = ResidualBlock(32)
        self.resblock2 = ResidualBlock(64)
        self.resblock3 = ResidualBlock(128)
        self.conv1 = nn.Sequential(
            #layer_1
            nn.Conv2d(3,32,3),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(2),
        )
        self.conv2 = nn.Sequential(
            #layer_2
            nn.Conv2d(32,64,3),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2),
        )
        self.conv3 = nn.Sequential(
            # layer_3
            nn.Conv2d(64, 128, 3),
            nn.BatchNorm2d(128),
            nn.ReLU(),
        )
    def forward(self,input):
        out1 = self.conv1(input)
        out1 = self.resblock1(out1)
        out2 = self.conv2(out1)
        out2 = self.resblock2(out2)
        out3 = self.conv3(out2)
        out3 = self.resblock3(out3)
        return out3

class CRNN(nn.Module):
    def __init__(self,nclass,nhidden):
        super(CRNN, self).__init__()
        self.cnn = nn.Sequential(myCNN())
        self.lstm = nn.Sequential(
            myLSTM(4*128,nhidden,nhidden),
            myLSTM(nhidden,nhidden,nclass),
        )
    def forward(self,input):
        conv = self.cnn(input)
        batch,channel,h,w = conv.shape
        #print(conv.shape)
        conv = conv.permute(0,3,2,1)
        # conv = conv.squeeze(dim =2)#[B,C,W]
        conv = conv.reshape(batch, -1, 4*128)
        conv = conv.permute(1,0,2)#input for lstm[T,N,C]
        out = self.lstm(conv)
        return out

图片大小Wx32x3,宽度不做要求但要求数据集统一大小,便于训练,如果H不同,可以参考卷积网络的输出自行修改即可,我这里经过卷积之后的H为4.
CNN模型我们搭建简单的ResNet,LSTM选用双向的,因此要注意输出。
代码基于pytorch框架,十分简洁明了。
码字不易,感谢点赞,下期出版训练代码以及数据集制作的代码。

转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/304268.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号