栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

用einops直观任性操作Tensor,解决Patch Embedding问题

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

用einops直观任性操作Tensor,解决Patch Embedding问题

首先看一下原图:

是一张jpg格式,512x512分辨率的图像,解码为RGB格式时,shape为[3, 512, 512]

导入einops相关函数

from einops import rearrange, reduce, repeat

常用的就是这三个了,文末有官方教程地址,可全面学习,解决Transform中的第一步的Patch Embedding,rearrange(重新排列,重新整理)就足够了

先增加一个b维度

img = rearrange(img, 'c h w -> 1 c h w')
# print(img.shape) # torch.Size([1, 3, 512, 512])

img就是上图,'c h w’对应你数据最开始的shape,'1 c h w’对应你想要的shape,增加一个维度的话,直接在前面加个1,完事

开始分割成Patch并重新排列

img = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=256, p2=256)
# print(img.shape) # torch.Size([1, 4, 196608])

原来的’b c h w’,现在要在’h w’维度开始分割patch,p1,p2就是patch的大小,因为我原图为512x512,这里简单起见,我就直接分成4块,那每一块就是256x256,这个在后面p1=256,p2=256指定,指定后,它会自动计算h w变为多少,这里肯定是2了,最后转换的shape,因为要送到神经网络中去,shape应该为[b, d1, d2],一个三维的tensor,这里就把分割好的h w重新放到括号里,p1,p2,c作为patch,按照经典思路,打平即可

这样就获得了我们想要的结果,shape:[1, 4, patch打平后的数值],4是patch个数

这样就解决了当分辨率增大时,采用通道数不变,将图像空间所有打平,数据太大的问题

下面是可视化代码,想看看它是不是按照我们的想法分割为patch了,直接放全部代码和结果了

# 接上面操作,看shape变化
img = reduce(img, 'b n p2 -> n p2', 'mean')
# print(img.shape) # torch.Size([4, 196608])

img = rearrange(img, 'n (p1 p2 c) -> n p1 p2 c', p1=256, p2= 256, c=3)
# print(img.shape) # torch.Size([4, 256, 256, 3])

# 画图,因为我是把图像ndarray转成tensor操作的,
# 所以要显示,就把它转换回numpy
fig, ax = plt.subplots(2, 2, figsize=(5, 5))
ax[0][0].imshow(img[0].data.numpy(), cmap='gray')
ax[0][1].imshow(img[1].data.numpy(), cmap='gray')
ax[1][0].imshow(img[2].data.numpy(), cmap='gray')
ax[1][1].imshow(img[3].data.numpy(), cmap='gray')
plt.show()

结果如下:

官方教程:
官方教程

转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/655564.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号