import torch
from model import Model
# train for one epoch to learn unique features
def train(net, data_loader, train_optimizer):
net.train()
total_loss, total_num, train_bar = 0.0, 0, tqdm(data_loader)
for pos_1, pos_2, target in train_bar:
pos_1, pos_2 = pos_1.cuda(non_blocking=True), pos_2.cuda(non_blocking=True)
feature_1, out_1 = net(pos_1)
feature_2, out_2 = net(pos_2)
# [2*B, D]
out = torch.cat([out_1, out_2], dim=0)
# [2*B, 2*B]
sim_matrix = torch.exp(torch.mm(out, out.t().contiguous()) / temperature)
mask = (torch.ones_like(sim_matrix) - torch.eye(2 * batch_size, device=sim_matrix.device)).bool()
# [2*B, 2*B-1]
sim_matrix = sim_matrix.masked_select(mask).view(2 * batch_size, -1)
# compute loss
pos_sim = torch.exp(torch.sum(out_1 * out_2, dim=-1) / temperature)
# [2*B]
pos_sim = torch.cat([pos_sim, pos_sim], dim=0)
loss = (- torch.log(pos_sim / sim_matrix.sum(dim=-1))).mean()
train_optimizer.zero_grad()
loss.backward()
train_optimizer.step()
total_num += batch_size
total_loss += loss.item() * batch_size
train_bar.set_description('Train Epoch: [{}/{}] Loss: {:.4f}'.format(epoch, epochs, total_loss / total_num))
return total_loss / total_num
————————————————
版权声明:本文为CSDN博主「xxxxxxxxxx13」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。
原文链接:https://blog.csdn.net/xxxxxxxxxx13/article/details/110820373