栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

第四周作业:卷积神经网络(Part2)

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

第四周作业:卷积神经网络(Part2)

vgg结果

正确率卡在了91%

resnet结果

第一次迁移只有50左右的正确率 修改了损失函数有了80左右的正确率。

lenet结果

自己的cnn网络
import os
import torch
from torchvision import transforms,datasets
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import numpy as np
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
import torch.optim as optim
from PIL import Image
device torch.device( cuda:0 if torch.cuda.is_available() else cpu )
print( Using gpu: %s % torch.cuda.is_available())

数据预处理

data_transform transforms.Compose([
 transforms.Resize(84),
 transforms.CenterCrop(84),
 transforms.ToTensor(),
 transforms.Normalize(mean [0.485, 0.456, 0.406],std [0.229, 0.224, 0.225])

获取数据 先把数据的路径 名称 标签都写入txt 然后通过txt获得信息

class MyDataSet(Dataset):
 def __init__(self, txtPath, data_transform):
 self.imgPathArr []
 self.labelArr []
 with open(txtPath, rb ) as f:
 txtArr f.readlines()
 for i in txtArr:
 fileArr str(i.strip(), encoding unicode_escape ).split( )
 self.imgPathArr.append(fileArr[0])
 self.labelArr.append(fileArr[1])
 self.transforms data_transform
 def __getitem__(self, index):
 label np.array(int(self.labelArr[index]))
 img_path self.imgPathArr[index]
 pil_img Image.open(img_path)
 if self.transforms:
 data self.transforms(pil_img)
 else:
 pil_img np.asarray(pil_img)
 data torch.from_numpy(pil_img)
 return data, label
 def __len__(self):
 return len(self.imgPathArr)

一个简单的cnn网络

class Net(nn.Module):
 def __init__(self):
 super(Net, self).__init__()
 self.conv1 nn.Sequential( 
 nn.Conv2d(3, 16, 5, 1, 2), 
 nn.ReLU(), 
 nn.MaxPool2d(kernel_size 2), 
 self.conv2 nn.Sequential( 
 nn.Conv2d(16, 32, 5, 1, 2), 
 nn.ReLU(), 
 nn.MaxPool2d(2), 
 self.out nn.Linear(32 * 21 * 21, 2) 
 def forward(self, x):
 x self.conv1(x)
 x self.conv2(x)
 x x.view(x.size(0), -1) 
 output self.out(x)
 return output, x 
if __name__ __main__ :
 train_dataset MyDataSet( D:/猫狗大战数据集/cat_dog/train.txt , data_transform) 
 train_loader torch.utils.data.DataLoader(train_dataset,batch_size 4,shuffle True,num_workers 4)
 test_dataset MyDataSet( D:/猫狗大战数据集/cat_dog/text.txt , data_transform) 
 test_loader torch.utils.data.DataLoader(test_dataset,batch_size 1,shuffle True,num_workers 4)
转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/268009.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号