728x90
Checkpoint 저장
import torch
import shutil
def save_ckp(state, is_best, checkpoint_dir, best_model_dir):
f_path = checkpoint_dir / 'checkpoint.pt'
torch.save(state, f_path)
if is_best:
best_fpath = best_model_dir / 'best_model.pt'
shutil.copyfile(f_path, best_fpath)
checkpoint = {
'epoch': epoch + 1,
'state_dict': model.state_dict(),
'optimizer': optimizer.state_dict()
}
save_ckp(checkpoint, is_best, checkpoint_dir, model_dir)
위 코드에서 보듯이 optimizer와 epoch 정보를 같이 저장해준다.
Checkpoint 불러오기
def load_ckp(checkpoint_fpath, model, optimizer):
checkpoint = torch.load(checkpoint_fpath)
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])
return model, optimizer, checkpoint['epoch']
model = MyModel(*args, **kwargs)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
ckp_path = "path/to/checkpoint/checkpoint.pt"
model, optimizer, start_epoch = load_ckp(ckp_path, model, optimizer)
- 모델을 불러올때는 'state_dict' key를 지정해줘야함
- optimizer도 load_state_dict를 사용하고 'optimizer' key를 지정해주어 이전 학습 정보를 가져올 수 있음
- checkpoint 안에 있는 epoch부터 학습을 시작
[ref]
728x90
'Temp' 카테고리의 다른 글
[Colab] 연결 끊김 방지 (0) | 2021.11.29 |
---|---|
[Pytorch] Optimizer learning rate 가져오기 (0) | 2021.11.27 |
[parafac] not enough values to unpack (expected 4, got 2) (0) | 2021.11.26 |
[Python] Dictionary Key Name 수정 (0) | 2021.11.26 |
[CV2] ImportError: libGL.so.1: cannot open shared object file: No such file or directory (0) | 2021.11.26 |