1.clone()主要用于模块复用 数据进行复制,不共享同一内存,梯度可以回溯
c=torch.tensor(1.0,requires_grad=True) b=c*2 d=b**2 (**) b_=b.clone() e_=b_**3 e_.backward(retain_graph=True) """ b.zero_() 这里的b是d.backward()的回溯节点(**),在回溯前不能进行in place 操作, 目的保证梯度计算正确,但如果是b_.zero_()就不会报错,因为clone不共享内存 """ d.backward() print(c.grad) #tensor(32.)
这里单独查看b_.grad或者b.grad都不存在,因为他们是中间变量,不需要保存,更新也是只更新叶子节点,此外要设置retain_graph=True,因为有一条线路上先进行了梯度回溯,为节省显存计算图会释放。
2.detach()主要用于数据的提取,共享同一内存,强制require_grad=False(即使设置为True也不进行梯度回溯)
c=torch.tensor(1.0,requires_grad=True) b=c*2 w=b**2 b_=b.detach() q=torch.tensor(1.0,requires_grad=True) e_=q**b_ e_.backward() #b_.zero_() 因为detach共享内存,这里进行in palce操作会报错 w.backward() print(q.grad) #tensor(2.)



