最近看了几篇使用transformer的文章,于是想用其中的一个transformer模块来替换另一个方法的骨干网络(backbone),替换完之后跑起来感觉没有什么效果,想着可能是transformer模型要用预训练会好一些。但是。由于是自己把原来方法的backbone替换掉,因此没有现成的直接可以使用的预训练模型来使用,只能从两个方法中提取相应模块的权重然后整合起来当做预训练模型使用。
原理不同的预训练模型之所以能够东拼一块、西拼一块成为一个可以用的预训练模型,是因为在预训练模型中有相应的键值对(key-value),只要把预训练模型中的对应到自己使用的网络中的键值对进行更新就好了。简单的举个例子:
# 使用的网络中有这样的减值对:
{'conv1': [1, 1, 1, 0, 0]}
# 加载的预训练模型中也有这样的键值对,但是值不同,这样你就可以通过字典更新的方式获得到预训练模型的训练权重了。
{'conv1': [1, 1, 1, 0, 0]} -> {'conv1': [0, 1, 0, 0, 1]}
但是值得注意的一点是,一定要根据自己构建的网络中的键值对来对预训练模型的键值对进行提取,否则将会更新失败。
更新backbone的预训练权重过程 1、查看网络的键值对和预训练模型的键值对首先查看需要替换的backbone在原始的预训练模型中的键值对,因为这个是作为backbone使用,所以一般打印出来的信息会有‘backbone'几个字,挺好辨认的。代码和效果如下:
def extract_backbone():
# 查看backbone部分的预训练模型键值对
backbone_model_path = "backbone.pth"
backbone_train_model = torch.load(backbone_model_path)
print(backbone_train_model .keys())
打印结果如下:
odict_keys(['backbone.SA_modules.0.local_chunk.pe.0.conv.weight', 'backbone.SA_modules.0.local_chunk.pe.0.bn.weight', ...., 'bbox_head.vote_module.vote_conv.0.bn.bias']
可以看到在打印中的信息就能看到’backbone'几个字,也就能知道我们要提取的数据范围。
接下来查看网络的键值对:
model = net()
model_stat_dict = model.state_dict()
print(model_stat_dict.keys())
打印结果:
odict_keys(['SA_modules.0.local_chunk.pe.0.conv.weight', 'SA_modules.0.local_chunk.pe.0.bn.weight', ...]
可以看到里面的键的名字虽然不是完全一样'backbone.SA_modules.0.local_chunk.pe.0.conv.weight'对比'SA_modules.0.local_chunk.pe.0.conv.weight',但是能够知道之间的对应关系,也就能知道该怎么样更新字典的键值对。
2、提取键值对并更新方法:观察法,观察对应的键值对相差什么样的字符串,然后把多余的字符串去掉,如,把'backbone.SA_modules.0.local_chunk.pe.0.conv.weight'换为'SA_modules.0.local_chunk.pe.0.conv.weight',
# 提取backbone部分的权重
backbone_stat_dict = {}
for i in backbone_train_model.keys():
if 'backbone' in i and 'FP_modules.1.mlps' not in i:
backbone_stat_dict[i.replace('backbone.', '')] = backbone_train_model[i]
这里的方法是判断字符串是否在另一个字符串中来定位自己想要的键值对,另外一个就是,因为我把网络的最后一层的输出改变了,例如原始输出是256,现在改为288,那么最后一层的训练权重就不能用了,只能使用默认的值。这一步是需要自己debug或者慢慢观察的出来的,遇到什么bug就解决什么bug就好了。
接下来对自己新建的网络进行权值更新:
# 更新网络权重
model_stat_dict.update(backbone_stat_dict)
model.load_state_dict(model_stat_dict)
更新非backbone部分的预训练权重过程
1、查看键值对
def extract_others():
other_model_path = "xxx.pth"
other_train_model = torch.load(other_model_path )
print(other_train_model['model'].keys())
打印结果:
odict_keys(['module.backbone_net.sa1.mlp_module.layer0.conv.weight',...]
这里的代码跟提取backbone的差不多,多了一个[‘model’]是因为这个预训练模型的所有键值对放在一个叫model的字典中,保存的层次不一样而已,可以看到里面也有’backbone’几个字,但是这一次,提取的就不是‘backbone'部分的预训练权重了。这一次是要把‘backbone部分的训练权重丢掉,保留其他部分。
具体代码:
other_state_dict = {}
for i in other_train_model['model'].keys():
if 'backbone_net' not in i and i.replace('module.', '') in model_stat_dict .keys():
other_state_dict [i.replace('module.', '')] = other_train_model['model'][i]
这里提取的是除backbone部分和在新建的网络中的键值对。
2、更新键值对 model_stat_dict.update(other_state_dict)
model.load_state_dict(model_stat_dict)
结果对比
新建网络的默认权重:
('prediction_heads.5.bn1.weight', tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]
更新后的模型权重:
('prediction_heads.5.bn1.weight', tensor([1.5746, 0.8308, 0.9783, 1.7613, 1.2984, 1.7066, 1.6914, 1.4509, 1.0602,
0.6725, 1.0243, 1.0239, 1.0145, 0.8561, 1.4968, 0.7426, 1.1622, 1.1951,
0.8682, 1.7863, 1.1278, 1.1146, 0.9362, 0.7673, 0.6439, 1.0142, 0.7825,
1.2053, 0.8184, 0.5493, 1.6683, 2.4341, 0.9867, 1.4273, 0.8302, 2.2120,
1.2883, 2.2580, 1.8059, 0.6444, 1.3589, 1.1272, 1.2261, 1.7947, 0.7312,
1.6581, 0.7312, 1.1488, 0.9693, 1.6703, 1.7478, 0.6688, 1.5604, 0.9614,
1.3638, 2.7505, 1.5183, 1.5955, 1.9048, 1.5618, 1.2638, 1.6549, 1.1343,
1.5455, 1.1127, 1.0814, 0.8313, 0.7031, 1.1536, 5.1275, 0.5989, 1.6834,
0.7864, 1.0093, 0.9050, 0.8332, 0.7380, 1.9826, 0.8156, 0.6324, 1.0322,
0.8536, 1.6749, 1.0331, 0.8382, 0.6750, 1.2850, 0.9476, 1.0836, 0.8562,
0.8736, 0.7233, 1.4038, 0.9900, 1.5904, 0.8788, 1.1214, 0.8128, 1.7798,
1.5962, 1.4050, 0.8402, 1.7244, 1.1827, 1.3041, 0.8613, 0.7466, 1.8008,
1.3645, 1.1653, 0.9632, 1.7416, 1.7852, 1.6817, 1.4639, 1.6094, 2.0025,
0.5836, 1.4742, 1.2245, 1.1398, 1.2036, 1.2999, 1.6483, 1.0462, 1.8296,
1.2137, 0.7727, 0.8674, 1.1087, 0.7098, 1.5879, 1.0915, 1.0083, 1.5613,
0.6884, 1.1696, 1.7591, 4.0460, 1.2088, 1.5664, 0.8087, 1.6567, 1.0537,
2.0586, 2.6935, 0.9094, 0.9196, 0.6560, 0.6821, 1.3052, 0.8119, 0.9941,
0.7237, 2.0793, 1.7808, 0.8971, 1.3992, 0.8429, 4.0478, 1.2344, 1.4197,
1.8055, 0.4954, 1.3991, 1.2639, 1.6844, 1.7331, 1.5301, 0.9634, 1.1549,
5.1205, 1.0622, 0.8827, 0.5879, 1.4933, 1.4872, 0.7220, 0.6562, 0.8938,
1.5842, 1.6547, 1.1611, 1.1708, 0.8881, 0.6681, 0.6075, 1.8197, 1.6702,
2.4868, 0.8145, 1.9318, 1.3400, 1.4768, 1.7392, 0.9595, 1.2901, 1.5173,
1.3438, 1.6758, 1.0352, 1.1626, 1.8464, 1.7514, 0.8785, 1.2588, 1.2789,
1.1080, 1.1457, 1.7604, 2.5747, 1.7570, 0.5919, 1.4384, 0.8145, 4.1783,
1.1613, 0.8638, 0.7660, 3.4310, 1.7689, 0.8914, 1.2057, 0.8660, 0.7634,
0.6558, 1.7399, 1.2640, 1.6545, 1.7261, 1.1976, 0.6633, 1.5383, 3.7988,
0.7230, 1.3519, 0.5643, 0.5462, 1.0492, 0.7756, 1.6172, 1.1215, 0.8326,
1.5099, 0.8606, 2.4929, 0.4736, 1.5953, 1.1830, 0.6277, 1.9548, 1.5805,
1.6124, 1.5544, 1.3946, 1.4689, 1.7732, 1.8245, 1.2772, 1.3543, 1.6477,
2.2249, 1.4015, 1.1677, 0.6915, 1.3482, 0.9322, 1.4587, 1.7425, 1.5820,
0.7349, 1.2672, 3.4010, 1.0449, 0.6197, 1.6636, 1.7152, 0.6885, 1.6232,
1.3337, 1.7424, 0.9620, 1.2479, 3.6048, 1.0400, 1.2665, 0.6770, 1.6031],
device='cuda:0'))
可以看到,刚刚创建的网络中的权重有一大部分都是0或1,而经过更新之后就不是这样子了。
总结实现提取预训练模型的权重主要的思想就是对比、找出预训练模型和自己新建网络的公共部分的键值对。如果遇到的是没有键的预训练模型,那就比较难受了,就不是本文的讨论范围了,总的代码如下:
# 查看backbone部分的预训练模型键值对
backbone_model_path = "backbone.pth"
backbone_train_model = torch.load(backbone_model_path)
print(backbone_train_model .keys())
# 查看网络的键值对
model = net()
model_stat_dict = model.state_dict()
print(model_stat_dict.keys())
# 提取backbone部分的权重
backbone_stat_dict = {}
for i in backbone_train_model.keys():
if 'backbone' in i and 'FP_modules.1.mlps' not in i:
backbone_stat_dict[i.replace('backbone.', '')] = backbone_train_model[i]
# 更新网络权重
model_stat_dict.update(backbone_stat_dict)
model.load_state_dict(model_stat_dict)
print('---------- load backbone pretrain model to new network successfully !!! ---------')
# 查看除backbone部分其他模块在另一个预训练模型的键值对
other_model_path = "other.pth"
other_train_model = torch.load(other_model_path)
print(other_train_model['model'].keys())
# 提取其他模块的训练权重
other_state_dict = {}
for i in other_train_model['model'].keys():
if 'backbone_net' not in i and i.replace('module.', '') in model_stat_dict .keys():
other_state_dict [i.replace('module.', '')] = other_train_model['model'][i]
# 更新权重
model_stat_dict.update(other_state_dict)
model.load_state_dict(model_stat_dict)
print('---------- load other pretrain model to new network successfully !!! ---------')
结语
本篇文章到此结束,本人水平有限,如有问题敬请指出。



