栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

Pytorch 计算模型的FLOPs和参数量

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

Pytorch 计算模型的FLOPs和参数量

安装:pip install ptflops

单独使用:

import torch
from ptflops import get_model_complexity_info
flops, params = get_model_complexity_info(model, [1,32,32], as_strings=True, print_per_layer_stat=True)
print(flops, params)

批量处理:

import torch, os
from ptflops import get_model_complexity_info


class Cal_Params():
    def __init__(self, model_name, device='cuda'):
        self.model_name = model_name
        self.path = r'models/{}'.format(model_name)
        self.model = get_model(self.path).to(torch.device(device))
        self.input_size = (1, self.model.size, self.model.size)

    def get_params(self, save_file, verbose=True):
        filepath = os.path.join(self.path, 'params.txt')
        f = open(filepath, 'w')
        flops, params = get_model_complexity_info(self.model, self.input_size, as_strings=True,
                                                  print_per_layer_stat=True, ost=f)
        display('%9s | %11s | %9s' % (self.model_name, flops, params), file=save_file, verbose=verbose)


def display(string, file=None, verbose=True):
    if file != None:
        print(string, file=file)
    if verbose:
        print(string)
        # devnull = open(os.devnull, 'w')
        # print(string, file=devnull)
        # devnull.close()


if __name__ == '__main__':
    save_file = open('all_params.txt', 'w')
    model_names = ['model1', 'model2', 'model3']
    losses = ['L1', 'L2']
    display('%9s | %11s | %9s' % ('Model', 'FLOPs', 'Params'), file=save_file)
    try:
        for model_name in model_names:
            cp = Cal_Params(model_name)
            cp.get_params(save_file)
    except Exception as e:
        print('Error: {}'.format(e))
    finally:
        save_file.close()

转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/348474.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号