环境:PyTorch 1.7.1
问题描述:在生成对抗样本的情境中,常常需要对一个对象(比如对抗扰动)进行多次的反向传播更新,例如下述代码段:
def attack_update(self, perturbation, x, y, model, device):
x = x.to(device)
y = y.to(device)
model = model.to(device)
if perturbation == None:
perturbation = torch.zeros_like(x[0])
perturbation = perturbation.to(device)
perturbation.requires_grad = True
with torch.enable_grad():
# turn adv into batches
adv = perturbation + x
adv = torch.clamp(adv, min=0.0, max=1.0)
pred = model(adv)
loss = torch.sum(F.cross_entropy(F.softmax(pred, dim=1), y))
grads = torch.autograd.grad(loss, perturbation, grad_outputs=None, only_inputs=True)[0]
if self.optimizer == 'sgd':
perturbation = perturbation + self.lr * grads
perturbation = torch.clamp(perturbation, min=-self.eps, max=self.eps)
return perturbation
但是这样的写法会造成一个问题:函数返回值是同一个对象perturbation,该对象在一开始计算时被设置requires_grad = True,在以上函数运行完一遍之后,计算图变成:perturbation -> adv -> pred -> loss -> perturbation. 也就是说,完成一遍更新之后,perturbation不再是叶子节点了,因此报错。
解决办法:很简单,只需返回perturbation时脱离计算图即可,这样在下一次重新进入函数开始新一轮计算时就不会对计算图造成混乱。
return perturbation.detach()



