代码实现:
def forward(self, x):
x = DiffAugment(x, self.diff_aug) # 30,3,32,32
b = x.shape[0] # 30
x = self.patches(x) # 30,64,384
cls_token = self.class_embedding.expand(b, -1, -1) # 30,1,384
x = torch.cat((cls_token, x), dim=1) # 30,65,384
x += self.positional_embedding
x = self.droprate(x)
x = self.TransfomerEncoder(x)
x = self.norm(x)
x = self.out(x[:, 0])
return x # 30,1



