PyTorch 7.保存和加载pytorch模型的方法
保存和加载模型
python的对象都可以通过torch.save和torch.load函数进行保存和加载
x1 = {
"d":"df","dd":"ddf"}
torch.save(x1,'a1.pt')
x2 = torch.load('a1.pt')
下面来谈模型的state_dict(),该函数返回模型的所有参数
class MLP(nn.Module):
def __init__(self):
super(MLP,self).__init__()
self.hidden = nn.Linear(3,2)
self.act = nn.ReLU()
self.output = nn.Linear(2,1)
def forward(self,x):
a = self.act(self.hidden(x))
return self.output(a)
net = MLP()
net.state_dict()
输出
OrderedDict([('hidden.weight',
tensor([[-0.4195, 0.2609, 0.4325],
[-0.4031, 0.2078, 0.2077]])),
('hidden.bias', tensor([ 0.0755, -0.1408])),
('output.weight', tensor([[0.2473, 0.6614]])),
('output.bias', tensor([0.6191]))])
torch.save&torch.load
- 保存整个模型
如果选择保存模型,那么可以不需要预先创建模型的实例,可以直接加载模型及其参数
torch.save(net,path)
- 保存模型参数
选择保存模型参数,在加载时需要先创建模型实例
state_dict = net.state_dict()
torch.save(state_dict,path)
- 模型finetune
如果模型训练中不小心中断了,或者需要用该模型去其他模型进行finetune。我们不仅要保存模型参数,还需要保存模型的训练周期及优化器参数。
这里,我们经常会看到一个叫checkpoint的东东,它其实是一个字典
checkpoint = {
"model_state_dict":net.state_dict(),
"optimizer_state_dict":optimizer.state_dict(),
"epoch":epoch
}
# 保存
torch.save(checkpoint, path_checkpoint)
# 加载
checkpoint = torch.load(path_checkpoint)
net.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
start_epoch = checkpoint['epoch']
scheduler.last_epoch = start_epoch
跨设备保存加载模型
- 在CPU上加载在GPU上训练并保存的模型:
device = torch.device('cpu')
model = MyModel()
model.load_state_dict(torch.load('net_params.pth', map_location=device))
- 在GPU上加载在GPU上训练并保存的模型:
device = torch.device('cuda')
model = MyModel()
model.load_state_dict(torch.load('net_params.pth'))
model.to(device)
在这里使用map_location参数不起作用,要使用model.to(torch.device(“cuda”))将模型转换为CUDA优化的模型
数据也要转换到GPU
由于my_tensor.to(device)会返回一个my_tensor在GPU上的副本,它不会覆盖my_tensor
my_tensor = my_tensor.to(device)
存在多个GPU设备时
map_location指定tensor加载的GPU序号
model.load_state_dict(torch.load('net_params.pth'),map_location='cuda:0')
多GPU训练,单GPU加载
def load_model(model, model_path, optimizer=None,resume=False,
lr=None, lr_step=None):
start_epoch = 0
checkpoint = torch.load(model_path, map_location=lambda storage, loc: storage)
print('loaded {}, epoch {}'.format(model_path, checkpoint['epoch']))
state_dict_ = checkpoint['state_dict']
state_dict = {
}
# 将data_parallal转换到model
for k in state_dict_:
if k.startswith('module') and not k.startswith('module_list'):
state_dict[k[7:]] = state_dict_[k]
else:
state_dict[k] = state_dict_[k]
model_state_dict = model.state_dict()
# 加载模型,检查参数,创建模型参数
msg = 'If you see this, your model does not fully load the pre-trained weight.'
for k in state_dict:
if k in model_state_dict:
if state_dict[k].shape != model_state_dict[k].shape:
print('Skip loading parameter {}, required shape{}, loaded shape{}.{}'.format(k,model_state_dict[k].shape,state_dict[k].shape,msg))
state_dict[k] = model_state_dict[k]
else:
#参数在预加载模型中存在,在本地模型不存在,丢弃参数
print('Drop parameter {}.'.format(k)+msg)
for k in model_state_dict:
if not (k in state_dict):
# 本地模型需要的参数,预加载模型中不存在,则根据本地模型参数加载
print('No param {}.'.format(k)+msg)
state_dict[k] = model_state_dict[k]
model.load_state_dict(state_dict,strict = False)
# 断点继续训练
if optimizer is not None and resume:
if 'optimizer' in checkpoint:
optimizer.load_state_dict(checkpoint['optimizer'])
start_epoch = checkpoint['epoch']
start_lr = lr
for step in lr_step:
if start_epoch >= step:
start_lr *= 0.1
for param_group in optimizer.param_groups:
param_group['lr'] = start_lr
print('Resumed optimizer with start lr', start_lr)
else:
print('No optimizer parameters in checkpoint.')
if optimizer is not None:
return model, optimizer, start_epoch
else:
return model