栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

怎样判断一个操作是否是可导的(Pytorch)

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

怎样判断一个操作是否是可导的(Pytorch)

在Pytorch中,看一个操作是否可导,即经过这个操作梯度是否还能顺利传递。

 可以看到,经过+操作后得到的z,仍能保持梯度的传递

而像torch.argmax(),  torch.eq() 这些操作就不行了,这些操作就是不可导的

即,遇到不可导的,你反向传播都会出问题,程序自己就会报错

 

而如果像soft argmax

import torch
import torch.nn as nn

def soft_argmax(x):
	"""
	Arguments: voxel patch in shape (batch_size, channel, H, W, depth)
	Return: 3D coordinates in shape (batch_size, channel, 3)
	"""
	# alpha is here to make the largest element really big, so it
	# would become very close to 1 after softmax
	alpha = 10000.0 
	N,C,L = x.shape
	soft_max = nn.functional.softmax(x*alpha,dim=2)
	soft_max = soft_max.view(x.shape)
	indices_kernel = torch.arange(start=0, end=L).unsqueeze(0)
	# indices_kernel = indices_kernel.view((H,W,D))
	# indices_kernel = indices_kernel.view(H,W)
	conv = soft_max*indices_kernel
	indices = conv.sum(2)
	# z = indices%D
	# y = (indices).floor()%W
	# x = (((indices).floor())/W).floor()%H
	# coords = torch.stack([x,y,z],dim=2)
	# coords = torch.stack([x,y],dim=2)
    #coords[0][0]代表第一个channel的最大点的坐标值
    #coords[0][1]代表第2个channel的最大点的坐标值
	return indices
 
if __name__ == "__main__":
	x = torch.randn(1024,16,35*35,requires_grad=True) # (batch_size, channel, H, W, depth)
	coords = soft_argmax(x)
    #coords是[b,c,2]
	print(coords)

 操作就是可导的

转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/725375.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号