栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

手撕代码deep image matting(7)

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

手撕代码deep image matting(7)

上两篇文章用了超级多的篇幅把dataset给拆解的溜干净,可以这么说:整个dim的代码比模型结构还恐怖的就是数据集的处理拆解完才明白:这特么根本不是正常人的脑回路能搞出来的)。既然dataset搞明白了,那就有必要重新审视下train函数以及长得极为相似的valid函数。

def train(train_loader, model, optimizer, epoch, logger):
    model.train()  # train mode (dropout and batchnorm is used)

    losses = AverageMeter()

    # Batches
    for i, (img, alpha_label) in enumerate(train_loader):
        # Move to GPU, if available
        img = img.type(torch.FloatTensor).to(device)  # [N, 4, 320, 320]
        alpha_label = alpha_label.type(torch.FloatTensor).to(device)  # [N, 320, 320]
        alpha_label = alpha_label.reshape((-1, 2, im_size * im_size))  # [N, 320*320]

        # Forward prop.
        alpha_out = model(img)  # [N, 3, 320, 320]
        alpha_out = alpha_out.reshape((-1, 1, im_size * im_size))  # [N, 320*320]

        # Calculate loss
        # loss = criterion(alpha_out, alpha_label)
        loss = alpha_prediction_loss(alpha_out, alpha_label)

        # Back prop.
        optimizer.zero_grad()
        loss.backward()

        # Clip gradients
        clip_gradient(optimizer, grad_clip)

        # Update weights
        optimizer.step()

        # Keep track of metrics
        losses.update(loss.item())

        # Print status

        if i % print_freq == 0:
            status = 'Epoch: [{0}][{1}/{2}]t' 
                     'Loss {loss.val:.4f} ({loss.avg:.4f})t'.format(epoch, i, len(train_loader), loss=losses)
            logger.info(status)

    return losses.avg


loss的实例变量之前没提过,这里单独拿出来:

class AverageMeter(object):
    """
    Keeps track of most recent, average, sum, and count of a metric.
    """

    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

按照注释的意思就是:提供这些变量就是为了追踪跟指标相关的recent、average、sum、count。在这个类里面设定了三个函数:初始化函数init(继承object的类都会带),reset重置函数,update更新函数。具体等到后面使用的时候回过头再看这里的情况。

关于dataset和dataloader,这篇文章写的挺对劲。

【小白学PyTorch】3 浅谈Dataset和Dataloader_忽逢桃林的博客-CSDN博客

 那么现在就开始重新审视它训练的过程。

    for i, (img, alpha_label) in enumerate(train_loader):
        # Move to GPU, if available
        img = img.type(torch.FloatTensor).to(device)  # [N, 4, 320, 320]
        alpha_label = alpha_label.type(torch.FloatTensor).to(device)  # [N, 320, 320]
        alpha_label = alpha_label.reshape((-1, 2, im_size * im_size))  # [N, 320*320]

        # Forward prop.
        alpha_out = model(img)  # [N, 3, 320, 320]
        alpha_out = alpha_out.reshape((-1, 1, im_size * im_size))  # [N, 320*320]

        # Calculate loss
        # loss = criterion(alpha_out, alpha_label)
        loss = alpha_prediction_loss(alpha_out, alpha_label)

        # Back prop.
        optimizer.zero_grad()
        loss.backward()

        # Clip gradients
        clip_gradient(optimizer, grad_clip)

        # Update weights
        optimizer.step()

        # Keep track of metrics
        losses.update(loss.item())

        # Print status

        if i % print_freq == 0:
            status = 'Epoch: [{0}][{1}/{2}]t' 
                     'Loss {loss.val:.4f} ({loss.avg:.4f})t'.format(epoch, i, len(train_loader), loss=losses)
            logger.info(status)

在这一次的模型训练的trainloader里面有bacth_size个dataset的getitem返回的数据对,img就是之前dataset出来的4通道tensor,alpha_label就是蒙版值和是否为128像素的标签。接下来就好理解很多了,在将img和alpha_label进行格式转换移入gpu之后,最开始在前面我认为alpha_lable经过了reshape(-1,2)的操作会转化成两列,但实际上并没有。

 到这里竟然是1维2行102400列。关于这里为啥是这样子我找了半天始终没有个比较明确的解释,推测就是第一个参数-1 使得后面自动按2*102400 生成的。至于说为什么让他降维成一维,看后面。

alpha_out保存将img输入模型后返回的值,那么有必要再把模型代码摆出来。

class DIMModel(nn.Module):
    def __init__(self, n_classes=1, in_channels=4, is_unpooling=True, pretrain=True):
        super(DIMModel, self).__init__()

        self.in_channels = in_channels
        self.is_unpooling = is_unpooling
        self.pretrain = pretrain

        self.down1 = segnetDown2(self.in_channels, 64)
        self.down2 = segnetDown2(64, 128)
        self.down3 = segnetDown3(128, 256)
        self.down4 = segnetDown3(256, 512)
        self.down5 = segnetDown3(512, 512)

        self.up5 = segnetUp1(512, 512)
        self.up4 = segnetUp1(512, 256)
        self.up3 = segnetUp1(256, 128)
        self.up2 = segnetUp1(128, 64)
        self.up1 = segnetUp1(64, n_classes)

        self.sigmoid = nn.Sigmoid()

        if self.pretrain:
            import torchvision.models as models
            vgg16 = models.vgg16()
            self.init_vgg16_params(vgg16)

    def forward(self, inputs):
        # inputs: [N, 4, 320, 320]
        down1, indices_1, unpool_shape1 = self.down1(inputs)
        down2, indices_2, unpool_shape2 = self.down2(down1)
        down3, indices_3, unpool_shape3 = self.down3(down2)
        down4, indices_4, unpool_shape4 = self.down4(down3)
        down5, indices_5, unpool_shape5 = self.down5(down4)

        up5 = self.up5(down5, indices_5, unpool_shape5)
        up4 = self.up4(up5, indices_4, unpool_shape4)
        up3 = self.up3(up4, indices_3, unpool_shape3)
        up2 = self.up2(up3, indices_2, unpool_shape2)
        up1 = self.up1(up2, indices_1, unpool_shape1)

        x = torch.squeeze(up1, dim=1)  # [N, 1, 320, 320] -> [N, 320, 320]
        x = self.sigmoid(x)

        return x

    def init_vgg16_params(self, vgg16):
        blocks = [self.down1, self.down2, self.down3, self.down4, self.down5]

        ranges = [[0, 4], [5, 9], [10, 16], [17, 23], [24, 29]]
        features = list(vgg16.features.children())

        vgg_layers = []
        for _layer in features:
            if isinstance(_layer, nn.Conv2d):
                vgg_layers.append(_layer)

        merged_layers = []
        for idx, conv_block in enumerate(blocks):
            if idx < 2:
                units = [conv_block.conv1.cbr_unit, conv_block.conv2.cbr_unit]
            else:
                units = [
                    conv_block.conv1.cbr_unit,
                    conv_block.conv2.cbr_unit,
                    conv_block.conv3.cbr_unit,
                ]
            for _unit in units:
                for _layer in _unit:
                    if isinstance(_layer, nn.Conv2d):
                        merged_layers.append(_layer)

        assert len(vgg_layers) == len(merged_layers)

        for l1, l2 in zip(vgg_layers, merged_layers):
            if isinstance(l1, nn.Conv2d) and isinstance(l2, nn.Conv2d):
                if l1.weight.size() == l2.weight.size() and l1.bias.size() == l2.bias.size():
                    l2.weight.data = l1.weight.data
                    l2.bias.data = l1.bias.data

到最后这两句就把alpha_out的类型摆了出来:

x = torch.squeeze(up1, dim=1)  # [N, 1, 320, 320] -> [N, 320, 320]
x = self.sigmoid(x)  

return x

也就是说,alpha_out就是最后输出的尺寸是(320*320),前面的注释为什么说是(3*320*320)有点没明白。最后把alpha_out进行了reshape,就成了(1*102400)。

等模型把alpha_out生成出来并且尺寸重组了之后,alpha_out和alpha_label本质上长得就大差不差了,下一步就是计算损失以及反向传播。由于把数据集看完了一圈,使得原本计算损失的一些模糊不清的地方变得清晰明了。

def alpha_prediction_loss(y_pred, y_true):
    mask = y_true[:, 1, :]
    diff = y_pred[:, 0, :] - y_true[:, 0, :]
    diff = diff * mask
    num_pixels = torch.sum(mask)
    return torch.sum(torch.sqrt(torch.pow(diff, 2) + epsilon_sqr)) / (num_pixels + epsilon)

在使用的时候是这么传参数的:loss = alpha_prediction_loss(alpha_out, alpha_label)

alpha_out对应y_pred

alpha_label 对应 y_true

那么alpha_out的规格就是[1,102400] alpha_label的规格是[1, 2,102400],由于本身alpha_label的尺寸是[2,320,320],其中第一通道是归一化的蒙版值,第二通道像素只有0和1两个值,记录该像素点是否是128,也就是mask,在进行reshape之后,第一行就是归一化蒙版值,第二行就是mask。那么到这里就好理解损失函数计算的时候那几个变量的含义了。

mask:alpha_true中记录像素点是128的情况,也就是trimap中模棱两可的边界的情况

y_pred:alpha_out,经过一系列处理后的四通道img进入模型后出来的尺寸为[1,102400]的数据

那么diff就是y_pred和y_true在第一维上面的比较,也就是在对比模型出来的蒙版值和原本的蒙版值之间的差值。比较出来的差值diff再乘上mask,在像素是128点上面就是差值,在像素不是128的点上面就是0。由此差值diff就能表示为trimap是128值的像素点上面模型出来的alpha_out和原本的alpha_label的差距。到最后num_pixels统计mask上面是1的数量,也就是128像素点的总数,最后用torch.pow将差值平方加上原本设定的为防止过拟合的常数esilon_sqr,二者的和用torch.sqrt开平方处理后,把整个diff里面的这些值都加一起最后除以num_pixels加上epsilon的和。其实这里面的损失函数根原本论文的设定大差不差。

到了这里其实整个训练函数就差不多了,但是后面出现了一个新的函数:clip_gradient(optimizer, grad_clip),看看这东西是干什么的。

def clip_gradient(optimizer, grad_clip):
    """
    Clips gradients computed during backpropagation to avoid explosion of gradients.
    :param optimizer: optimizer with the gradients to be clipped
    :param grad_clip: clip value
    """
    for group in optimizer.param_groups:
        for param in group['params']:
            if param.grad is not None:
                param.grad.data.clamp_(-grad_clip, grad_clip)

看到注释上面对该函数的功能描述是避免梯度爆炸而采用的clip gradient方法。最开始设定的grad_clip = 5,具体的操作推荐看这个文章:梯度爆炸的解决办法:clip gradient_u010814042的博客-CSDN博客_clip gradient

等都训练完毕了,就会打印出当前epoch以及当前数据item训练的时候的损失函数结果。到了这里,train函数就大功告成。其实看到这里valid函数也就不用多看了,步骤少了很多(懒得写了)。

转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/755130.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号