栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

深度学习:可视化-结果loss acc可视化及测试数据显示

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

深度学习:可视化-结果loss acc可视化及测试数据显示

文章预览:
  • 可视化train,test的loss acc 案例:交通指示牌识别案例-history数组
  • 可视化测试结果

可视化train,test的loss acc 案例:交通指示牌识别案例-history数组

代码地址:E:项目例程猫狗分类迁移学习猫狗_resnet18_2 猫狗分类_迁移学习可视化

  1. 导入库
 from collections import defaultdict
  1. 训练函数中构建一个默认value为list的字典
history = defaultdict(list)  # 构建一个默认value为list的字典
  1. 训练函数中保存train_loss,train_acc,test_loss,test_acc结果
history['train_acc'].append(train_accuracy)
        history['train_loss'].append(train_loss)
        history['val_acc'].append(val_accuracy)
        history['val_loss'].append(val_loss)
  1. 训练函数返回
return model, history
  1. 训练模型调用时
# 调用训练函数训练
model_conv, history = train_model(
    model_conv,
    criterion,
    optimizer_conv,
    exp_lr_scheduler,
    num_epochs=30
)
  1. 绘制函数 两张图,每个图两个曲线,写法固定
# 绘制 loss, acc  写法固定:两张表
def plot_training_history(history):
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(18, 6))
    ax1.plot(history['train_loss'], label='train loss')
    ax1.plot(history['val_loss'], label='val loss')

    ax1.set_ylim([-0.05, 1.05])
    ax1.legend()
    ax1.set_ylabel('Loss')
    ax1.set_xlabel('Epoch')

    ax2.plot(history['train_acc'], label='train acc')
    ax2.plot(history['val_acc'], label='val acc')

    ax2.set_ylim([-0.05, 1.05])
    ax2.legend()
    ax2.set_ylabel('Accuracy')
    ax2.set_xlabel('Epoch')

    fig.suptitle('Training History')

plot_training_history(history)

结果曲线展示:

可视化测试结果

代码地址:E:项目例程猫狗分类迁移学习猫狗_resnet18_2 猫狗分类_迁移学习可视化

  1. 定义结果可视化函数
    注意:plt.subplot(4,4,i+1) 应根据batch_size修改,本案例batch_size=16.故为4x4
# 测试结果可视化函数
def visualize_model(model):
    model.eval()
    with torch.no_grad():
        inputs, labels = next(iter(dataloaders['val']))
        inputs = inputs.to(device)
        labels = labels.to(device)

        outputs = model(inputs)
        preds = outputs.argmax(1)

        plt.figure(figsize=(9, 9))
        for i in range(inputs.size(0)):
            plt.subplot(4,4,i+1)  #根据batch_size修改
            plt.axis('off')
            plt.title(f'pred: {class_names[preds[i]]}|true: {class_names[labels[i]]}')
            im = no_normalize(inputs[i].cpu())
            plt.imshow(im)
        plt.savefig('train.jpg')
        plt.show()
  1. 调用函数
# 测试结果可视化
visualize_model(model_conv)

效果展示:

转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/293928.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号