| 参数名称 | 含义 | 默认值 | 接受类型 |
|---|---|---|---|
| callbacks | 添加回调函数或回调函数列表 | None(ModelCheckpoint默认值) | Union[List[Callback], Callback, None] |
| enable_checkpointing | 是否使用callbacks | True | bool |
| gpus | 使用的gpu数量(int)或gpu节点列表(list或str) | None(不使用GPU) | Union[int, str, List[int], None] |
| precision | 指定训练精度 | 32(full precision) | Union[int, str] |
| default_root_dir | 模型保存和日志记录默认根路径 | None(os.getcwd()) | Optional[str] |
| logger | 设置日志记录器(支持多个),若没设置logger的save_dir,则使用default_root_dir | True(默认日志记录) | Union[LightningLoggerbase, Iterable[LightningLoggerbase], bool] |
| max_epochs | 最多训练轮数(指定为**-1可以设置为无限次**) | None(1000) | Optional[int] |
| min_epochs | 最少训练轮数 | None(1) | Optional[int] |
| max_steps | 最大网络权重更新次数 | -1(禁用) | Optional[int] |
| min_steps | 最少网络权重更新次数 | None(禁用) | Optional[int] |
| weights_save_path | 权重保存路径(优先级高于default_root_dir),ModelCheckpoint未定义路径时将使用该路径 | None(default_root_dir) | Optional[str] |
| log_every_n_steps | 更新n次网络权重后记录一次日志 | 50 | int |
| auto_scale_batch_size | 自动搜索最佳batch_size并保存到模型的self.bacth_size中 | False | Union[str, bool] |
| auto_lr_find | 自动搜索最佳学习率并存储到self.lr或self.learing_rate | False | Union[str, bool] |
| accumulate_grad_batches | 每k次batches累计一次梯度 | None | Union[int, Dict[int, int], None] |
| check_val_every_n_epoch | 每n个train epoch执行一次验证 | 1 | int |
| num_sanity_val_steps | 开始训练前加载n个验证数据进行测试,k=-1时加载所有验证数据 | 2 | int |
这里max_steps/min_steps中的step就是指的是优化器的step,优化器每step一次就会更新一次网络权重梯度累加(Gradient Accumulation):受限于显存大小,一些训练任务只能使用较小的batch_size,但一般batch-size越大(一定范围内)模型收敛越稳定效果相对越好;梯度累加可以先累加多个batch的梯度再进行一次参数更新,相当于增大了batch_size。 Trainer.fit() 常用参数
| 参数名称 | 含义 | 默认值 |
|---|---|---|
| model | LightningModule实例 | |
| train_dataloaders | 训练数据加载器 | None |
| val_dataloaders | 验证数据加载器 | None |
| ckpt_path | ckpt文件路径(从这里文件恢复训练) | None |
| datamodule | LightningDataModule实例 | None |
使用该参数指定一个模型ckpt文件(需要保存整个模型,而不是仅仅保存模型权重),Trainer将从ckpt文件的下一个epoch继续训练。
示范net = MyNet(...) trainer = pl.Trainer(...) # 假设模型保存在./ckpt中 trainer.fit(net, train_iter, val_iter, ckpt_path='./ckpt/myresult.ckpt')使用注意
请不要使用Trainer()中的resume_from_checkpoint参数,该参数未来将被丢弃,请使用Trainer.fit()的ckpt_path参数 Trainer.test() 常用参数
| 参数名称 | 含义 | 默认值 |
|---|---|---|
| model | LightningModule实例 | None(使用**fit()**传递的模型) |
| verbose | 是否打印测试结果 | True |
| dataloaders | 测试数据加载器(可以使用torch.utils.data.DataLoader) | None |
| ckpt_path | ckpt文件路径(从这里文件恢复训练) | None |
| datamodule | LightningDataModule实例 | None |



