正确率卡在了91%
第一次迁移只有50左右的正确率 修改了损失函数有了80左右的正确率。
import os import torch from torchvision import transforms,datasets import torch import torch.nn as nn import torch.nn.functional as F from torch.autograd import Variable import numpy as np from torch.utils.data import Dataset, DataLoader from torchvision import transforms, utils import torch.optim as optim from PIL import Image device torch.device( cuda:0 if torch.cuda.is_available() else cpu ) print( Using gpu: %s % torch.cuda.is_available())
数据预处理
data_transform transforms.Compose([ transforms.Resize(84), transforms.CenterCrop(84), transforms.ToTensor(), transforms.Normalize(mean [0.485, 0.456, 0.406],std [0.229, 0.224, 0.225])
获取数据 先把数据的路径 名称 标签都写入txt 然后通过txt获得信息
class MyDataSet(Dataset): def __init__(self, txtPath, data_transform): self.imgPathArr [] self.labelArr [] with open(txtPath, rb ) as f: txtArr f.readlines() for i in txtArr: fileArr str(i.strip(), encoding unicode_escape ).split( ) self.imgPathArr.append(fileArr[0]) self.labelArr.append(fileArr[1]) self.transforms data_transform def __getitem__(self, index): label np.array(int(self.labelArr[index])) img_path self.imgPathArr[index] pil_img Image.open(img_path) if self.transforms: data self.transforms(pil_img) else: pil_img np.asarray(pil_img) data torch.from_numpy(pil_img) return data, label def __len__(self): return len(self.imgPathArr)
一个简单的cnn网络
class Net(nn.Module): def __init__(self): super(Net, self).__init__() self.conv1 nn.Sequential( nn.Conv2d(3, 16, 5, 1, 2), nn.ReLU(), nn.MaxPool2d(kernel_size 2), self.conv2 nn.Sequential( nn.Conv2d(16, 32, 5, 1, 2), nn.ReLU(), nn.MaxPool2d(2), self.out nn.Linear(32 * 21 * 21, 2) def forward(self, x): x self.conv1(x) x self.conv2(x) x x.view(x.size(0), -1) output self.out(x) return output, x
if __name__ __main__ : train_dataset MyDataSet( D:/猫狗大战数据集/cat_dog/train.txt , data_transform) train_loader torch.utils.data.DataLoader(train_dataset,batch_size 4,shuffle True,num_workers 4) test_dataset MyDataSet( D:/猫狗大战数据集/cat_dog/text.txt , data_transform) test_loader torch.utils.data.DataLoader(test_dataset,batch_size 1,shuffle True,num_workers 4)



