【Keras】基于SegNet和U-Net的遥感图像语义分割(二)

训练U_net模型:unet_train.py

代码:
https://github.com/fuyou123/Segmentation_Unet
在这里插入图片描述

U-Net网络训练

1. args = args_parse()

def args_parse():
    # construct the argument parse and parse the arguments
    ap = argparse.ArgumentParser()
    ap.add_argument("-d", "--data", default="./unet_train/road/",
                    help="training data's path")
    ap.add_argument("-m", "--model", default="Trained_Unet_Model.h5",
                    help="path to output model")
    ap.add_argument("-p", "--plot", type=str, default="plot.png",
                    help="path to output accuracy/loss plot")
    args = vars(ap.parse_args()) 
    return args
注意:对于命令行参数的设定,参考Argparse的使用。由于我的调试过程是在windows环境下,所以路径设为了默认值,便于操作。

2. model = unet() #模型初始化

调试可能出现的问题:ValueError: Negative dimension size caused by subtracting 2 from 1 for 'max_pooling2d_2/Max...
解决:https://blog.csdn.net/weixin_43723625/article/details/104918997
def unet():
    inputs = Input((3, img_w, img_h))

    conv1 = Conv2D(32, (3, 3), activation="relu", padding="same")(inputs)
    conv1 = Conv2D(32, (3, 3), activation="relu", padding="same")(conv1)
    pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)

    conv2 = Conv2D(64, (3, 3), activation="relu", padding="same")(pool1)
    conv2 = Conv2D(64, (3, 3), activation="relu", padding="same")(conv2)
    pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)

    conv3 = Conv2D(128, (3, 3), activation="relu", padding="same")(pool2)
    conv3 = Conv2D(128, (3, 3), activation="relu", padding="same")(conv3)
    pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)

    conv4 = Conv2D(256, (3, 3), activation="relu", padding="same")(pool3)
    conv4 = Conv2D(256, (3, 3), activation="relu", padding="same")(conv4)
    pool4 = MaxPooling2D(pool_size=(2, 2))(conv4)

    conv5 = Conv2D(512, (3, 3), activation="relu", padding="same")(pool4)
    conv5 = Conv2D(512, (3, 3), activation="relu", padding="same")(conv5)

    up6 = concatenate([UpSampling2D(size=(2, 2))(conv5), conv4], axis=1)
    conv6 = Conv2D(256, (3, 3), activation="relu", padding="same")(up6)
    conv6 = Conv2D(256, (3, 3), activation="relu", padding="same")(conv6)

    up7 = concatenate([UpSampling2D(size=(2, 2))(conv6), conv3], axis=1)
    conv7 = Conv2D(128, (3, 3), activation="relu", padding="same")(up7)
    conv7 = Conv2D(128, (3, 3), activation="relu", padding="same")(conv7)

    up8 = concatenate([UpSampling2D(size=(2, 2))(conv7), conv2], axis=1)
    conv8 = Conv2D(64, (3, 3), activation="relu", padding="same")(up8)
    conv8 = Conv2D(64, (3, 3), activation="relu", padding="same")(conv8)

    up9 = concatenate([UpSampling2D(size=(2, 2))(conv8), conv1], axis=1)
    conv9 = Conv2D(32, (3, 3), activation="relu", padding="same")(up9)
    conv9 = Conv2D(32, (3, 3), activation="relu", padding="same")(conv9)

    conv10 = Conv2D(n_label, (1, 1), activation="sigmoid")(conv9)
    #conv10 = Conv2D(n_label, (1, 1), activation="softmax")(conv9)

    model = Model(inputs=inputs, outputs=conv10)
    model.compile(optimizer='Adam', loss='binary_crossentropy', metrics=['accuracy'])
    return model

在这里插入图片描述

3. get_train_val()

def get_train_val(val_rate = 0.25):
    train_url = []    
    train_set = []    # 训练数据集
    val_set  = []     # 测试训练集
    # 将待训练的图片路径存入列表
    for pic in os.listdir(filepath + 'src'):
        train_url.append(pic)
    random.shuffle(train_url)   # 将序列的所有元素随机排序
    total_num = len(train_url)
    # 将数据集分为2部分,3/4用于训练,1/4用于检验
    val_num = int(val_rate * total_num)
    for i in range(len(train_url)):
        if i < val_num:
            val_set.append(train_url[i]) 
        else:
            train_set.append(train_url[i])
    return train_set,val_set

在这里插入图片描述
由于在训练过程中,为了检验训练中的效果,一般是从原数据集中取出一小部分数据用于检验训练的效果。而不是等全部训练结束后,再检验训练模型。

4. model.fit_generator()

# 使用 Python 生成器逐批生成的数据,按批次训练模型
 H = model.fit_generator(generator=generateData(BS,train_set),steps_per_epoch=train_numb//BS,epochs=EPOCHS,verbose=1,  
                    validation_data=generateValidData(BS,val_set),validation_steps=valid_numb//BS,callbacks=callable,max_q_size=1)  

参数:
(1) generator=generateData(BS,train_set) # BS=16
在这里插入图片描述
定义训练时,选取数据的方式是,将数据集中每16个为一组进行训练
(2) steps_per_epoch=train_numb//BS # 整数,表示一次epoch需要训练的组数
(3) epochs=EPOCHS # 迭代次数

(4) validation_data=generateValidData(BS,val_set) # BS=16 , 定义验证集的生成器
(5) validation_steps=valid_numb//BS # 整数,表示一次epoch需要训练的组数

结果:
(调试过程,可以将数据集减少些,同时减少迭代的次数)
在这里插入图片描述
在这里插入图片描述

  • 3
    点赞
  • 31
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 20
    评论
评论 20
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

代码小白的成长

计算机网络PPT下载

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值