Python/PyTorch 공부

[Pytorch] Load state_dict로 인한 out of memory

AI 꿈나무 2021. 12. 14. 01:50
반응형
back = Back().cuda()
state_dict = torch.load(config['back_init_model'])
back.load_state_dict(state_dict['state_dict'])

모델을 GPU에 올린뒤에

load 명령어로 state_dict 파일을 불러오고

모델.load_state_dict로 불러온 state_dict을 적용하면

GPU 메모리가 낭비된다. 즉 CUDA out of memory가 발생할 수 있다는 것.

 

 

load_state_dict이 적용된 모델과 처음에 정의한 모델이 함께 GPU에 올라가서 발생한 문제인 것 같다.

 

먼저 CPU 에 모델을 올리고 나서 load_state_dict으로 state_dict을 적용한 후에 GPU에 올리면 된다.

back = Back().cpu()
state_dict = torch.load(config['back_init_model'], map_location='cpu')
back.load_state_dict(state_dict['state_dict'])
back = back.cuda()

 

출처

https://discuss.pytorch.org/t/load-state-dict-causes-memory-leak/36189

 

Load_state_dict causes memory leak

If you store a state_dict using torch.save, and then load that state_dict (or another), it doesn’t just replace the weights in your current model. It loads the new values into GPU memory and then maybe releases the old GPU memory. If I set my vector leng

discuss.pytorch.org

 

반응형