栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

使用pytorch进行训练的步骤

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

使用pytorch进行训练的步骤

下面是《深度学习框架PyTorch:入门与实践》中的代码,是一个标准的模型训练器,主要包括4个步骤:

准备数据准备网络模型准备损失函数和优化器训练

def train(**kwargs):
    # 1、准备数据
    train_data = DogCat(opt.train_data_root, train=True)
    val_data = DogCat(opt.train_data_root, train + false)
    train_dataloader = DataLoader(train_data, opt.batch_size, shuffle=True,
                                  num_workers=opt.num_workers)
    val_dataloader = DataLoader(val_data, opt.batch_size, shuffle=False,
                                num_workers=opt.num_workers)

    # 2、准备网络模型
    model = getattr(models, opt.model)()
    if opt.load_model_path:
        model.load(opt.load_model_path)
    if opt.use_gpu:
        model.cuda()

    # 3、准备损失函数和优化器
    criterion = t.nn.CrossEntropyLoss()
    lr = opt.lr
    optimizer = t.optim.Adam(model.parameters(), lr=lr,
                             weight_decay=opt.weight_decay)

    # 4、训练
    loss_meter = meter.AveragevalueMeter()
    confusion_matrix = meter.ConfusionMeter(2)
    previous_loss = 1e100
    for epoch in range(opt.max_epoch):
        for i, (data, label) in enumerate(train_dataloader):

            # 训练模型参数
            input = Variable(data)
            target = Variable(label)
            if opt.use_gpu:
                input = input.cuda()
                target = target.cuda()
            optimizer.zero_grad()
            score = model(input)
            loss = criterion(score, target)
            loss.backward()
            optimizer.step()

            # 更新统计指标及可视化
            loss_meter.add(loss.data[0])
            confusion_matrix.add(score.data, target.data)

            if i % opt.print_freg == opt.print_freg - 1:
                vis.plot('loss', loss_meter.value()[0])
                
        # 保存模型
        model.save()

        # 计算验证集上的指标及可视化
        val_cm, val_accuracy = val(model, val_dataloader)
        vis.plot('val_accuracy', val_accuracy)

        vis.log(
            "epoch: {epoch}, lr: {lr}, loss: {loss}, train_cm: {train_cm},val_cm : { val_cm}"
                .format(
                epoch=epoch,
                loss=loss_meter.value()[0],
                val_cm=str(val_cm.value()),
                train_cm=str(confusion_matrix.value()),
                lr=lr))

        # 如采损失不再下降,则降低学习率
        if loss_meter.value()[0] > previous_loss:
            lr = lr * opt.lr_decay
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr

        previous_loss = loss_meter.value()[0]

转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/767564.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号