#传入train 使用训练集的数据处理方法处理数据
train_num len(train_dataset)#将训练集中的图片个数赋值给train_num
# { daisy :0, dandelion :1, roses :2, sunflower :3, tulips :4}
flower_list train_dataset.class_to_idx#获取分类名称所对应的索引
cla_dict dict((val, key) for key, val in flower_list.items()
#遍历所获取的分类以及索引的字典 并且将key,values交换位置
# write dict into json file
json_str json.dumps(cla_dict, indent 4) #将字典编码成json格式
with open( class_indices.json , w ) as json_file:
json_file.write(json_str)
batch_size 32#定义batch_size 32
nw min([os.cpu_count(), batch_size if batch_size 1 else 0, 8]) # number of workers
print( Using {} dataloader workers every process .format(nw))
train_loader torch.utils.data.DataLoader(train_dataset,
batch_size batch_size, shuffle True,
num_workers 0)
#train_loader函数是为了随机在数据集中获取一批批数据 num_workers 0加载数据的线程个数 在windows系统下该数为 0 意思为在windows系统下使用一个主线程加载数据
validate_dataset datasets.ImageFolder(root os.path.join(image_path, val ),
transform data_transform[ val ])
val_num len(validate_dataset)
validate_loader torch.utils.data.DataLoader(validate_dataset,
batch_size 4, shuffle True,
num_workers 0)
print( using {} images for training, {} images for validation. .format(train_num,
val_num))
net AlexNet(num_classes 5, init_weights True)#num_classes 5花有5种类别 初始化权重
net.to(device)#将该网络分配到制定的设备上 gpu或者cpu
loss_function nn.CrossEntropyLoss()#定义损失函数 针对多类别的损失交叉熵函数
# pata list(net.parameters())
optimizer optim.Adam(net.parameters(), lr 0.0002)
#定义一个Adam优化器 优化对象是所有可训练的参数 定义学习率为0.0002 通过调试获得的最佳学习率
epochs 10
save_path ./AlexNet.pth #保存准确率最高的那次模型的路径
best_acc 0.0#最佳准确率
train_steps len(train_loader)
for epoch in range(epochs):
# train
net.train()#使用net.train()方法 该方法中有dropout
running_loss 0.0#使用running_loss方法统计训练过程中的平均损失
train_bar tqdm(train_loader)
for step, data in enumerate(train_bar):#遍历数据集
images, labels data#将数据分为图像标签
optimizer.zero_grad()#清空之前的梯度信息
outputs net(images.to(device))#通过正向传播的到输出
loss loss_function(outputs, labels.to(device))#指定设备gpu或者cpu,通过Loss_function函数计算预测值与真实值之间的差距
loss.backward()#将损失反向传播到每一个节点
optimizer.step()#通过optimizer更新每一个参数
# print statistics
running_loss loss.item()#累加损失
#print train process
rate (step 1)/len(train_loader)
a * * int(rate*50)
b . * int((1-rate)*50)
print( rtrain loss:{:^3.0f}%[{}- {}]{:.3f} .format(int(rate*100),a,b,loss),end )
print()