上两篇文章用了超级多的篇幅把dataset给拆解的溜干净,可以这么说:整个dim的代码比模型结构还恐怖的就是数据集的处理拆解完才明白:这特么根本不是正常人的脑回路能搞出来的)。既然dataset搞明白了,那就有必要重新审视下train函数以及长得极为相似的valid函数。
def train(train_loader, model, optimizer, epoch, logger):
model.train() # train mode (dropout and batchnorm is used)
losses = AverageMeter()
# Batches
for i, (img, alpha_label) in enumerate(train_loader):
# Move to GPU, if available
img = img.type(torch.FloatTensor).to(device) # [N, 4, 320, 320]
alpha_label = alpha_label.type(torch.FloatTensor).to(device) # [N, 320, 320]
alpha_label = alpha_label.reshape((-1, 2, im_size * im_size)) # [N, 320*320]
# Forward prop.
alpha_out = model(img) # [N, 3, 320, 320]
alpha_out = alpha_out.reshape((-1, 1, im_size * im_size)) # [N, 320*320]
# Calculate loss
# loss = criterion(alpha_out, alpha_label)
loss = alpha_prediction_loss(alpha_out, alpha_label)
# Back prop.
optimizer.zero_grad()
loss.backward()
# Clip gradients
clip_gradient(optimizer, grad_clip)
# Update weights
optimizer.step()
# Keep track of metrics
losses.update(loss.item())
# Print status
if i % print_freq == 0:
status = 'Epoch: [{0}][{1}/{2}]t'
'Loss {loss.val:.4f} ({loss.avg:.4f})t'.format(epoch, i, len(train_loader), loss=losses)
logger.info(status)
return losses.avg
loss的实例变量之前没提过,这里单独拿出来:
class AverageMeter(object):
"""
Keeps track of most recent, average, sum, and count of a metric.
"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
按照注释的意思就是:提供这些变量就是为了追踪跟指标相关的recent、average、sum、count。在这个类里面设定了三个函数:初始化函数init(继承object的类都会带),reset重置函数,update更新函数。具体等到后面使用的时候回过头再看这里的情况。
关于dataset和dataloader,这篇文章写的挺对劲。
【小白学PyTorch】3 浅谈Dataset和Dataloader_忽逢桃林的博客-CSDN博客
那么现在就开始重新审视它训练的过程。
for i, (img, alpha_label) in enumerate(train_loader):
# Move to GPU, if available
img = img.type(torch.FloatTensor).to(device) # [N, 4, 320, 320]
alpha_label = alpha_label.type(torch.FloatTensor).to(device) # [N, 320, 320]
alpha_label = alpha_label.reshape((-1, 2, im_size * im_size)) # [N, 320*320]
# Forward prop.
alpha_out = model(img) # [N, 3, 320, 320]
alpha_out = alpha_out.reshape((-1, 1, im_size * im_size)) # [N, 320*320]
# Calculate loss
# loss = criterion(alpha_out, alpha_label)
loss = alpha_prediction_loss(alpha_out, alpha_label)
# Back prop.
optimizer.zero_grad()
loss.backward()
# Clip gradients
clip_gradient(optimizer, grad_clip)
# Update weights
optimizer.step()
# Keep track of metrics
losses.update(loss.item())
# Print status
if i % print_freq == 0:
status = 'Epoch: [{0}][{1}/{2}]t'
'Loss {loss.val:.4f} ({loss.avg:.4f})t'.format(epoch, i, len(train_loader), loss=losses)
logger.info(status)
在这一次的模型训练的trainloader里面有bacth_size个dataset的getitem返回的数据对,img就是之前dataset出来的4通道tensor,alpha_label就是蒙版值和是否为128像素的标签。接下来就好理解很多了,在将img和alpha_label进行格式转换移入gpu之后,最开始在前面我认为alpha_lable经过了reshape(-1,2)的操作会转化成两列,但实际上并没有。
到这里竟然是1维2行102400列。关于这里为啥是这样子我找了半天始终没有个比较明确的解释,推测就是第一个参数-1 使得后面自动按2*102400 生成的。至于说为什么让他降维成一维,看后面。
alpha_out保存将img输入模型后返回的值,那么有必要再把模型代码摆出来。
class DIMModel(nn.Module):
def __init__(self, n_classes=1, in_channels=4, is_unpooling=True, pretrain=True):
super(DIMModel, self).__init__()
self.in_channels = in_channels
self.is_unpooling = is_unpooling
self.pretrain = pretrain
self.down1 = segnetDown2(self.in_channels, 64)
self.down2 = segnetDown2(64, 128)
self.down3 = segnetDown3(128, 256)
self.down4 = segnetDown3(256, 512)
self.down5 = segnetDown3(512, 512)
self.up5 = segnetUp1(512, 512)
self.up4 = segnetUp1(512, 256)
self.up3 = segnetUp1(256, 128)
self.up2 = segnetUp1(128, 64)
self.up1 = segnetUp1(64, n_classes)
self.sigmoid = nn.Sigmoid()
if self.pretrain:
import torchvision.models as models
vgg16 = models.vgg16()
self.init_vgg16_params(vgg16)
def forward(self, inputs):
# inputs: [N, 4, 320, 320]
down1, indices_1, unpool_shape1 = self.down1(inputs)
down2, indices_2, unpool_shape2 = self.down2(down1)
down3, indices_3, unpool_shape3 = self.down3(down2)
down4, indices_4, unpool_shape4 = self.down4(down3)
down5, indices_5, unpool_shape5 = self.down5(down4)
up5 = self.up5(down5, indices_5, unpool_shape5)
up4 = self.up4(up5, indices_4, unpool_shape4)
up3 = self.up3(up4, indices_3, unpool_shape3)
up2 = self.up2(up3, indices_2, unpool_shape2)
up1 = self.up1(up2, indices_1, unpool_shape1)
x = torch.squeeze(up1, dim=1) # [N, 1, 320, 320] -> [N, 320, 320]
x = self.sigmoid(x)
return x
def init_vgg16_params(self, vgg16):
blocks = [self.down1, self.down2, self.down3, self.down4, self.down5]
ranges = [[0, 4], [5, 9], [10, 16], [17, 23], [24, 29]]
features = list(vgg16.features.children())
vgg_layers = []
for _layer in features:
if isinstance(_layer, nn.Conv2d):
vgg_layers.append(_layer)
merged_layers = []
for idx, conv_block in enumerate(blocks):
if idx < 2:
units = [conv_block.conv1.cbr_unit, conv_block.conv2.cbr_unit]
else:
units = [
conv_block.conv1.cbr_unit,
conv_block.conv2.cbr_unit,
conv_block.conv3.cbr_unit,
]
for _unit in units:
for _layer in _unit:
if isinstance(_layer, nn.Conv2d):
merged_layers.append(_layer)
assert len(vgg_layers) == len(merged_layers)
for l1, l2 in zip(vgg_layers, merged_layers):
if isinstance(l1, nn.Conv2d) and isinstance(l2, nn.Conv2d):
if l1.weight.size() == l2.weight.size() and l1.bias.size() == l2.bias.size():
l2.weight.data = l1.weight.data
l2.bias.data = l1.bias.data
到最后这两句就把alpha_out的类型摆了出来:
x = torch.squeeze(up1, dim=1) # [N, 1, 320, 320] -> [N, 320, 320]
x = self.sigmoid(x)
return x
也就是说,alpha_out就是最后输出的尺寸是(320*320),前面的注释为什么说是(3*320*320)有点没明白。最后把alpha_out进行了reshape,就成了(1*102400)。
等模型把alpha_out生成出来并且尺寸重组了之后,alpha_out和alpha_label本质上长得就大差不差了,下一步就是计算损失以及反向传播。由于把数据集看完了一圈,使得原本计算损失的一些模糊不清的地方变得清晰明了。
def alpha_prediction_loss(y_pred, y_true):
mask = y_true[:, 1, :]
diff = y_pred[:, 0, :] - y_true[:, 0, :]
diff = diff * mask
num_pixels = torch.sum(mask)
return torch.sum(torch.sqrt(torch.pow(diff, 2) + epsilon_sqr)) / (num_pixels + epsilon)
在使用的时候是这么传参数的:loss = alpha_prediction_loss(alpha_out, alpha_label)
alpha_out对应y_pred
alpha_label 对应 y_true
那么alpha_out的规格就是[1,102400] alpha_label的规格是[1, 2,102400],由于本身alpha_label的尺寸是[2,320,320],其中第一通道是归一化的蒙版值,第二通道像素只有0和1两个值,记录该像素点是否是128,也就是mask,在进行reshape之后,第一行就是归一化蒙版值,第二行就是mask。那么到这里就好理解损失函数计算的时候那几个变量的含义了。
mask:alpha_true中记录像素点是128的情况,也就是trimap中模棱两可的边界的情况
y_pred:alpha_out,经过一系列处理后的四通道img进入模型后出来的尺寸为[1,102400]的数据
那么diff就是y_pred和y_true在第一维上面的比较,也就是在对比模型出来的蒙版值和原本的蒙版值之间的差值。比较出来的差值diff再乘上mask,在像素是128点上面就是差值,在像素不是128的点上面就是0。由此差值diff就能表示为trimap是128值的像素点上面模型出来的alpha_out和原本的alpha_label的差距。到最后num_pixels统计mask上面是1的数量,也就是128像素点的总数,最后用torch.pow将差值平方加上原本设定的为防止过拟合的常数esilon_sqr,二者的和用torch.sqrt开平方处理后,把整个diff里面的这些值都加一起最后除以num_pixels加上epsilon的和。其实这里面的损失函数根原本论文的设定大差不差。
到了这里其实整个训练函数就差不多了,但是后面出现了一个新的函数:clip_gradient(optimizer, grad_clip),看看这东西是干什么的。
def clip_gradient(optimizer, grad_clip):
"""
Clips gradients computed during backpropagation to avoid explosion of gradients.
:param optimizer: optimizer with the gradients to be clipped
:param grad_clip: clip value
"""
for group in optimizer.param_groups:
for param in group['params']:
if param.grad is not None:
param.grad.data.clamp_(-grad_clip, grad_clip)
看到注释上面对该函数的功能描述是避免梯度爆炸而采用的clip gradient方法。最开始设定的grad_clip = 5,具体的操作推荐看这个文章:梯度爆炸的解决办法:clip gradient_u010814042的博客-CSDN博客_clip gradient
等都训练完毕了,就会打印出当前epoch以及当前数据item训练的时候的损失函数结果。到了这里,train函数就大功告成。其实看到这里valid函数也就不用多看了,步骤少了很多(懒得写了)。



