栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

pytorch学习(六)神经网络卷积层

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

pytorch学习(六)神经网络卷积层

这里写目录标题
  • 1 数据集:CIFAR10
  • 2 简单一层卷积可视化
    • 1、代码
    • 2、运行结果
    • 3、tensorboard结果可视化

1 数据集:CIFAR10

使用CIFAR10,之前下过所以直接用。没有下载的请看此条笔记进行下载。
1、数据集所在文件夹如下

2、运行下列代码

from torch.utils.data import DataLoader
from torch import nn


#torchvision.datasets.CIFAR10参数说明

#root:数据集根目录
#train(bool,optional) - 如果为True,则创建数据集training.pt,否则创建数据集test.pt。
#download(bool,optional) - 如果为true,则从Internet下载数据集并将其放在根目录中。如果已下载数据集,则不会再次下载。
#transform(callable ,optional) - 一个函数/转换,它接收PIL图像并返回转换后的版本。例如,transforms.RandomCrop
#target_transform(callable ,optional) - 接收目标并对其进行转换的函数/转换
dataset = torchvision.datasets.CIFAR10("E:/PycharmProjects/Pytoch_learning/dataset",
                                       train=False, transform=torchvision.transforms.ToTensor(), download=False)
dataloader = DataLoader(dataset,batch_size=64)

#创建模型
class Tian(nn.Module):
    def __init__(self):  #初始化
        super(Tian, self).__init__()
        self.conv1 = Conv2d(in_channels=3, out_channels=6,kernel_size=3, stride=1,padding=0)

    def forward(self,x):
        x = self.conv1(x)
        return x

#使用模型
ren = Tian()
print(ren)

3.报错
错误:RuntimeError: Dataset not found or corrupted. You can use download=True to download it

已经下好但报错没有找到
解决:之前路径是dataset,使用CIFAR10的绝对路径即可
只用更改代码为以下方式

dataset = torchvision.datasets.CIFAR10("E:/PycharmProjects/Pytoch_learning/dataset/CIFAR10",
                                       train=False, transform=torchvision.transforms.ToTensor(), download=False)

输出

2 简单一层卷积可视化 1、代码
import torchvision
from torch.nn import Conv2d
from torch.utils.data import DataLoader
from torch import nn
from torch.utils.tensorboard import SummaryWriter


#torchvision.datasets.CIFAR10参数说明

#root:数据集根目录
#train(bool,optional) - 如果为True,则创建数据集training.pt,否则创建数据集test.pt。
#download(bool,optional) - 如果为true,则从Internet下载数据集并将其放在根目录中。如果已下载数据集,则不会再次下载。
#transform(callable ,optional) - 一个函数/转换,它接收PIL图像并返回转换后的版本。例如,transforms.RandomCrop
#target_transform(callable ,optional) - 接收目标并对其进行转换的函数/转换
dataset = torchvision.datasets.CIFAR10("E:/PycharmProjects/Pytoch_learning/dataset/CIFAR10",
                                       train=False, transform=torchvision.transforms.ToTensor(), download=False)
dataloader = DataLoader(dataset,batch_size=64)

#创建模型
class Tian(nn.Module):
    def __init__(self):  #初始化
        super(Tian, self).__init__()
        self.conv1 = Conv2d(in_channels=3, out_channels=3,kernel_size=3, stride=1,padding=0)

    def forward(self,x):
        x = self.conv1(x)
        return x

#使用模型
ren = Tian()
print(ren)

#tensorboard可视化
writer = SummaryWriter(log_dir="E:/PycharmProjects/runs/flower_experiment")
step = 0
#查看数据集的每张图片
for data in dataloader:
    imgs, targets =data
    output = ren(imgs)
    print(output.shape)

    writer.add_images("day8_input", imgs, step)
    writer.add_images("day8_output", output, step)

    step += 1
2、运行结果

右键 -->run

3、tensorboard结果可视化

tensorboard详细教程tensorboard新手友好
在terminal中输入

tensorboard --logdir="E:/PycharmProjects/runs/flower_experiment"

回车后点击蓝色链接即可显示

网页显示结果

转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/689284.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号