# EMA
ema = ModelEMA(model) if RANK in [-1, 0] else None
# Resume
start_epoch, best_fitness = 0, 0.0
if pretrained:
# Optimizer
if ckpt['optimizer'] is not None:
optimizer.load_state_dict(ckpt['optimizer'])
best_fitness = ckpt['best_fitness']
# EMA
ckpt['new_ema'] = []
for emaa in ckpt['ema'].state_dict():
ckpt['new_ema'].append(emaa)
new_weights = []
for k,v in ckpt['ema'].float().state_dict().items():
if k.startswith('model.24.m.0.weight'):
new_v = torch.rand([27, 128, 1, 1])
new_weights.append(new_v)
elif k.startswith('model.24.m.1.weight'):
new_v = torch.rand([27, 256, 1, 1])
new_weights.append(new_v)
elif k.startswith('model.24.m.2.weight'):
new_v = torch.rand([27, 512, 1, 1])
new_weights.append(new_v)
elif k.startswith('model.24.m'):
new_v = torch.rand([27])
new_weights.append(new_v)
else:
new_weights.append(v)
ckpt['my_weight'] = dict(zip(ckpt['new_ema'], new_weights))
if ema and ckpt.get('ema'):
ema.ema.load_state_dict(ckpt['my_weight'])
# ema.ema.load_state_dict(ckpt['ema'].float().state_dict())
ema.updates = ckpt['updates']
# ema.state_dict().update(ckpt['my_weight'])
默认加载方式
# EMA
ema = ModelEMA(model) if RANK in [-1, 0] else None
# Resume
start_epoch, best_fitness = 0, 0.0
if pretrained:
# Optimizer
if ckpt['optimizer'] is not None:
optimizer.load_state_dict(ckpt['optimizer'])
best_fitness = ckpt['best_fitness']
# EMA
if ema and ckpt.get('ema'):
ema.ema.load_state_dict(ckpt['ema'].float().state_dict())
ema.updates = ckpt['updates']