Code/파이썬
[Pytorch] fine-tuning 시, model.load_state_dict() 모델 파라미터 로드 오류
마메프
2023. 1. 12. 17:25
반응형
case 1 : 없는 layer 무시
model.load_state_dict(torch.load(opt.saved_model), strict=False)
case 2 : layer 에서의 size mismatch
#...
import logging
logger = logging.getLogger()
logger.setLevel(logging.INFO)
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
stream_handler = logging.StreamHandler()
stream_handler.setFormatter(formatter)
logger.addHandler(stream_handler)
#...
class Module:
#...
def on_load_checkpoint(self, checkpoint: dict)-> None:
state_dict = checkpoint
model_state_dict = self.state_dict()
is_change = False
for k in state_dict:
if k in model_state_dict:
if state_dict[k].shape != model_state_dict[k].shape:
logger.info(f"Skip loading parameter: {k}, "
f"required shape: {model_state_dict[k].shape}, "
f"loaded shape: {state_dict[k].shape}")
state_dict[k] = model_state_dict[k]
is_change = True
else:
logger.info(f"Dropping parameter {k}")
is_change = True
if is_change:
checkpoint.pop("optimizer_states", None)
#...
이 코드를 pytorch 의 module.py 에 기존에 존재하던 load_state_dict 함수 위에 넣어준다.
(*/lib/python3.7/site-packages/torch/nn/modules/module.py)
그리고 이렇게 사용.
model.on_load_checkpoint(torch.load(opt.saved_model))
반응형