在使用mse损失函数进行对抗攻击的时候 loss反向传播一直报错 最终改成如下形式才能够正常运行
loss11 F.mse_loss(logits, logits_target, reduction none ).sum(axis 1) loss12 F.mse_loss(logits, logits_true, reduction none ).sum(axis 1) loss1 4 * loss11 - loss12 loss torch.mean(loss) optimizer.zero_grad() loss.backward()
同时 如果使用交叉熵损失函数则不需要reduction none及torch.mean操作。
原因暂时未知



