hist = torch.zeros(n_classes, n_classes).cuda().detach()
for i, (imgs, label) in diter:
N, _, H, W = label.shape
label = label.squeeze(1).cuda()
size = label.size()[-2:]
imgs = imgs.cuda()
##########################################
# logits = YourNetwork(imgs)[0] #
##########################################
probs = torch.softmax(logits, dim=1)
preds = torch.argmax(probs, dim=1)
keep = label != self.ignore_label
binc = torch.bincount(label[keep] * n_classes + preds[keep], minlength=n_classes ** 2)
binc = binc.view(n_classes, n_classes).float()
hist += binc
if dist.is_initialized():
dist.all_reduce(hist, dist.ReduceOp.SUM)
ious = hist.diag() / (hist.sum(dim=0) + hist.sum(dim=1) - hist.diag())
miou = ious.mean()