在语义分割过程中,进行二分类和多分类时,如何使用dice系数以及相应的代码
1. dice系数的基础知识
参考链接:https://www.aiuai.cn/aifarm1159.html
2. 语义分割
参考链接:https://editor.csdn.net/md?not_checkout=1&spm=1000.2115.3001.4503&articleId=127461922
语义分割的目标是:将一张RGB图像或是灰度图作为输入,输出的是分割图,其中每一个像素包含了其类别的标签。
从图中可以发现,每一个类用相同的像素值表示,因此在设计多分类的dice系数时,可以采用对每个类(像素值),进行一次dice计算
3. 代码
二分类时,采用sigmoid与Bceloss的组合
#计算Dice系数二分类
def meandice(pred, label):
sumdice = 0
smooth = 1e-6
pred_bin = pred
label_bin = label
pred_bin = pred_bin.contiguous().view(pred_bin.shape[0], -1)
label_bin = label_bin.contiguous().view(label_bin.shape[0], -1)
intersection = (pred_bin * label_bin).sum()
dice = (2. * intersection + smooth) / (pred_bin.sum() + label_bin.sum() + smooth)
sumdice += dice
return sumdice
#设置loss
criterion = torch.nn.BCELoss() # define loss and cross entropy
#设置保存模型的条件
bestdice = 0
bestepoch = 0
for epoch in range(1000):
# 训练
# scheduler.step(epoch)
cnt = 0
losssum = 0
model.train()
for image, label in tqdm(dataloader):
optimizer.zero_grad()
image, label = image.to(device), label.to(device)
label[label != 1] = 0
out = model(image)
m = nn.Sigmoid()
out = m(out)
# 计算loss
loss = criterion(out, label)
loss.backward()
optimizer.step()
# 测试
with torch.no_grad():
model.eval()
features = None
print("validating....")
testnum = 0
for image_t, label_t in tqdm(dataloader_test):
image_t, label_t = image_t.to(device), label_t.to(device)
label_t[label_t != 1] = 0 # 屏蔽掉,像素值为1的其他像素值,只保留这一类的像素值。
out_t = model(image_t)
out_t[out_t > 0.5] = 1 # 将预测图中,概率值大于0.5的置为1(如果为其他像素值,可以设置2,3等) 小于置为0, mask的标签对应的是0 和 1
out_t[out_t < 0.5] = 0
rawdice = meandice(out_t, label_t)
多分类时,采用softmax与CEloss的组合
#计算Dice系数 多分类
def meandice(pred, label):
sumdice = 0
smooth = 1e-6
for i in range(1, 5):
pred_bin = (pred==i)*1
label_bin = (label==i)*1
pred_bin = pred_bin.contiguous().view(pred_bin.shape[0], -1)
label_bin = label_bin.contiguous().view(label_bin.shape[0], -1)
intersection = (pred_bin * label_bin).sum()
dice = (2. * intersection + smooth) / (pred_bin.sum() + label_bin.sum() + smooth)
sumdice += dice
return sumdice/4
for epoch in range(1000):
# 训练
# scheduler.step(epoch)
cnt = 0
losssum = 0
model.train()
for image, label in tqdm(dataloader):
optimizer.zero_grad()
image, label = image.to(device), label.to(device).long()
out = model(image) # [1, 6, 256,256], 6表示为5个类和一个背景
m = nn.Softmax(1)
out = m(out)
# 计算loss
loss = criterion(out, label) # label 为 [1, 256, 256] out [1, 6, 256, 256],计算每个类的误差
loss.backward()
optimizer.step()
cnt += 1
losssum += loss
print('Epoch {0},train_loss {1}'.format(epoch, losssum / cnt))
# '''
dicecat = []
resultcat = []
imgcat = []
# 测试
with torch.no_grad():
model.eval()
features = None
print("validating....")
testnum = 0
for image_t, label_t in tqdm(dataloader_test):
image_t, label_t = image_t.to(device), label_t.to(device)
out_t = model(image_t) # out_t [1, 6, 256, 256]
prediction = torch.argmax(out_t, dim=1) # prediction [1, 256, 256] torch.argmax(out_t, dim=1)比较6个[256, 256]上的最大值,并进行赋值,赋值为0,1,2,3,4,5即最后得到的prediction全为0,1,2,3,4,5的值。
rawdice = meandice(prediction,label_t) # label_t 标签值分别为1, 2, 3, 4,5