U-net网络用于多分类——坑点

一共有data.py、main.py、model.py三个文件
参数修改:
flag_multi_class=True
num_class=类数
一、data.py
1、adjustData函数改为:

def adjustData(img,mask,flag_multi_class,num_class):
    if(flag_multi_class):#多类情况
        img = img / 255
        mask = mask[:,:,:,0] if(len(mask.shape) == 4) else mask[:,:,0]
        new_mask = np.zeros(mask.shape + (num_class,))
        new_mask[mask == 0,0] = 1
        new_mask[mask == 50,1] = 1
        new_mask[mask == 150,2] = 1
        new_mask[mask == 255,3] = 1
        mask = new_mask
    elif(np.max(img) > 1):
        img = img / 255
        mask = mask /255
        mask[mask > 0.5] = 1
        mask[mask <= 0.5] = 0
    return (img,mask)

2、未完待续
二、model.py
1、修改如下:
在这里插入图片描述
其中最后一行因为所做的是四分类,所以第一个数字是4
2、修改如下:
在这里插入图片描述
loss也可以是自己定义的损失函数
三、main.py
1、未完待续

总结:
在这里插入图片描述

  • 1
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值