栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

34 - Swin-Transformer论文精讲及其PyTorch逐行复现

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

34 - Swin-Transformer论文精讲及其PyTorch逐行复现

文章目录
  • 1. 两种方法实现Patch_Embedding
  • 2. 多头自注意力(Multi_Head_Self_Attention)

1. 两种方法实现Patch_Embedding
import torch
from torch.nn import functional as F



# method_1 : using unfold to achieve the patch_embedding
# step_1: unfold the image
# step_2: unfold_output@weight
def image2embed_naive(image, patch_size, weight):
	"""
	:param image: [bs,in_channel,height,width]
	:param patch_size:
	:param weight : weight.shape=[patch_depth=in_channel*patch_size*patch_size,model_dim_C]
	:return: patch_embedding,it shape is [batch_size,num_patches,model_dim_C]
	"""

	# patch_depth = in_channel*patch_size*patch_size
	# image_output.shape = [batch_size,num_patch,patch_depth=in_channel*patch_size*patch_size]
	image_output = F.unfold(image, kernel_size=(patch_size, patch_size),
							stride=(patch_size, patch_size)).transpose(-1, -2)

	# change the final_channel dimension from patch_depth to model_dim_C
	patch_embedding = image_output @ weight

	return patch_embedding



# using F.conv2d to achieve the patch_embedding
def image2conv(image, weight, patch_size):
	# image =[batch_size,in_channel,height,width]
	# weight = [out_channels,in_channels,kernel_h,kernel_w]
	conv_output = F.conv2d(image, weight=weight, stride=patch_size)
	bs, oc, oh, ow = conv_output.shape
	patch_embedding = conv_output.reshape(bs, oc, oh * ow).transpose(-1,-2)

	return patch_embedding


batch_size = 1
in_channel = 2
out_channel = 5
height = 3
width = 4
input = torch.randn(batch_size, in_channel, height, width)

patch_size = 2

weight1_depth = in_channel * patch_size * patch_size

weight1_model_c = out_channel

weight1 = torch.randn(weight1_depth,weight1_model_c)

weight2_out_channel = weight1_model_c


weight2 = weight1.transpose(0,1).reshape(weight1_model_c,in_channel,patch_size,patch_size)

output1 = image2embed_naive(input, patch_size, weight1)

output2 = image2conv(input, weight2, patch_size)


# flag the check output1 is the same for output2
# if flag is true ,they are the same
flag = torch.isclose(output1,output2)
print(f"flag={flag}")
print(f"output1={output1}")
print(f"output2={output2}")
print(f"output1.shape={output1.shape}")
print(f"output2.shape={output2.shape}")
2. 多头自注意力(Multi_Head_Self_Attention)
转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/883054.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号