import torch
import torch.nn as nn
import torchvision
from #所在程序名字 import #自己网络保存的名字
import matplotlib.pyplot as plt
from PIL import Image
from torchvision import transforms
import numpy as np
def test_mydata():
im = plt.imread('33333.jpg') #自己手写的图片读入
images = Image.open('33333.jpg') #可以转换为黑底白字,读取更准确
images = images.resize((28,28))
images = images.convert('L')
transform = transforms.ToTensor()
images = transform(images)
images = images.resize(1,1,28,28)
# 加载网络和参数
model = #加上自己网络的名字()
model.load_state_dict(torch.load('pathh'))
model.eval()
outputs = model(images)
values, indices = outputs.data.max(1) # 返回最大概率值和下标
plt.title('{}'.format((int(indices[0]))))
plt.imshow(im)
plt.show()
test_mydata()