最近遇到了一个很基础的问题,就是pytorch中的dropout在面对一个n维的矩阵时,是会随机drop某一行、或者某一维上的一个向量,还是某一个元素呢?用试验稍微验证了下
import torch m = torch.nn.Dropout(p=0.2) input = torch.randn(5, 5) output = m(input)
input为
tensor([[-0.2266, -0.6318, -0.1628, -0.0301, -1.2603],
[-2.3018, 0.7649, 1.3658, -0.6601, 0.1574],
[-0.7697, 0.1300, -1.9488, 0.9426, -0.2315],
[ 0.9873, 0.7713, 1.3725, -0.6127, 0.2403],
[ 0.5574, 1.3104, -0.1863, 0.9430, -0.3442]])
output为
tensor([[-0.2832, -0.7898, -0.2035, -0.0376, -0.0000],
[-2.8772, 0.9561, 1.7072, -0.8251, 0.1968],
[-0.9622, 0.1626, -2.4360, 1.1782, -0.2893],
[ 1.2341, 0.9641, 1.7156, -0.7659, 0.3003],
[ 0.6968, 1.6380, -0.2329, 1.1788, -0.4302]])
所以,dropout的操作是,等可能性的在每一个元素上进行drop的,而不是drop一行或者一列。output中保留元素在数值上发生变化是因为:在官方实现里,保留的每个元素乘上了 1 1 − p frac{1}{1-p} 1−p1
对一个三维矩阵dropout:
input = torch.randn(5, 5, 5) output = m(input)
input为
tensor([[[-0.7328, -0.7222, 1.3446, -0.0297, 2.2102],
[-0.2430, 0.7217, -0.0666, 1.2258, 0.0429],
[-1.5523, -0.5620, -2.0723, 0.0276, 0.6448],
[-0.2670, -0.5384, -0.3922, 0.9496, -0.5713],
[-1.4619, 0.7644, 0.9520, -0.1513, -0.2254]],
[[ 0.0849, -2.3184, -0.1957, 0.1423, -0.5450],
[-0.1516, -0.7531, 1.7559, 0.9391, 0.9862],
[ 1.0796, 0.5263, -1.2996, -0.8324, 0.8036],
[ 0.6910, -1.5230, 1.4050, 1.6258, -0.9268],
[-3.4878, 0.7519, -0.6886, 0.0373, 1.1346]],
[[-1.7822, -0.8222, 1.5164, 2.6609, 0.2105],
[ 1.4420, 0.5239, 1.7459, -0.7058, 0.7801],
[ 1.4009, 0.4207, 0.3810, -0.7673, 0.8988],
[ 0.3676, 1.1389, 0.9903, -0.6542, -0.7219],
[ 0.2029, -1.2014, -0.0530, 0.6527, -0.4523]],
[[ 2.1250, 1.6481, -0.0844, -0.3846, -0.0184],
[-1.5917, 0.8343, -0.2303, 0.7135, -1.2380],
[ 0.6418, -0.3958, 0.3442, -1.9055, 0.8149],
[-0.6485, 1.9034, 0.9087, 1.5293, -0.6239],
[-0.2155, -0.0097, 3.0727, -0.0537, -0.9891]],
[[-0.4705, 0.4225, 0.2590, -0.6873, 0.8627],
[ 1.0333, -0.9247, 0.8617, -0.4957, -0.3247],
[-1.1925, -0.5471, 0.9413, -0.5192, 0.2911],
[ 2.4236, -1.5812, 0.6202, -1.8823, 0.1273],
[-0.4075, 0.0873, -0.6647, -1.2492, 1.0147]]])
output为
tensor([[[-0.0000, -0.9027, 1.6808, -0.0371, 0.0000],
[-0.0000, 0.9021, -0.0833, 1.5322, 0.0536],
[-1.9404, -0.7025, -0.0000, 0.0345, 0.8060],
[-0.3338, -0.6730, -0.0000, 0.0000, -0.7141],
[-1.8274, 0.9555, 1.1901, -0.0000, -0.2817]],
[[ 0.1061, -2.8980, -0.2447, 0.0000, -0.6813],
[-0.1895, -0.9414, 2.1948, 1.1738, 1.2328],
[ 1.3494, 0.6579, -0.0000, -1.0405, 1.0044],
[ 0.8637, -1.9038, 1.7562, 2.0322, -1.1585],
[-4.3598, 0.9398, -0.0000, 0.0466, 1.4183]],
[[-2.2277, -1.0278, 1.8955, 3.3262, 0.2631],
[ 1.8025, 0.6548, 2.1824, -0.0000, 0.9751],
[ 0.0000, 0.5259, 0.4763, -0.0000, 1.1235],
[ 0.4594, 0.0000, 1.2379, -0.8177, -0.9023],
[ 0.2536, -1.5017, -0.0663, 0.8159, -0.5653]],
[[ 2.6562, 0.0000, -0.1055, -0.0000, -0.0230],
[-0.0000, 1.0429, -0.2879, 0.8918, -1.5475],
[ 0.8022, -0.4948, 0.4302, -0.0000, 1.0186],
[-0.8106, 2.3793, 1.1359, 1.9116, -0.7799],
[-0.2694, -0.0122, 0.0000, -0.0672, -0.0000]],
[[-0.5882, 0.0000, 0.3237, -0.8591, 1.0784],
[ 1.2917, -1.1558, 1.0771, -0.6197, -0.4059],
[-1.4907, -0.0000, 1.1766, -0.6490, 0.3639],
[ 0.0000, -1.9765, 0.7752, -2.3528, 0.0000],
[-0.5093, 0.1091, -0.8309, -1.5615, 1.2683]]])
同样,不是drop某一行,或者分散的drop每一个元素



