Temp

[Pytorch] model 불러와서 resume 하는 방법

ju_young 2021. 11. 27. 21:58
728x90

Checkpoint 저장

import torch
import shutil
def save_ckp(state, is_best, checkpoint_dir, best_model_dir):
    f_path = checkpoint_dir / 'checkpoint.pt'
    torch.save(state, f_path)
    if is_best:
        best_fpath = best_model_dir / 'best_model.pt'
        shutil.copyfile(f_path, best_fpath)
checkpoint = {
    'epoch': epoch + 1,
    'state_dict': model.state_dict(),
    'optimizer': optimizer.state_dict()
}
save_ckp(checkpoint, is_best, checkpoint_dir, model_dir)

위 코드에서 보듯이 optimizer와 epoch 정보를 같이 저장해준다.

Checkpoint 불러오기

def load_ckp(checkpoint_fpath, model, optimizer):
    checkpoint = torch.load(checkpoint_fpath)
    model.load_state_dict(checkpoint['state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer'])
    return model, optimizer, checkpoint['epoch']

model = MyModel(*args, **kwargs)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
ckp_path = "path/to/checkpoint/checkpoint.pt"
model, optimizer, start_epoch = load_ckp(ckp_path, model, optimizer)​
  • 모델을 불러올때는 'state_dict' key를 지정해줘야함
  • optimizer도 load_state_dict를 사용하고 'optimizer' key를 지정해주어 이전 학습 정보를 가져올 수 있음
  • checkpoint 안에 있는 epoch부터 학습을 시작

[ref]

https://medium.com/analytics-vidhya/saving-and-loading-your-model-to-resume-training-in-pytorch-cb687352fa61

728x90