使用方法很简单,只需要在utils中的general.py或者plots.py添加如下函数:
import matplotlib.pyplot as plt
from torchvision import transforms
def feature_visualization(features, model_type, model_id, feature_num=64):
"""
features: The feature map which you need to visualization
model_type: The type of feature map
model_id: The id of feature map
feature_num: The amount of visualization you need
"""
save_dir = "features/"
if not os.path.exists(save_dir):
os.makedirs(save_dir)
# print(features.shape)
# block by channel dimension
blocks = torch.chunk(features, features.shape[1], dim=1)
# # size of feature
# size = features.shape[2], features.shape[3]
plt.figure()
for i in range(feature_num):
torch.squeeze(blocks[i])
feature = transforms.ToPILImage()(blocks[i].squeeze())
# print(feature)
ax = plt.subplot(int(math.sqrt(feature_num)), int(math.sqrt(feature_num)), i+1)
ax.set_xticks([])
ax.set_yticks([])
plt.imshow(feature)
# gray feature
# plt.imshow(feature, cmap='gray')
# plt.show()
plt.savefig(save_dir + '{}_{}_feature_map_{}.png'
.format(model_type.split('.')[2], model_id, feature_num), dpi=300)
接着在models中的yolo.py中添加如下代码:
def forward_once(self, x, profile=False):
y, dt = [], [] # outputs
for m in self.model:
if m.f != -1: # if not from previous layer
x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f] # from earlier layers
if profile:
o = thop.profile(m, inputs=(x,), verbose=False)[0] / 1E9 * 2 if thop else 0 # FLOPS
t = time_synchronized()
for _ in range(10):
_ = m(x)
dt.append((time_synchronized() - t) * 100)
print('%10.1f%10.0f%10.1fms %-40s' % (o, m.np, dt[-1], m.type))
x = m(x) # run
y.append(x if m.i in self.save else None) # save output
# 特征图可视化 添加到此处
'''
feature_vis = True
if m.type == 'models.common.SPP' and feature_vis:
print(m.type, m.i)
feature_visualization(x, m.type, m.i)
'''
if profile:
print('%.1fms total' % sum(dt))
还要在在yolo.py的开头加入下面函数:
from utils.general import feature_visualization
如果你想要某个结构的特征图,可以在图片这个地方进行修改
比如输出最后三个检测层的特征图(即配置文件里的17,20,23层)
把这行代码 if m.type == 'models.common.C3' and feature_vis: 替换为 if m.i=='17, 20, 23' and feature_vis:
小白学习中,此笔记纯属学习笔记使用,若有侵权,请联系我删除



