语义分割评价指标
import torch
import torch.optim
from dataset import SelfDataSet
import os
import torch.nn as nn
import sys
from model import UNet
from torch.utils.data import Dataset
from torch import optim, utils
import time
from tqdm import tqdm
from torchvision.utils import save_image
import numpy as np
def calculate_metrics(pred, target):
pred = torch.argmax(pred, dim=1)
intersection = torch.sum(pred * target)
union = torch.sum(pred) + torch.sum(target) - intersection
iou = (intersection + 1e-6) / (union + 1e-6)
true_positives = torch.sum(pred * target)
false_positives = torch.sum(pred * (1 - target))
false_negatives = torch.sum((1 - pred) * target)
precision = true_positives / (true_positives + false_positives + 1e-6)
recall = true_positives / (true_positives + false_negatives + 1e-6)
f1_score = 2 * (precision * recall) / (precision + recall + 1e-6)
acc = torch.sum(pred == target) / torch.numel(pred)
return iou, f1_score, recall, acc, precision
def Train_Unet(net, device, train_data_path, batch_size=3, epochs=40, lr=0.0001):
train_dataset = SelfDataSet(train_data_path)
train_loader = utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
opt = optim.Adam((net.parameters()))
loss_fun = nn.CrossEntropyLoss()
save_path = 'train_image'
i = 0
for epoch in range(epochs):
net.train()
train_bar = tqdm(train_loader, file=sys.stdout)
for image, label in train_bar:
image = image.to(device=device)
label = label.to(device=device)
pred = net(image)
loss = loss_fun(pred, label.long())
opt.zero_grad()
loss.backward()
opt.step()
iou, f1_score, recall, acc, precision = calculate_metrics(pred, label)
train_bar.desc = f'[epoch {epoch + 1}/{epochs}] loss: {loss.item():.4f}, IoU: {iou.item():.4f}, F1: {f1_score.item():.4f}, Recall: {recall.item():.4f}, Accuracy: {acc.item():.4f}, Precision: {precision.item():.4f}'
_image = image[0]
_segment_image = torch.unsqueeze(label[0], 0) * 255
_out_image = torch.argmax(pred[0], dim=0).unsqueeze(0) * 255
img = torch.stack([_segment_image, _out_image], dim=0)
save_image(img, f'{save_path}/{i}.png')
i += 1
if __name__ == '__main__':
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
net = UNet(3, 2, bilinear=False)
net.to(device=device)
train_data_path = r"./data/train/image"
Train_Unet(net, device, train_data_path, epochs=40, batch_size=8)