for epoch in range(num_epochs):
print('Epoch {}/{}'.format(epoch, num_epochs - 1))
print('-' * 10)
# Each epoch has a training and validation phase
for phase in ['train', 'val']:
if phase == 'train':
model.train(True) # Set model to training mode
else:
model.train(False) # Set model to evaluate mode
running_loss = 0.0
running_corrects = 0.0
# Iterate over data.
for iter, data in enumerate(dataloaders[phase]):
# get the inputs
inputs, labels = data
now_batch_size,c,h,w = inputs.shape
if now_batch_size<opt.batchsize: # skip the last batch
continue
#print(inputs.shape)
# wrap them in Variable
if use_gpu:
inputs = Variable(inputs.cuda().detach())
labels = Variable(labels.cuda().detach())
else:
inputs, labels = Variable(inputs), Variable(labels)
# if we use low precision, input also need to be fp16
#if fp16:
# inputs = inputs.half()
# zero the parameter gradients
optimizer.zero_grad()
# forward
if phase == 'val':
with torch.no_grad():
outputs = model(inputs)
else:
outputs = model(inputs)
if opt.adv>0 and iter%opt.aiter==0:
inputs_adv = ODFA(model, inputs)
outputs_adv = model(inputs_adv)
sm = nn.Softmax(dim=1)
log_sm = nn.LogSoftmax(dim=1)
return_feature = opt.arcface or opt.cosface or opt.circle or opt.triplet or opt.contrast or opt.instance or opt.lifted or opt.sphere
if return_feature:
logits, ff = outputs
fnorm = torch.norm(ff, p=2, dim=1, keepdim=True)
ff = ff.div(fnorm.expand_as(ff))
loss = criterion(logits, labels)
_, preds = torch.max(logits.data, 1)
if opt.adv>0 and iter%opt.aiter==0:
logits_adv, _ = outputs_adv
loss += opt.adv * criterion(logits_adv, labels)
if opt.arcface:
loss += criterion_arcface(ff, labels)/now_batch_size
if opt.cosface:
loss += criterion_cosface(ff, labels)/now_batch_size
if opt.circle:
loss += criterion_circle(*convert_label_to_similarity( ff, labels))/now_batch_size
if opt.triplet:
hard_pairs = miner(ff, labels)
loss += criterion_triplet(ff, labels, hard_pairs) #/now_batch_size
if opt.lifted:
loss += criterion_lifted(ff, labels) #/now_batch_size
if opt.contrast:
loss += criterion_contrast(ff, labels) #/now_batch_size
if opt.instance:
loss += criterion_instance(ff) /now_batch_size
if opt.sphere:
loss += criterion_sphere(ff, labels)/now_batch_size
elif opt.PCB: # PCB
part = {}
num_part = 6
for i in range(num_part):
part[i] = outputs[i]
score = sm(part[0]) + sm(part[1]) +sm(part[2]) + sm(part[3]) +sm(part[4]) +sm(part[5])
_, preds = torch.max(score.data, 1)
loss = criterion(part[0], labels)
for i in range(num_part-1):
loss += criterion(part[i+1], labels)
else: # norm
_, preds = torch.max(outputs.data, 1)
loss = criterion(outputs, labels)
if opt.adv>0 and iter%opt.aiter==0:
loss += opt.adv * criterion(outputs_adv, labels)
del inputs
# use extra DG Dataset (https://github.com/NVlabs/DG-Net#dg-market)
if opt.DG and phase == 'train' and epoch > num_epochs*0.1:
try:
_, batch = DGloader_iter.__next__()
except StopIteration:
DGloader_iter = enumerate(dataloaders['DG'])
_, batch = DGloader_iter.__next__()
except UnboundLocalError: # first iteration
DGloader_iter = enumerate(dataloaders['DG'])
_, batch = DGloader_iter.__next__()
inputs1, inputs2, _ = batch
inputs1 = inputs1.cuda().detach()
inputs2 = inputs2.cuda().detach()
# use memory in vivo loss (https://arxiv.org/abs/1912.11164)
outputs1 = model(inputs1)
if return_feature:
outputs1, _ = outputs1
elif opt.PCB:
for i in range(num_part):
part[i] = outputs1[i]
outputs1 = part[0] + part[1] + part[2] + part[3] + part[4] + part[5]
outputs2 = model(inputs2)
if return_feature:
outputs2, _ = outputs2
elif opt.PCB:
for i in range(num_part):
part[i] = outputs2[i]
outputs2 = part[0] + part[1] + part[2] + part[3] + part[4] + part[5]
mean_pred = sm(outputs1 + outputs2)
kl_loss = nn.KLDivLoss(size_average=False)
reg= (kl_loss(log_sm(outputs2) , mean_pred) + kl_loss(log_sm(outputs1) , mean_pred))/2
loss += 0.01*reg
del inputs1, inputs2
#print(0.01*reg)
# backward + optimize only if in training phase
if epoch<opt.warm_epoch and phase == 'train':
warm_up = min(1.0, warm_up + 0.9 / warm_iteration)
loss = loss*warm_up
if phase == 'train':
if fp16: # we use optimier to backward loss
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
else:
loss.backward()
optimizer.step()
# statistics
if int(version[0])>0 or int(version[2]) > 3: # for the new version like 0.4.0, 0.5.0 and 1.0.0
running_loss += loss.item() * now_batch_size
else : # for the old version like 0.3.0 and 0.3.1
running_loss += loss.data[0] * now_batch_size
del loss
running_corrects += float(torch.sum(preds == labels.data))
epoch_loss = running_loss / dataset_sizes[phase]
epoch_acc = running_corrects / dataset_sizes[phase]
print('{} Loss: {:.4f} Acc: {:.4f}'.format(
phase, epoch_loss, epoch_acc))
y_loss[phase].append(epoch_loss)
y_err[phase].append(1.0-epoch_acc)
# deep copy the model
if phase == 'val':
last_model_wts = model.state_dict()
if epoch%10 == 9:
save_network(model, epoch)
draw_curve(epoch)
if phase == 'train':
scheduler.step()
time_elapsed = time.time() - since
print('Training complete in {:.0f}m {:.0f}s'.format(
time_elapsed // 60, time_elapsed % 60))
print()
time_elapsed = time.time() - since
print('Training complete in {:.0f}m {:.0f}s'.format(
time_elapsed // 60, time_elapsed % 60))
#print('Best val Acc: {:4f}'.format(best_acc))
# load best model weights
model.load_state_dict(last_model_wts)
save_network(model, 'last')
return model
07-21
07-01
05-11
10-09
05-02