def train(model, train_loader, criterion, optimizer):
losses = []
for i, (image, target) in enumerate(train_loader):
#target为mask的图片(258,256,1)
image, target = image.to(DEVICE), target.float().to(DEVICE)
#梯度归零
optimizer.zero_grad()
output = model(image)
#计算损失
loss = criterion(output, target, 0.2, False)
loss.backward()
optimizer.step()
losses.append(loss.item())
# print('train, ', loss.item())
return np.array(losses).mean()
def np_dice_score(probability, mask,threshold=0.5):
p = probability.reshape(-1)
t = mask.reshape(-1)
p = p>threshold
t = t>threshold
uion = p.sum() + t.sum()
overlap = (p*t).sum()
dice = 2*overlap/(uion+0.001)
return dice
def validation(model, val_loader, criterion,threshold=0.5):
val_probability, val_mask = [], []
model.eval()
with torch.no_grad():
for image, target in val_loader:
image, target = image.to(DEVICE), target.float().to(DEVICE)
output = model(image)
output_ny = output.sigmoid().data.cpu().numpy()
target_np = target.data.cpu().numpy()
val_probability.append(output_ny)
val_mask.append(target_np)
val_probability = np.concatenate(val_probability)
val_mask = np.concatenate(val_mask)
return np_dice_score(val_probability, val_mask,threshold=threshold)
best_dice = 0
for epoch in range(1, EPOCHES+1):
start_time = time.time()
model.train()
train_loss = train(model, train_loader, loss_fn, optimizer)
val_dice = validation(model, val_loader, loss_fn,0.5)
lr_step.step(val_dice)
if val_dice > best_dice:
best_dice = val_dice
i=5
torch.save(model.state_dict(), './fold_{}.pth'.format(i))
print("best_savefold_{}.pth ".format(fold_idx))
print("epoch:",epoch, "train_loss:",train_loss, "val_dice:",val_dice ,"best_dice:",best_dice)