- pytorch 1.7.0
- pytorch lightning 1.3.8
#1998
import torch
import pytorch_lightning as pl
import torch.nn as nn
class LeNet5(pl.LightningModule):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(in_channels=1,kernel_size=5,stride=1,out_channels=6) # 32-5/1+1 = 28 # 28*28*6
self.pool1 = nn.MaxPool2d(kernel_size=2,stride=2) # 28-2/2+1 = 14 # 14*14*6
self.conv2 = nn.Conv2d(in_channels=6,kernel_size=5,stride=1,out_channels=16) # 14-5/1+


