栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

特征图可视化:可解释的深度学习模型(Pytorch)

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

特征图可视化:可解释的深度学习模型(Pytorch)


定义钩子函数

import torchvision.utils as vutil
import cv2
def hook_func(module, input, output):
    """
    Hook function of register_forward_hook

    Parameters:
    -----------
    module: module of neural network
    input: input of module
    output: output of module
    """
    image_name = get_image_name_for_hook(module)
    data = output.clone().detach()
    # data = data.permute(1, 0, 2, 3)
    # vutil.save_image(data, image_name, pad_value=0.5) # 这保存的是每个通道捕捉的语义

    data = data.permute(1,0,2,3).cpu().squeeze()
    pic = (np.mean(data.numpy(),axis=0)*255).astype(np.uint8)
    feature=cv2.resize(pic,(512,512))
    # 根据图像的像素值中最大最小值,将特征图的像素值归一化到了[0,1];
    feature = (feature - np.amin(feature))/(np.amax(feature) - np.amin(feature) + 1e-5) # 注意要防止分母为0! 
    feature = np.round(feature * 255) # [0, 1]——[0, 255],为cv2.imwrite()函数而进行
    feature = cv2.applyColorMap(np.array(feature,np.uint8),2) # 给特征图个颜色  热力图
    cv2.imwrite(image_name,feature)

INSTANCE_FOLDER = "VIS_results"
def get_image_name_for_hook(module):
    """
    Generate image filename for hook function

    Parameters:
    -----------
    module: module of neural network
    """
    os.makedirs(INSTANCE_FOLDER, exist_ok=True)
    base_name = str(module).split('(')[0]
    index = 0
    image_name = '.'  # '.' is surely exist, to make first loop condition True
    while os.path.exists(image_name):
        index += 1
        image_name = os.path.join(
            INSTANCE_FOLDER, '%s_%d.png' % (base_name, index))
    return image_name

在验证处嵌入如下定义

	with torch.no_grad():
        # modules_for_plot = (torch.nn.ReLU, torch.nn.Conv2d,
        #                 torch.nn.MaxPool2d, torch.nn.AdaptiveAvgPool2d)
        names_for_plot = ('module.classifier.fusion','module.classifier.context','module.classifier.context.2','module.classifier.context.2.aspp')
        for name, module in model.named_modules():
            # if isinstance(module, modules_for_plot):
            if name in names_for_plot:
                module.register_forward_hook(hook_func)

        for i, (images, labels) in tqdm(enumerate(loader)):
            if i>=20:
                break

部分参照:https://blog.csdn.net/bby1987/article/details/109590108

转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/740090.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号