- tensorboard
- tensorboard使用
- pytorch模型保存与加载
- 断点续训练
可视化工具:tensorboard。
支持标量、图像、文本、音频、视频和embedding等多种数据可视化。
运行机制:python脚本中记录可视化的数据–>将需要的数据存储在硬盘中–>tensorboard在终端读取数据进行可视化。
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter(comment="test tensorboard")
for x in range(100):
writer.add_scalar('y=2x', x*2, x)
writer.close()
读取文件,在pycharm终端中进行读取
tensorboard --logdir=./runstensorboard使用
SummaryWriter( log_dir=None, comment='', filename_suffix='') # 提供创建event file的高级接口 # log_dir:event file输出文件夹 # comment:不指定log_dir时,文件夹后缀 # filename_shuffix:event file文件名后缀 # log_dir有设置值时,comment不会其作用。
- 记录标量
add_scalar( tag, scalar_value, global_step=None, walltime=None) # tag: 标签名,唯一标识 # scalar_value:要记录的标量 # global_step:x轴 add_scalars( main_tag, tag_scalar_dict, global_step=None, walltime=None) # main_tag:标签名 # tag_scalar_dict:key是变量的tag,value是变量的值
- 统计直方图与多分位数折线图
add_histogram( tag, values, global_step=None, bins='tensorflow', walltime=None) # tag:标签名 # values:要统计的参数 # global_step:y轴 # bins:取直方图的bins
- 记录图像
add_image( tag, img_tensor, global_step=None, walltime=None, dataformats='CHW') # tag:标签名 # img_tensor:图像数据,注意尺度 # global_step:x轴 # dataformats:数据形式,CHW HWC HW
- 制作网格图像
torchvision.utils.make_grid( tensor, nrow=8, padding=2, normalize=False, range=None, scale_each=False, pad_value=0) # tensor:图像数据 B*C*H*W形式 # nrow:行数 # padding:图像间距(像素单位) # normalize:是否将像素值标准化 # range:标准化范围 # scale_each:是否单张图维度标准化 # pad_value:padding的像素值
- 可视化模型计算图
add_graph( model, input_to_model=None, verbose=Fasle) # model:模型 必须是nn.Module # input_to_model:输出给模型的数据 # verbose是否打印计算图结构信息 torchsummary( model, input_size, batch_size=-1, device="cuda") # model:pytorch模型 # input_size:模型输入size # batch_size:batch_size # device: cuda cpu 使用 from torchsummary import summarypytorch模型保存与加载
序列化与反序列化
序列化:将内存中的模型转变为二进制的数存储在硬盘中
反序列化:将存储在硬盘中的二进制数转变为模型至内存中
torch.save(obj, f) # obj:对象 f:输出路径 torch.load(f, map_location) # f:文件路径 map_location指定cpu/gpu
保存模型的两种方法
# 1 保存整个模型 torch.save(model, path) # 2 保存模型参数 字典型(键值对) state_dict = model.state_dict() torch.save(state_dict, state_dict_path) # 读取模型参数 state_dict_load = torch.load(state_dict_path) model_n.load_state_dict(state_dict_load)断点续训练



