torch常用train和val训练过程记录

def train(model, train_loader, criterion, optimizer):
    losses = []
    for i, (image, target) in enumerate(train_loader):
        #target为mask的图片(258,256,1)
        image, target = image.to(DEVICE), target.float().to(DEVICE)
        #梯度归零
        optimizer.zero_grad()
        
        output = model(image)
        #计算损失
        loss = criterion(output, target, 0.2, False)
        loss.backward()
        optimizer.step()
        losses.append(loss.item())
        # print('train, ', loss.item())
    return np.array(losses).mean()

def np_dice_score(probability, mask,threshold=0.5):
    p = probability.reshape(-1)
    t = mask.reshape(-1)

    p = p>threshold
    t = t>threshold
    uion = p.sum() + t.sum()
    
    overlap = (p*t).sum()
    dice = 2*overlap/(uion+0.001)
    return dice

 

 

def validation(model, val_loader, criterion,threshold=0.5):
    val_probability, val_mask = [], []
    model.eval()
    with torch.no_grad():
        for image, target in val_loader:
            image, target = image.to(DEVICE), target.float().to(DEVICE)
            output = model(image)
            
            output_ny = output.sigmoid().data.cpu().numpy()
            target_np = target.data.cpu().numpy()
            
            val_probability.append(output_ny)
            val_mask.append(target_np)
            
    val_probability = np.concatenate(val_probability)
    val_mask = np.concatenate(val_mask)
    
    return np_dice_score(val_probability, val_mask,threshold=threshold)

 

    best_dice = 0
    
    for epoch in range(1, EPOCHES+1):
        start_time = time.time()
        model.train()
        train_loss = train(model, train_loader, loss_fn, optimizer)
        val_dice = validation(model, val_loader, loss_fn,0.5)
        lr_step.step(val_dice)
        
        if val_dice > best_dice:
            best_dice = val_dice
            i=5
            torch.save(model.state_dict(), './fold_{}.pth'.format(i))
            print("best_savefold_{}.pth ".format(fold_idx))
        print("epoch:",epoch, "train_loss:",train_loss, "val_dice:",val_dice ,"best_dice:",best_dice)

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

CVer儿

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值