栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

Pytorch-Lightning中的训练器—Trainer

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

Pytorch-Lightning中的训练器—Trainer

Pytorch-Lightning中的训练器—Trainer Trainer() 常用参数
参数名称含义默认值接受类型
callbacks添加回调函数或回调函数列表None(ModelCheckpoint默认值)Union[List[Callback], Callback, None]
enable_checkpointing是否使用callbacksTruebool
gpus使用的gpu数量(int)或gpu节点列表(list或str)None(不使用GPU)Union[int, str, List[int], None]
precision指定训练精度32(full precision)Union[int, str]
default_root_dir模型保存和日志记录默认根路径None(os.getcwd())Optional[str]
logger设置日志记录器(支持多个),若没设置logger的save_dir,则使用default_root_dirTrue(默认日志记录)Union[LightningLoggerbase, Iterable[LightningLoggerbase], bool]
max_epochs最多训练轮数(指定为**-1可以设置为无限次**)None(1000)Optional[int]
min_epochs最少训练轮数None(1)Optional[int]
max_steps最大网络权重更新次数-1(禁用)Optional[int]
min_steps最少网络权重更新次数None(禁用)Optional[int]
weights_save_path权重保存路径(优先级高于default_root_dir),ModelCheckpoint未定义路径时将使用该路径None(default_root_dir)Optional[str]
log_every_n_steps更新n次网络权重后记录一次日志50int
auto_scale_batch_size自动搜索最佳batch_size并保存到模型的self.bacth_size中FalseUnion[str, bool]
auto_lr_find自动搜索最佳学习率并存储到self.lr或self.learing_rateFalseUnion[str, bool]
accumulate_grad_batches每k次batches累计一次梯度NoneUnion[int, Dict[int, int], None]
check_val_every_n_epoch每n个train epoch执行一次验证1int
num_sanity_val_steps开始训练前加载n个验证数据进行测试,k=-1时加载所有验证数据2int
额外的解释

这里max_steps/min_steps中的step就是指的是优化器的step,优化器每step一次就会更新一次网络权重梯度累加(Gradient Accumulation):受限于显存大小,一些训练任务只能使用较小的batch_size,但一般batch-size越大(一定范围内)模型收敛越稳定效果相对越好;梯度累加可以先累加多个batch的梯度再进行一次参数更新,相当于增大了batch_size。 Trainer.fit() 常用参数

参数名称含义默认值
modelLightningModule实例
train_dataloaders训练数据加载器None
val_dataloaders验证数据加载器None
ckpt_pathckpt文件路径(从这里文件恢复训练)None
datamoduleLightningDataModule实例None
ckpt_path参数详解(从之前的模型恢复训练)

​ 使用该参数指定一个模型ckpt文件(需要保存整个模型,而不是仅仅保存模型权重),Trainer将从ckpt文件的下一个epoch继续训练。

示范
net = MyNet(...)
trainer = pl.Trainer(...)
# 假设模型保存在./ckpt中
trainer.fit(net, train_iter, val_iter, ckpt_path='./ckpt/myresult.ckpt')
使用注意

请不要使用Trainer()中的resume_from_checkpoint参数,该参数未来将被丢弃,请使用Trainer.fit()的ckpt_path参数 Trainer.test() 常用参数

参数名称含义默认值
modelLightningModule实例None(使用**fit()**传递的模型)
verbose是否打印测试结果True
dataloaders测试数据加载器(可以使用torch.utils.data.DataLoader)None
ckpt_pathckpt文件路径(从这里文件恢复训练)None
datamoduleLightningDataModule实例None
转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/714902.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号