栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

使用卷积神经网络进行图像分类211015

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

使用卷积神经网络进行图像分类211015

import paddle
import paddle.nn.functional as F
from paddle.vision.transforms import ToTensor
import numpy as np
import matplotlib.pyplot as plt

transform = ToTensor()
cifar10_train = paddle.vision.datasets.Cifar10(mode='train',
                                               transform=transform)
cifar10_test = paddle.vision.datasets.Cifar10(mode='test',
                                              transform=transform)
class MyNet(paddle.nn.Layer):
    def __init__(self, num_classes=1):
        super(MyNet, self).__init__()

        self.conv1 = paddle.nn.Conv2D(in_channels=3, out_channels=32, kernel_size=(3, 3))
        self.pool1 = paddle.nn.MaxPool2D(kernel_size=2, stride=2)

        self.conv2 = paddle.nn.Conv2D(in_channels=32, out_channels=64, kernel_size=(3,3))
        self.pool2 = paddle.nn.MaxPool2D(kernel_size=2, stride=2)

        self.conv3 = paddle.nn.Conv2D(in_channels=64, out_channels=64, kernel_size=(3,3))

        self.flatten = paddle.nn.Flatten()

        self.linear1 = paddle.nn.Linear(in_features=1024, out_features=64)
        self.linear2 = paddle.nn.Linear(in_features=64, out_features=num_classes)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.pool1(x)

        x = self.conv2(x)
        x = F.relu(x)
        x = self.pool2(x)

        x = self.conv3(x)
        x = F.relu(x)

        x = self.flatten(x)
        x = self.linear1(x)
        x = F.relu(x)
        x = self.linear2(x)
        return x

epoch_num = 10
batch_size = 32
learning_rate = 0.001
val_acc_history = []
val_loss_history = []

def train(model):
    print('start training ... ')
    # turn into training mode
    model.train()

    opt = paddle.optimizer.Adam(learning_rate=learning_rate,
                                parameters=model.parameters())

    train_loader = paddle.io.DataLoader(cifar10_train,
                                        shuffle=True,
                                        batch_size=batch_size)

    valid_loader = paddle.io.DataLoader(cifar10_test, batch_size=batch_size)
    
    for epoch in range(epoch_num):
        for batch_id, data in enumerate(train_loader()):
            x_data = data[0]
            y_data = paddle.to_tensor(data[1])
            y_data = paddle.unsqueeze(y_data, 1)

            logits = model(x_data)
            loss = F.cross_entropy(logits, y_data)

            if batch_id % 1000 == 0:
                print("epoch: {}, batch_id: {}, loss is: {}".format(epoch, batch_id, loss.numpy()))
            loss.backward()
            opt.step()
            opt.clear_grad()

        # evaluate model after one epoch
        model.eval()
        accuracies = []
        losses = []
        for batch_id, data in enumerate(valid_loader()):
            x_data = data[0]
            y_data = paddle.to_tensor(data[1])
            y_data = paddle.unsqueeze(y_data, 1)

            logits = model(x_data)
            loss = F.cross_entropy(logits, y_data)
            acc = paddle.metric.accuracy(logits, y_data)
            accuracies.append(acc.numpy())
            losses.append(loss.numpy())

        avg_acc, avg_loss = np.mean(accuracies), np.mean(losses)
        print("[validation] accuracy/loss: {}/{}".format(avg_acc, avg_loss))
        val_acc_history.append(avg_acc)
        val_loss_history.append(avg_loss)
        model.train()

model = MyNet(num_classes=10)
train(model)

plt.plot(val_acc_history, label = 'validation accuracy')

plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.ylim([0.5, 0.8])
plt.legend(loc='lower right')
    /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/__init__.py:107: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
      from collections import MutableMapping
    /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/rcsetup.py:20: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
      from collections import Iterable, Mapping
    /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/colors.py:53: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
      from collections import Sized
    Cache file /home/aistudio/.cache/paddle/dataset/cifar/cifar-10-python.tar.gz not found, downloading https://dataset.bj.bcebos.com/cifar/cifar-10-python.tar.gz 
    Begin to download
    
    Download finished


start training ... 
epoch: 0, batch_id: 0, loss is: [2.2984047]
epoch: 0, batch_id: 1000, loss is: [1.1130366]
[validation] accuracy/loss: 0.5728833675384521/1.1950840950012207
epoch: 1, batch_id: 0, loss is: [1.4412661]
epoch: 1, batch_id: 1000, loss is: [0.7821938]
[validation] accuracy/loss: 0.6347843408584595/1.0210740566253662
epoch: 2, batch_id: 0, loss is: [0.9672767]
epoch: 2, batch_id: 1000, loss is: [1.3251742]
[validation] accuracy/loss: 0.6820088028907776/0.924034059047699
epoch: 3, batch_id: 0, loss is: [0.948807]
epoch: 3, batch_id: 1000, loss is: [1.0935752]
[validation] accuracy/loss: 0.6902955174446106/0.899625837802887
epoch: 4, batch_id: 0, loss is: [0.68473774]
epoch: 4, batch_id: 1000, loss is: [0.9899424]
[validation] accuracy/loss: 0.699181318283081/0.8516970276832581
epoch: 5, batch_id: 0, loss is: [0.59179]
epoch: 5, batch_id: 1000, loss is: [0.60291755]
[validation] accuracy/loss: 0.6930910348892212/0.9029199481010437
epoch: 6, batch_id: 0, loss is: [0.86010605]
epoch: 6, batch_id: 1000, loss is: [0.8406735]
[validation] accuracy/loss: 0.7024760246276855/0.8776425719261169
epoch: 7, batch_id: 0, loss is: [1.0346094]
epoch: 7, batch_id: 1000, loss is: [0.2502644]
[validation] accuracy/loss: 0.7146565318107605/0.834501326084137
epoch: 8, batch_id: 0, loss is: [0.5335045]
epoch: 8, batch_id: 1000, loss is: [0.38619795]
[validation] accuracy/loss: 0.7243410348892212/0.8565839529037476
epoch: 9, batch_id: 0, loss is: [0.55277234]
epoch: 9, batch_id: 1000, loss is: [0.43163693]
[validation] accuracy/loss: 0.7071685194969177/0.9381765723228455


转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/326167.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号