반응형
case 1 : 없는 layer 무시
model.load_state_dict(torch.load(opt.saved_model), strict=False)
case 2 : layer 에서의 size mismatch
#...
import logging
logger = logging.getLogger()
logger.setLevel(logging.INFO)
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
stream_handler = logging.StreamHandler()
stream_handler.setFormatter(formatter)
logger.addHandler(stream_handler)
#...
class Module:
#...
def on_load_checkpoint(self, checkpoint: dict)-> None:
state_dict = checkpoint
model_state_dict = self.state_dict()
is_change = False
for k in state_dict:
if k in model_state_dict:
if state_dict[k].shape != model_state_dict[k].shape:
logger.info(f"Skip loading parameter: {k}, "
f"required shape: {model_state_dict[k].shape}, "
f"loaded shape: {state_dict[k].shape}")
state_dict[k] = model_state_dict[k]
is_change = True
else:
logger.info(f"Dropping parameter {k}")
is_change = True
if is_change:
checkpoint.pop("optimizer_states", None)
#...
이 코드를 pytorch 의 module.py 에 기존에 존재하던 load_state_dict 함수 위에 넣어준다.
(*/lib/python3.7/site-packages/torch/nn/modules/module.py)
그리고 이렇게 사용.
model.on_load_checkpoint(torch.load(opt.saved_model))
반응형
'Code > 파이썬' 카테고리의 다른 글
[Python] Selenium 을 이용한 뉴스 크롤링 해오기 (feat. Beutiful Soup) (0) | 2021.10.22 |
---|---|
[파이썬] 자주쓰이는 파이썬 자료구조 & 메소드 시간복잡도 (0) | 2021.10.04 |
[파이썬]Opencv Mat 를 PIL image 포맷으로 변환하기 및 PIL image -> Opencv Mat (0) | 2021.06.10 |
[Python]아나콘다 가상환경 기본 사용법 (0) | 2021.05.20 |
[Python]Selenium 을 이용한 이미지 크롤링 해오기(+추가 21.07.30) (0) | 2021.05.11 |