pytorch中计算网络模型的参数量和Flops@TOC
#1、参数量的计算
(1)from thop import profile
if __name__ == '__main__':
import torch
from thop import profile
model = Net()
input = torch.randn(1, 3, 224, 224)
flops, params = profile(model, (input,))
print('flops: ', flops, 'params: ', params)#直接print输出的是个数
# 转换后的M表示百万个,不是存储单位,G表示每秒10亿 (=10^9) 次的浮点运算,
print('》》》》》》》》》》》Flops:', str(flops / 1000 ** 3) + 'G')
print('》》》》》》》》》》》Params:', str(params / 1000 ** 2) + 'M')
Net()是自己搭建的网络模型
(2)from torchstat import stat
from torchstat import stat
if __name__ == '__main__':
# 导入模型,输入一张输入图片的尺寸
model = Net()
stat(model, (3, 224, 224))
Net()是自己搭建的网络模型
(3)from torchinfo import summary
from torchinfo import summary
if __name__ == '__main__':
# 导入模型,输入一张输入图片的尺寸
batch_size = 1
model = Net()
summary(model, input_size=(batch_size, 3, 224, 224))
Net()是自己搭建的网络模型



