반응형
back = Back().cuda()
state_dict = torch.load(config['back_init_model'])
back.load_state_dict(state_dict['state_dict'])
모델을 GPU에 올린뒤에
load 명령어로 state_dict 파일을 불러오고
모델.load_state_dict로 불러온 state_dict을 적용하면
GPU 메모리가 낭비된다. 즉 CUDA out of memory가 발생할 수 있다는 것.
load_state_dict이 적용된 모델과 처음에 정의한 모델이 함께 GPU에 올라가서 발생한 문제인 것 같다.
먼저 CPU 에 모델을 올리고 나서 load_state_dict으로 state_dict을 적용한 후에 GPU에 올리면 된다.
back = Back().cpu()
state_dict = torch.load(config['back_init_model'], map_location='cpu')
back.load_state_dict(state_dict['state_dict'])
back = back.cuda()
출처
https://discuss.pytorch.org/t/load-state-dict-causes-memory-leak/36189
반응형