栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

2021SC@SDUSC山东大学软件学院软件工程应用与实践--YOLOV5代码分析(十)plots.py-2

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

2021SC@SDUSC山东大学软件学院软件工程应用与实践--YOLOV5代码分析(十)plots.py-2

2021SC@SDUSC

目录

前言

plot_lr_scheduler函数

 plot_val_txt函数

plot_targets_txt函数

plot_val_study函数

plot_labels函数

profile_idetection函数

plot_evolve函数

plot_results函数

feature_visualization函数

总结


前言

这篇继续分析plots.py的代码

plot_lr_scheduler函数
def plot_lr_scheduler(optimizer, scheduler, epochs=300, save_dir=''):
    # Plot LR simulating training for full epochs
    optimizer, scheduler = copy(optimizer), copy(scheduler)  # do not modify originals
    y = []
    for _ in range(epochs):
        scheduler.step()
        y.append(optimizer.param_groups[0]['lr'])
    plt.plot(y, '.-', label='LR')
    plt.xlabel('epoch')
    plt.ylabel('LR')
    plt.grid()
    plt.xlim(0, epochs)
    plt.ylim(0)
    plt.savefig(Path(save_dir) / 'LR.png', dpi=200)
    plt.close()

optimizer:优化器

scheduler:学习率调整器

epochs:训练迭代次数

save_dir:保存路径

该函数将模拟整个训练过程,记录下每个epoch的学习率,并画在图上并保存到save_dir

copy函数将optimizer和scheduler复制一份,防止改变原来的值

接下来遍历epoch,scheduler执行一步,调整一下学习率,并将每个epoch的学习率添加到列表末尾。

最后将学习率的调整过程画在图上并保存下来。

import torchvision
model=torchvision.models.resnet50(pretrained=False)
opt=torch.optim.Adam(params=model.parameters(),lr=0.01)
scheduler=torch.optim.lr_scheduler.CosineAnnealingLR(opt,100)
plot_lr_scheduler(opt,scheduler,100,'.')

执行以上代码后将会在当前目录下得到一张LR.png的图片

 可以看到我们得到了100个epoch的以cos退火的学习率变化过程的图片。

 plot_val_txt函数
def plot_val_txt():  # from utils.plots import *; plot_val()
    # Plot val.txt histograms
    x = np.loadtxt('val.txt', dtype=np.float32)
    box = xyxy2xywh(x[:, :4])
    cx, cy = box[:, 0], box[:, 1]
 
    fig, ax = plt.subplots(1, 1, figsize=(6, 6), tight_layout=True)
    ax.hist2d(cx, cy, bins=600, cmax=10, cmin=0)
    ax.set_aspect('equal')
    plt.savefig('hist2d.png', dpi=300)
 
    fig, ax = plt.subplots(1, 2, figsize=(12, 6), tight_layout=True)
    ax[0].hist(cx, bins=600)
    ax[1].hist(cy, bins=600)
    plt.savefig('hist1d.png', dpi=200)

该函数比较简单,就是将val.txt这个文本文件读取成numpy格式,val.txt原本保存的是验证时的输出结果,数据是(x1,y1,x2,y2,p,c),读出前四个,也就是预测框的位置,将其转换为xywh格式,xy是预测框中心的坐标,cx和cy分别是预测框中心点的横坐标和纵坐标。

接下来分别以1维和2维的形式画出直方图并保存下来。

由于项目没有val.txt文件,我的电脑配置也跑不了这个项目,因此就无法运行一下将图片打开来看了。

plot_targets_txt函数
def plot_targets_txt():  # from utils.plots import *; plot_targets_txt()
    # Plot targets.txt histograms
    x = np.loadtxt('targets.txt', dtype=np.float32).T
    s = ['x targets', 'y targets', 'width targets', 'height targets']
    fig, ax = plt.subplots(2, 2, figsize=(8, 8), tight_layout=True)
    ax = ax.ravel()
    for i in range(4):
        ax[i].hist(x[i], bins=100, label='%.3g +/- %.3g' % (x[i].mean(), x[i].std()))
        ax[i].legend()
        ax[i].set_title(s[i])
    plt.savefig('targets.jpg', dpi=200)

这个函数与上面一个基本无差,不过这个函数是对目标框进行可视化,读取targets.txt为numpy数据,分别对x、y、w、h画出它们的直方图,最后保存下来。

plot_val_study函数
def plot_val_study(file='', dir='', x=None):  # from utils.plots import *; plot_val_study()
    # Plot file=study.txt generated by val.py (or plot all study*.txt in dir)
    save_dir = Path(file).parent if file else Path(dir)
    plot2 = False  # plot additional results
    if plot2:
        ax = plt.subplots(2, 4, figsize=(10, 6), tight_layout=True)[1].ravel()
 
    fig2, ax2 = plt.subplots(1, 1, figsize=(8, 4), tight_layout=True)
    # for f in [Path(path) / f'study_coco_{x}.txt' for x in ['yolov5s6', 'yolov5m6', 'yolov5l6', 'yolov5x6']]:
    for f in sorted(save_dir.glob('study*.txt')):
        y = np.loadtxt(f, dtype=np.float32, usecols=[0, 1, 2, 3, 7, 8, 9], ndmin=2).T
        x = np.arange(y.shape[1]) if x is None else np.array(x)
        if plot2:
            s = ['P', 'R', 'mAP@.5', 'mAP@.5:.95', 't_preprocess (ms/img)', 't_inference (ms/img)', 't_NMS (ms/img)']
            for i in range(7):
                ax[i].plot(x, y[i], '.-', linewidth=2, markersize=8)
                ax[i].set_title(s[i])
 
        j = y[3].argmax() + 1
        ax2.plot(y[5, 1:j], y[3, 1:j] * 1E2, '.-', linewidth=2, markersize=8,
                 label=f.stem.replace('study_coco_', '').replace('yolo', 'YOLO'))
 
    ax2.plot(1E3 / np.array([209, 140, 97, 58, 35, 18]), [34.6, 40.5, 43.0, 47.5, 49.7, 51.5],
             'k.-', linewidth=2, markersize=8, alpha=.25, label='EfficientDet')
 
    ax2.grid(alpha=0.2)
    ax2.set_yticks(np.arange(20, 60, 5))
    ax2.set_xlim(0, 57)
    ax2.set_ylim(30, 55)
    ax2.set_xlabel('GPU Speed (ms/img)')
    ax2.set_ylabel('COCO AP val')
    ax2.legend(loc='lower right')
    f = save_dir / 'study.png'
    print(f'Saving {f}...')
    plt.savefig(f, dpi=300)

 file:指定要可视化的文件,如果为空则可视化dir目录下的所有study*.txt文件

dir:如上,file为空时可视化当前目录下所有的study*.txt文件

x:图表的横坐标,为空时则为y长度的列表

 save_dir需要可视化的文件根目录,同时也是可视化后的图片的保存路径

接下来遍历save_dir目录下的所有study*.txt的文件,y为从文件读取出来的nunpy数据,x为参数x,若参数x为空则x为y的长度的一个列表

这里的plot2我想是作者在写代码时的测试,因为给写死了为False,一些代码根本不会执行到,应该是作者在测试的时候手动更改的。如果plot2为True的话会可视化一些评估指标。

接下来的操作就是在将数据可视化,画在图上,并保存下来,没什么好讲的,这里执行以下代码会更清楚,但由于电脑配置只能放弃。

plot_labels函数
def plot_labels(labels, names=(), save_dir=Path('')):
    # plot dataset labels
    print('Plotting labels... ')
    c, b = labels[:, 0], labels[:, 1:].transpose()  # classes, boxes
    nc = int(c.max() + 1)  # number of classes
    x = pd.Dataframe(b.transpose(), columns=['x', 'y', 'width', 'height'])
 
    # seaborn correlogram
    sn.pairplot(x, corner=True, diag_kind='auto', kind='hist', diag_kws=dict(bins=50), plot_kws=dict(pmax=0.9))
    plt.savefig(save_dir / 'labels_correlogram.jpg', dpi=200)
    plt.close()
 
    # matplotlib labels
    matplotlib.use('svg')  # faster
    ax = plt.subplots(2, 2, figsize=(8, 8), tight_layout=True)[1].ravel()
    y = ax[0].hist(c, bins=np.linspace(0, nc, nc + 1) - 0.5, rwidth=0.8)
    # [y[2].patches[i].set_color([x / 255 for x in colors(i)]) for i in range(nc)]  # update colors bug #3195
    ax[0].set_ylabel('instances')
    if 0 < len(names) < 30:
        ax[0].set_xticks(range(len(names)))
        ax[0].set_xticklabels(names, rotation=90, fontsize=10)
    else:
        ax[0].set_xlabel('classes')
    sn.histplot(x, x='x', y='y', ax=ax[2], bins=50, pmax=0.9)
    sn.histplot(x, x='width', y='height', ax=ax[3], bins=50, pmax=0.9)
 
    # rectangles
    labels[:, 1:3] = 0.5  # center
    labels[:, 1:] = xywh2xyxy(labels[:, 1:]) * 2000
    img = Image.fromarray(np.ones((2000, 2000, 3), dtype=np.uint8) * 255)
    for cls, *box in labels[:1000]:
        ImageDraw.Draw(img).rectangle(box, width=1, outline=colors(cls))  # plot
    ax[1].imshow(img)
    ax[1].axis('off')
 
    for a in [0, 1, 2, 3]:
        for s in ['top', 'right', 'left', 'bottom']:
            ax[a].spines[s].set_visible(False)
 
    plt.savefig(save_dir / 'labels.jpg', dpi=200)
    matplotlib.use('Agg')
    plt.close()

labels:数据集的标签

names:类别名

save_dir:保存路径

这个函数将数据集中的标签可视化,方便观察数据集的标签

c和b分别为类别和目标框

nc为类别的数量

x为目标框,将其转换成pandas的数据帧

接下来用seaborn.pairplot对x进行可视化,该函数用于展示变量两两之间的关系,线性或非线性,具体可参考Python可视化 | Seaborn5分钟入门(七)——pairplot - 知乎

接下来分别画出中心点分布、宽高分布,其中横坐标为类别,当类别数少于30时显示类别的名称。

最后就是画出目标框的分布。

总的来说是对数据集中的label进行可视化,分别对目标框的中心点、宽高、类别等进行可视化,便于观察它们的分布。

profile_idetection函数
def profile_idetection(start=0, stop=0, labels=(), save_dir=''):
    # Plot iDetection '*.txt' per-image logs. from utils.plots import *; profile_idetection()
    ax = plt.subplots(2, 4, figsize=(12, 6), tight_layout=True)[1].ravel()
    s = ['Images', 'Free Storage (GB)', 'RAM Usage (GB)', 'Battery', 'dt_raw (ms)', 'dt_smooth (ms)', 'real-world FPS']
    files = list(Path(save_dir).glob('frames*.txt'))
    for fi, f in enumerate(files):
        try:
            results = np.loadtxt(f, ndmin=2).T[:, 90:-30]  # clip first and last rows
            n = results.shape[1]  # number of rows
            x = np.arange(start, min(stop, n) if stop else n)
            results = results[:, x]
            t = (results[0] - results[0].min())  # set t0=0s
            results[0] = x
            for i, a in enumerate(ax):
                if i < len(results):
                    label = labels[fi] if len(labels) else f.stem.replace('frames_', '')
                    a.plot(t, results[i], marker='.', label=label, linewidth=1, markersize=5)
                    a.set_title(s[i])
                    a.set_xlabel('time (s)')
                    # if fi == len(files) - 1:
                    #     a.set_ylim(bottom=0)
                    for side in ['top', 'right']:
                        a.spines[side].set_visible(False)
                else:
                    a.remove()
        except Exception as e:
            print('Warning: Plotting error for %s; %s' % (f, e))
    ax[1].legend()
    plt.savefig(Path(save_dir) / 'idetection_profile.png', dpi=200)

 start:开始时间

stop:结束时间

labels:标签

save_dir:保存路径

 将每张图片输出的*.txt进行可视化

files为给定目录下的所有文件

对每个文件先加载出相应的数据,对数据进行数值大小限定处理,设定x轴为时间,y轴为文件读取出来的数据

可视化后进行保存

plot_evolve函数
def plot_evolve(evolve_csv='path/to/evolve.csv'):  # from utils.plots import *; plot_evolve()
    # Plot evolve.csv hyp evolution results
    evolve_csv = Path(evolve_csv)
    data = pd.read_csv(evolve_csv)
    keys = [x.strip() for x in data.columns]
    x = data.values
    f = fitness(x)
    j = np.argmax(f)  # max fitness index
    plt.figure(figsize=(10, 12), tight_layout=True)
    matplotlib.rc('font', **{'size': 8})
    for i, k in enumerate(keys[7:]):
        v = x[:, 7 + i]
        mu = v[j]  # best single result
        plt.subplot(6, 5, i + 1)
        plt.scatter(v, f, c=hist2d(v, f, 20), cmap='viridis', alpha=.8, edgecolors='none')
        plt.plot(mu, f.max(), 'k+', markersize=15)
        plt.title('%s = %.3g' % (k, mu), fontdict={'size': 9})  # limit to 40 characters
        if i % 5 != 0:
            plt.yticks([])
        print('%15s: %.3g' % (k, mu))
    f = evolve_csv.with_suffix('.png')  # filename
    plt.savefig(f, dpi=200)
    plt.close()
    print(f'Saved {f}')

该函数对evolve.txt文件进行可视化,其保存的是算法学习时的超参数进化,画出其分布

x从7开始是因为前面保存的是模型的评估指标,后面才是超参数

其实就是把在整个训练过程中的超参数变化画出来,保存为图片方便观察

plot_results函数
def plot_results(file='path/to/results.csv', dir=''):
    # Plot training results.csv. Usage: from utils.plots import *; plot_results('path/to/results.csv')
    save_dir = Path(file).parent if file else Path(dir)
    fig, ax = plt.subplots(2, 5, figsize=(12, 6), tight_layout=True)
    ax = ax.ravel()
    files = list(save_dir.glob('results*.csv'))
    assert len(files), f'No results.csv files found in {save_dir.resolve()}, nothing to plot.'
    for fi, f in enumerate(files):
        try:
            data = pd.read_csv(f)
            s = [x.strip() for x in data.columns]
            x = data.values[:, 0]
            for i, j in enumerate([1, 2, 3, 4, 5, 8, 9, 10, 6, 7]):
                y = data.values[:, j]
                # y[y == 0] = np.nan  # don't show zero values
                ax[i].plot(x, y, marker='.', label=f.stem, linewidth=2, markersize=8)
                ax[i].set_title(s[j], fontsize=12)
                # if j in [8, 9, 10]:  # share train and val loss y axes
                #     ax[i].get_shared_y_axes().join(ax[i], ax[i - 5])
        except Exception as e:
            print(f'Warning: Plotting error for {f}: {e}')
    ax[1].legend()
    fig.savefig(save_dir / 'results.png', dpi=200)
    plt.close()

对result文件进行可视化,一些操作与前面的函数基本没有区别

feature_visualization函数
def feature_visualization(x, module_type, stage, n=32, save_dir=Path('runs/detect/exp')):
    """
    x:              Features to be visualized
    module_type:    Module type
    stage:          Module stage within model
    n:              Maximum number of feature maps to plot
    save_dir:       Directory to save results
    """
    if 'Detect' not in module_type:
        batch, channels, height, width = x.shape  # batch, channels, height, width
        if height > 1 and width > 1:
            f = f"stage{stage}_{module_type.split('.')[-1]}_features.png"  # filename
 
            blocks = torch.chunk(x[0].cpu(), channels, dim=0)  # select batch index 0, block by channels
            n = min(n, channels)  # number of plots
            fig, ax = plt.subplots(math.ceil(n / 8), 8, tight_layout=True)  # 8 rows x n/8 cols
            ax = ax.ravel()
            plt.subplots_adjust(wspace=0.05, hspace=0.05)
            for i in range(n):
                ax[i].imshow(blocks[i].squeeze())  # cmap='gray'
                ax[i].axis('off')
 
            print(f'Saving {save_dir / f}... ({n}/{channels})')
            plt.savefig(save_dir / f, dpi=300, bbox_inches='tight')
            plt.close()

可视化每层网络的输出

x:网络中间层的输出

module_type:模块类型

stage:模块的层次

n:最多可视化多少个channel

save_dir:保存路径

 torch.chunk将x[0], 在channel维度,分割成channels个块,返回一个tuple 

接下来绘制每一个通道的特征图

最后保存下来。

总结

这部分代码比较繁杂,都是一些可视化的内容,便于对模型的效果进行评估,然而由于我没有相关的文件,不能运行代码,有些地方的理解不是很清楚,不过这部分也不是什么重要的代码,大概知道是在做什么就可以了,不用扣的太细。

转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/444513.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号