栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

【无标题】

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

【无标题】

pytorch中计算网络模型的参数量和Flops@TOC

#1、参数量的计算

(1)from thop import profile

if __name__ == '__main__':
    import torch
    from thop import profile
    model = Net()
    input = torch.randn(1, 3, 224, 224)
    flops, params = profile(model, (input,))
    print('flops: ', flops, 'params: ', params)#直接print输出的是个数
    # 转换后的M表示百万个,不是存储单位,G表示每秒10亿 (=10^9) 次的浮点运算,
    print('》》》》》》》》》》》Flops:', str(flops / 1000 ** 3) + 'G')
    print('》》》》》》》》》》》Params:', str(params / 1000 ** 2) + 'M')

Net()是自己搭建的网络模型
(2)from torchstat import stat

from torchstat import stat
if __name__ == '__main__':
# 导入模型,输入一张输入图片的尺寸
    model = Net()
    stat(model, (3, 224, 224))

Net()是自己搭建的网络模型
(3)from torchinfo import summary

from torchinfo import summary
if __name__ == '__main__':
# 导入模型,输入一张输入图片的尺寸
    batch_size = 1
    model = Net()
    summary(model, input_size=(batch_size, 3, 224, 224))

Net()是自己搭建的网络模型

转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/879385.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号