栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

Pytorch children()、modules()、named

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

Pytorch children()、modules()、named

def __init__(self, num_classes 1000, init_weights False): super(AlexNet, self).__init__() self.features nn.Sequential( nn.Conv2d(3, 48, kernel_size 11, stride 4, padding 2), # input[3, 224, 224] output[48, 55, 55] nn.ReLU(inplace True), nn.MaxPool2d(kernel_size 3, stride 2), # output[48, 27, 27] nn.Conv2d(48, 128, kernel_size 5, padding 2), # output[128, 27, 27] nn.ReLU(inplace True), nn.MaxPool2d(kernel_size 3, stride 2), # output[128, 13, 13] nn.Conv2d(128, 192, kernel_size 3, padding 1), # output[192, 13, 13] nn.ReLU(inplace True), nn.Conv2d(192, 192, kernel_size 3, padding 1), # output[192, 13, 13] nn.ReLU(inplace True), nn.Conv2d(192, 128, kernel_size 3, padding 1), # output[128, 13, 13] nn.ReLU(inplace True), nn.MaxPool2d(kernel_size 3, stride 2), # output[128, 6, 6] self.classifier nn.Sequential( nn.Dropout(p 0.5), nn.Linear(128 * 6 * 6, 2048), nn.ReLU(inplace True), nn.Dropout(p 0.5), nn.Linear(2048, 2048), nn.ReLU(inplace True), nn.Linear(2048, num_classes), if init_weights: self._initialize_weights() def forward(self, x): x self.features(x) x self.classifier(x) return x def _initialize_weights(self): for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode fan_out , nonlinearity relu ) if m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.Linear): nn.init.normal_(m.weight, 0, 0.01) nn.init.constant_(m.bias, 0) if __name__ __main__ : model AlexNet() print( model children: ) for module in model.children(): print(module) print( model modules: ) for module in model.modules(): print(module) print( model named children: ) for name, module in model.named_children(): print( name: {}, module: {} .format(name, module)) print( model named modules: ) for name, module in model.named_modules(): print( name: {}, module: {} .format(name, module)) print( model named parameters: ) for name, parameter in model.named_parameters(): print( name: {}, parameter: {} .format(name, parameter)) print( parameters: ) for parameter in model.parameters(): print( parameter: {} .format(parameter))

 

转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/268247.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号