栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

pytorch上实现简单的定点网络

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

pytorch上实现简单的定点网络

最近开始学pytorch相关的网络模型训练,作为基础练习训练一个关键点检测模型。
入坑看了这个项目 说明的还是很详细的
https://github.com/tensor-yu/PyTorch_Tutorial

里面的loss函数介绍以处理分类模型为主,而数据处理和读取,pytorch训练流程都讲的很全。因此我这次练点检测的模型只要稍微改一下数据读取和loss的计算即可。

我目前生成的数据保存在txt里,一行包含的信息为“图像路径+4个关键点的八个坐标”,因此在自定义数据模块只做了如下处理:(一般放在utils.py里定义自己的data类)

class MyDataset(Dataset):
    def __init__(self, txt_path, transform = None, target_transform = None):
        fh = open(txt_path, 'r')
        imgs = []
        for line in fh:
            line = line.rstrip()
            words = line.split()
            points_tensor = torch.Tensor([[float(words[1]),float(words[2])],[float(words[3]),float(words[4])],[float(words[5]),float(words[6])],[float(words[7]),float(words[8])]])
            imgs.append((words[0], points_tensor))

这里把8个数分成4组用于对应4个点 其他读取部分照抄教程里的例子即可。

接下来要选择loss函数,我这边就简单的采用模型输出的4个点和label里4个点欧氏距离的和作为loss,查完资料发现计算两点距离在torch里已经实现了,因此可以直接拿来用于loss的计算部分(教程里例子都是分类的一些loss,用不上了) 复杂一点的loss也要自己建一个loss的类,但是pytorch下欧氏距离的实现比较容易,这里就直接改main.py里的loss计算代码了

import torch.nn.functional as F
/******************教程中其他代码略*****************/        
        outputs = net(inputs)   #输出网络推理结果,reshape后计算欧氏距离
        outputs = outputs.reshape(-1,4,2)
        loss = torch.sum(F.pairwise_distance(outputs,labels, p=2))        
        loss.backward()
        optimizer.step()

最后是模型的选择,由于我要实现的点图像比较单一,就把教程例子里的RESNET砍一砍拿来用了,这里只列出动过的部分(resnet.py可以自己在model文件夹里创建一个,在main.py里调用即可)

class ResidualBlock(nn.Module):
    /***************略************/
class ResNets(BasicModule):
    def __init__(self, num_classes=8):   #输出8个数代表四个坐标
        super(ResNets, self).__init__()
        self.model_name = 'resnets'

        # 前几层: 图像转换
        self.pre = nn.Sequential(
                nn.Conv2d(3, 32, 7, 2, 3, bias=False),
                nn.BatchNorm2d(32),
                nn.ReLU(inplace=True),
                nn.MaxPool2d(3, 2, 1))
        
        # 重复的layer,residual block都被我砍成1个了
        self.layer1 = self._make_layer( 32, 64, 1)
        self.layer2 = self._make_layer( 64, 128, 1, stride=2)   
        self.layer3 = self._make_layer( 128, 256, 1, stride=2)
        self.layer4 = self._make_layer( 256, 256, 1, stride=2)

        #分类用的全连接
        self.fc = nn.Linear(256, num_classes)
    def _make_layer(self,  inchannel, outchannel, block_num, stride=1):
        /***************略************/
    def forward(self, x):
        /***************略************/

其他:lr和优化器optimizer都可以照抄不动,上述修改完后再去除main.py里和分类模型有关的计算和测试代码就可以跑了。

转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/768631.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号