栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

Pytorch-Lightning--Tuner

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

Pytorch-Lightning--Tuner

Pytorch-Lightning–Tuner lr_find() 参数详解
参数名称含义默认值
modelLightningModule实例
train_dataloaders训练数据加载器None
val_dataloaders验证数据加载器None
datamoduleLightningDataModule实例None
min_lr学习率最小值1e-08
max_lr学习率最大值1
num_training测试学习率的训练轮数100
mode学习率寻找策略,分为指数(默认)和线性(linear)exponential
early_stop_threshold当任意一点的loss>=early_stop_threshold*best_loss时停止搜索,设置为None禁用该项4.0
update_attr将搜索到的学习率更新到模型参数中False
使用注意

暂时只支持单个优化器暂不支持DDP 用法

使用self.learing_rate或self.lr作为学习率参数

class LitModel(LightningModule):
    def __init__(self, learning_rate):
        self.learning_rate = learning_rate

    def configure_optimizers(self):
        return Adam(self.parameters(), lr=(self.lr or self.learning_rate))


model = LitModel()

# 开启 auto_lr_find标志
trainer = Trainer(auto_lr_find=True)
# 寻找合适的学习率
trainer.tune(model)

使用其他的学习率变量名称

model = LitModel()

# 设置为自己的学习率超参数名称 my_value
trainer = Trainer(auto_lr_find="my_value")

trainer.tune(model)

使用lr_find()查看自动搜索学习率的结果

model = MyModelClass(hparams)
trainer = Trainer()

# 运行学习率搜索
lr_finder = trainer.tuner.lr_find(model)

# 查看搜索结果
lr_finder.results

# 绘制学习率搜索图,suggest参数指定是否显示建议的学习率点
fig = lr_finder.plot(suggest=True)
fig.show()

# 获取最佳学习率或建议的学习率
new_lr = lr_finder.suggestion()

# 更新模型的学习率
model.hparams.lr = new_lr

# 训练模型
trainer.fit(model)

scale_batch_size() 参数详解
参数名称含义默认值
modelLightningModule实例
train_dataloaders训练数据加载器None
val_dataloaders验证数据加载器None
datamoduleLightningDataModule实例None
mode学习率寻找策略,分为幂次方(默认)和二分(binsearch)power
steps_per_trial每次测试当前batch_size的训练step数量3
init_val初始batch_size大小2
max_trials算法结束前batch_size最大增量25
batch_arg_name存储batch_size的属性名'batch_size'

Returns:搜索结果

将在如下地方寻找batch_arg_name

modelmodel.hparamstrainer.datamodule (如果datamodule传递给了tune()) 使用注意

暂时不支持DDP模式

由于需要使用模型的batch_arg_name属性,因此不能直接将dataloader直接传递给trainer.fit(),否则此功能将失效,需要在模型中加载数据

原来模型中的batch_arg_name属性将被覆盖

train_dataloader()应该依赖于batch_arg_name属性

def train_dataloader(self):
    return DataLoader(train_dataset, batch_size=self.batch_size | self.hparams.batch_size)
用法 使用Trainer中的auto_scale_batch_size属性
# 默认不执行缩放
trainer = Trainer(auto_scale_batch_size=None)

# 设置搜索策略
trainer = Trainer(auto_scale_batch_size=None | "power" | "binsearch")

# 寻找最佳batch_szie,并自动设置到模型的batch_size属性中
trainer.tune(model)
使用scale_batch_size()
# 返回搜索结果
new_batch_size = tuner.scale_batch_size(model, *extra_parameters_here)

# 覆盖原来的属性(这个过程是自动的)
model.hparams.batch_size = new_batch_size
转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/714782.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号