简单来说就是增加一个新的维度,在这个新的维度上进行torch.cat()操作。
import torch a = torch.randn(3,4) # tensor([[-0.0974, -1.3577, -0.5162, -0.9748], # [-1.0509, -0.7450, -0.7226, -1.6895], # [-0.7616, 1.0055, 0.5779, -2.0157]]) b = torch.randn(3,4) # tensor([[-0.3454, 1.2769, -0.3882, -1.4049], # [-0.3809, -1.2949, -0.6149, 1.1036], # [ 0.9674, 1.2621, 1.7883, -0.7552]]) c = torch.cat((a, b), dim=0) # tensor([[-0.0974, -1.3577, -0.5162, -0.9748], # [-1.0509, -0.7450, -0.7226, -1.6895], # [-0.7616, 1.0055, 0.5779, -2.0157], # [-0.3454, 1.2769, -0.3882, -1.4049], # [-0.3809, -1.2949, -0.6149, 1.1036], # [ 0.9674, 1.2621, 1.7883, -0.7552]]) c.shape # torch.Size([6, 4]) d = torch.stack((a,b), dim=0) # tensor([[[-0.0974, -1.3577, -0.5162, -0.9748], # [-1.0509, -0.7450, -0.7226, -1.6895], # [-0.7616, 1.0055, 0.5779, -2.0157]], # [[-0.3454, 1.2769, -0.3882, -1.4049], # [-0.3809, -1.2949, -0.6149, 1.1036], # [ 0.9674, 1.2621, 1.7883, -0.7552]]]) d.shape # torch.Size([2, 3, 4])



