加载预训练权重的几种方法
自己学习记录
方法一
if model_path != '': # model_path为预训练权重的路径
pretrained_dict = torch.load(model_path)
all_params = {}
for k, v in model.state_dict().items():
# model为实例化的模型
if k in pretrained_dict.keys():
v = pretrained_dict[k]
all_params[k] = v
else:
all_params[k] = v
model.load_state_dict(all_params)
方法二
if model_path != '':
pretrained_dict = torch.load(model_path)
all_params = {}
for k, v in model.state_dict().items():
# print(k)
# 用于'head.cls_loc'在pretrained_dict里面,但是尺寸不一致
if 'head.cls_loc' in k:
all_params[k] = v
elif 'head.score' in k:
all_params[k] = v
elif k in pretrained_dict.keys():
v = pretrained_dict[k]
all_params[k] = v
else:
all_params[k] = v
model.load_state_dict(all_params)
方法三
if model_path != '':
# 加载部分预训练模型
model_dict = model.state_dict()
pretrained_dict = torch.load(model_path)
pretrained_dict = {k:v for k, v in pretrained_dict.items() if np.shape(model_dict[k]) == np.shape(v)}
model_dict.update(pretrained_dict)
model.load_state_dict(model_dict)



