一共有data.py、main.py、model.py三个文件
参数修改:
flag_multi_class=True
num_class=类数
一、data.py
1、adjustData函数改为:
def adjustData(img,mask,flag_multi_class,num_class):
if(flag_multi_class):#多类情况
img = img / 255
mask = mask[:,:,:,0] if(len(mask.shape) == 4) else mask[:,:,0]
new_mask = np.zeros(mask.shape + (num_class,))
new_mask[mask == 0,0] = 1
new_mask[mask == 50,1] = 1
new_mask[mask == 150,2] = 1
new_mask[mask == 255,3] = 1
mask = new_mask
elif(np.max(img) > 1):
img = img / 255
mask = mask /255
mask[mask > 0.5] = 1
mask[mask <= 0.5] = 0
return (img,mask)
2、未完待续
二、model.py
1、修改如下:
其中最后一行因为所做的是四分类,所以第一个数字是4
2、修改如下:
loss也可以是自己定义的损失函数
三、main.py
1、未完待续
总结: