初学UNET Day1
因为学习方面的需要,开始学习Unet网络,因为自身基础较差,能卡懂一点十一点
基本概念
个人感觉新手入门的话如果对于基本盖帘不理解,及时各位大牛将原理讲的有多简单易懂,自己还是简单不了,因此,我认为要学习神经网络,首先要了解基本概念
参考文献:卷积
卷积:卷积的概念是计算连续变化的事物和事物本身的变化对于总体的影响…(因为后来觉得目前是初步学习UNET,对于这种只要知道作用就行了,因此打算之后再去细究)
Unet网络的工作原理:
在查阅过一定的播客以后,得知Unet网络可以解释为一个编码-解码器,Unet网络的大致工作原理是通过左半边对于输入图像的进行卷积与池化进行下采样,通过卷积,图像可以缩小其大小,此后,对于图像依次进行反卷积,通道拼接、上采样进行特征提取,恢复至原图像大小。在此过程中,也完成了相应图像的分类。
Unet代码解析
对于Unet网络代码的剖析,我参照了许多代码,代码所完成的功能都大同小异,但是我还是参照了“研志必有功”、“忽逢桃花林”两位博主的播客,感觉两位博主的播客比较完整,容易理解一点
import torch
import torch.nn as nn
import torch.nn.functional as F
class double_conv2d_bn(nn.Module):#此处是卷积类
def __init__(self,in_channels,out_channels,kernel_size=3,strides=1,padding=1):#这部分是对于卷积的基本值的设定,比如输入与输出图片的通道数,如RGB为3,卷积核大小,移步大小以及补充宽度大小等。
super(double_conv2d_bn,self).__init__()#此处的继承超类应该是指预训练模型??
self.conv1 = nn.Conv2d(in_channels,out_channels,#对于卷积核函数的定义域初始化
kernel_size=kernel_size,
stride = strides,padding=padding,bias=True)
self.conv2 = nn.Conv2d(out_channels,out_channels,#对于卷积核函数的定义域初始化
kernel_size = kernel_size,
stride = strides,padding=padding,bias=True)
self.bn1 = nn.BatchNorm2d(out_channels)#我查了一下BatchNorm2d函数的使用,这是对数据的进行归一化处理,能够防止数据过大而倒数卷积时出现问题
self.bn2 = nn.BatchNorm2d(out_channels)#参数都为out——channel,因此可以看出这个函数是用于解码时的图片分类用的
def forward(self,x):#这个函数应该就是此卷积类的实现函数
out = F.relu(self.bn1(self.conv1(x)))
out = F.relu(self.bn2(self.conv2(out)))
return out
class deconv2d_bn(nn.Module):
def __init__(self,in_channels,out_channels,kernel_size=2,strides=2):#这也是一个卷积,在参考了该博主的播客以后,此处应该是用于上采样的卷积设定
super(deconv2d_bn,self).__init__()
self.conv1 = nn.ConvTranspose2d(in_channels,out_channels,
kernel_size = kernel_size,
stride = strides,bias=True)
self.bn1 = nn.BatchNorm2d(out_channels)
def forward(self,x):#此反卷积的实现函数
out = F.relu(self.bn1(self.conv1(x)))
return out
class Unet(nn.Module):#这部分应该是Unet网络的整体框架
def __init__(self):
super(Unet,self).__init__()
self.layer1_conv = double_conv2d_bn(1,8)
self.layer2_conv = double_conv2d_bn(8,16)
self.layer3_conv = double_conv2d_bn(16,32)
self.layer4_conv = double_conv2d_bn(32,64)
self.layer5_conv = double_conv2d_bn(64,128)#此处应该是不断改变卷积核的通道数与数量,从而控制对图片的转换
self.layer6_conv = double_conv2d_bn(128,64)
self.layer7_conv = double_conv2d_bn(64,32)
self.layer8_conv = double_conv2d_bn(32,16)
self.layer9_conv = double_conv2d_bn(16,8)
self.layer10_conv = nn.Conv2d(8,1,kernel_size=3,
stride=1,padding=1,bias=True)#此处是对于图片的卷积过程,但至于为什么是1-128再到8,目前还不太理解,还在看
self.deconv1 = deconv2d_bn(128,64)
self.deconv2 = deconv2d_bn(64,32)
self.deconv3 = deconv2d_bn(32,16)
self.deconv4 = deconv2d_bn(16,8)#这是反卷积的函数
self.sigmoid = nn.Sigmoid()#次函数应该是用于控制神经网络的输出取值为(0,1)之间
def forward(self,x):#这里是对于图片不断池化与卷积的函数
conv1 = self.layer1_conv(x)
pool1 = F.max_pool2d(conv1,2)
conv2 = self.layer2_conv(pool1)
pool2 = F.max_pool2d(conv2,2)
conv3 = self.layer3_conv(pool2)
pool3 = F.max_pool2d(conv3,2)
conv4 = self.layer4_conv(pool3)
pool4 = F.max_pool2d(conv4,2)
conv5 = self.layer5_conv(pool4)#这里为解码器的池化与卷积部分
convt1 = self.deconv1(conv5)#这里应该为解码器部分,对于之前已经池化卷积提取出的特征进行反卷积并且与之前的卷积图片进行通道的凭借再通过上采样提取其特征
concat1 = torch.cat([convt1,conv4],dim=1)
conv6 = self.layer6_conv(concat1)
convt2 = self.deconv2(conv6)
concat2 = torch.cat([convt2,conv3],dim=1)
conv7 = self.layer7_conv(concat2)
convt3 = self.deconv3(conv7)
concat3 = torch.cat([convt3,conv2],dim=1)
conv8 = self.layer8_conv(concat3)
convt4 = self.deconv4(conv8)
concat4 = torch.cat([convt4,conv1],dim=1)
conv9 = self.layer9_conv(concat4)
outp = self.layer10_conv(conv9)
outp = self.sigmoid(outp)#控制最后的输出为(0,1)之间的函数。
return outp
model = Unet()
inp = torch.rand(10,1,224,224)
outp = model(inp)
print(outp.shape)
==> torch.Size([10, 1, 224, 224])
原文链接:https://blog.csdn.net/qq_34107425/article/details/110184747
看完这位大佬的Unet网络基本模型以后,我还参照了这一位博主的“研志必有功”这位博主的播客,发现Unet网络构建阶段真的是大同小异,但是目前对于除网络方面的地方还在学习与研究。
import tensorflow as tf
import tensorflow.keras as keras
from tensorflow.keras import layers
import numpy as np
import cv2
import itertools
import glob
import matplotlib.pyplot as plt
import tensorflow.keras.backend as K
class UNet():#Unet网络的构建,大家都知道
def __init__(self,
input_width,
input_height,
num_classes,
train_class,
train_images,
train_instances,
val_images,
val_instances,
epochs,
lr,
lr_decay,
batch_size,
model_path,
save_path,
train_mode
):
self.input_width=input_width
self.input_height=input_height
self.num_classes=num_classes
self.train_class=train_class
self.train_images=train_images
self.train_instances=train_instances
self.val_images=val_images
self.val_instances=val_instances
self.epochs=epochs
self.lr=lr
self.lr_decay=lr_decay
self.batch_size=batch_size
self.model_path=model_path
self.save_path=save_path
self.train_mode=train_mode
#--------------------------------------------------------------定义U—net网络结构
def leftNetwork(self, inputs): # U-net网络左侧下采样结构
x = layers.Conv2D(32, (3, 3), padding='same', activation='relu')(inputs)
o_1 = layers.Conv2D(32, (3, 3), padding='same', activation='relu')(x)
x = layers.MaxPooling2D(pool_size=(2, 2), strides=(2, 2))(o_1)
x = layers.Conv2D(64, (3, 3), padding='same', activation='relu')(x)
o_2 = layers.Conv2D(64, (3, 3), padding='same', activation='relu')(x)
x = layers.MaxPooling2D(pool_size=(2, 2), strides=(2, 2))(o_2)
x = layers.Conv2D(128, (3, 3), padding='same', activation='relu')(x)
o_3 = layers.Conv2D(128, (3, 3), padding='same', activation='relu')(x)
x = layers.MaxPooling2D(pool_size=(2, 2), strides=(2, 2))(o_3)
x = layers.Conv2D(256, (3, 3), padding='same', activation='relu')(x)
o_4 = layers.Conv2D(256, (3, 3), padding='same', activation='relu')(x)
x = layers.MaxPooling2D(pool_size=(2, 2), strides=(2, 2))(o_4)
x = layers.Conv2D(512, (3, 3), padding='same', activation='relu')(x)
o_5 = layers.Conv2D(512, (3, 3), padding='same', activation='relu')(x)
print(o_1, o_2, o_3, o_4, o_5)
return [o_1, o_2, o_3, o_4, o_5]
def rightNetwork(self, inputs, num_classes, activation): # U-net网络右侧上采样结构
c_1, c_2, c_3, c_4, c_5 = inputs
x = layers.UpSampling2D((2, 2))(c_5)
print('1', x)
x = layers.concatenate([c_4, x], axis=3)
x = layers.Conv2D(256, (3, 3), padding='same', activation='relu')(x)
x = layers.Conv2D(256, (3, 3), padding='same', activation='relu')(x)
x = layers.UpSampling2D((2, 2))(x)
print('2', x)
x = layers.concatenate([c_3, x], axis=3)
x = layers.Conv2D(128, (3, 3), padding='same', activation='relu')(x)
x = layers.Conv2D(128, (3, 3), padding='same', activation='relu')(x)
x = layers.UpSampling2D((2, 2))(x)
print('3', x)
x = layers.concatenate([c_2, x], axis=3)
x = layers.Conv2D(64, (3, 3), padding='same', activation='relu')(x)
x = layers.Conv2D(64, (3, 3), padding='same', activation='relu')(x)
x = layers.UpSampling2D((2, 2))(x)
print('4', x)
x = layers.concatenate([c_1, x], axis=3)
x = layers.Conv2D(32, (3, 3), padding='same', activation='relu')(x)
x = layers.Conv2D(32, (3, 3), padding='same', activation='relu')(x)
x = layers.Conv2D(num_classes, (1, 1), strides=(1, 1), padding='same')(x)
x= layers.Reshape([self.input_height, self.input_width])(x)
x = layers.Activation(activation)(x)
return (x)
def U_net(self, inputs, num_classes, activation): # U-net网络结构
leftout = self.leftNetwork(inputs)
outputs = self.rightNetwork(leftout, num_classes, activation)
return outputs
#---------------------------------------------------------------
def build_mode(self):#定义建立结构的方法
inputs = keras.Input(shape=[self.input_height,self.input_width,3])
outputs= self.U_net(inputs,num_classes=self.num_classes,activation='sigmoid')
model =keras.Model(inputs=inputs , outputs=outputs)
return model
def dataGenerator(self,mode):#定义 数据生成器
zeroMat=np.zeros(shape=[self.input_height,self.input_width])
if mode =='training':#训练的数据
images = glob.glob(self.train_images+'/*.jpg')
images.sort()
instances= glob.glob(self.train_instances +'/*.png')
instances.sort()
zipped = itertools.cycle(zip(images,instances))
while True :
x_train=[]
y_train=[]
for _ in range(self.batch_size):
img,seg = next(zipped)
img = cv2.imread(img,1)/255
#----------------------------------------------------------------------------------------改变的地方
seg = cv2.imread(seg, 0)
if (self.train_class):
seg = np.where(seg == self.train_class, 1, 0)
# ----------------------------------------------------------------------------------------
# seg = keras.utils.to_categorical(seg,num_classes=self.num_classes)
x_train.append(img)
y_train.append(seg)
yield np.array(x_train),np.array(y_train)
if mode == 'validation':#测试的数据
images = glob.glob(self.val_images + '/*.jpg')#17年的数据用Jpg存放
images.sort()
instances = glob.glob(self.val_instances + '/*.png')#标签用PNG存放
instances.sort()
zipped = itertools.cycle(zip(images,instances))
while True:
x_eval = []
y_eval = []
img,seg = next(zipped)
img = cv2.resize(cv2.imread(img, 1), (self.input_width, self.input_height))/255
#----------------------------------------------------------------------------------------
seg = cv2.imread(seg, 0)
if (self.train_class):
seg = np.where(seg == self.train_class, 1, 0)
# ----------------------------------------------------------------------------------------
# seg = keras.utils.to_categorical(seg,num_classes=self.num_classes)
x_eval.append(img)
y_eval.append(seg)
yield np.array(x_eval), np.array(y_eval)
def multi_category_focal_loss(self,y_true, y_pred):
epsilon = 1.e-7
gamma = 2.0
# alpha = tf.constant([[2],[1],[1],[1],[1]], dtype=tf.float32)
alpha = tf.constant([[1], [1], [1], [1], [1]], dtype=tf.float32)
y_true = tf.cast(y_true, tf.float32)
y_pred = tf.clip_by_value(y_pred, epsilon, 1. - epsilon)
y_t = tf.multiply(y_true, y_pred) + tf.multiply(1 - y_true, 1 - y_pred)
ce = -K.log(y_t)
weight = tf.pow(tf.subtract(1., y_t), gamma)
fl = tf.matmul(tf.multiply(weight, ce), alpha)
loss = tf.reduce_mean(fl)
return loss
def focal_loss(self,y_true, y_pred): # 定义损失函数
gamma = 1.5
alpha = 0.9
pt_1 = tf.where(tf.equal(y_true, 1), y_pred, tf.ones_like(y_pred))
pt_0 = tf.where(tf.equal(y_true, 0), y_pred, tf.zeros_like(y_pred))
pt_1 = K.clip(pt_1, 1e-3, .999)
pt_0 = K.clip(pt_0, 1e-3, .999)
return -tf.reduce_mean(alpha * K.pow(1. - pt_1, gamma) * K.log(pt_1)) - tf.reduce_mean(
(1 - alpha) * K.pow(pt_0, gamma) * K.log(1. - pt_0))
def train(self):#定义训练过程
G_train =self.dataGenerator(mode='training')
G_eval =self.dataGenerator(mode='validation')
if (self.train_mode):
model=keras.models.load_model(self.model_path,custom_objects={'focal_loss': self.focal_loss})
else:
model =self.build_mode()#实例化
model.summary()
model.compile(
optimizer=keras.optimizers.Adam(self.lr,self.lr_decay),
loss ='binary_crossentropy',#构造损失函数
metrics=['binary_accuracy', 'Recall','AUC']#构造评价函数
)
checkpoint = keras.callbacks.ModelCheckpoint(self.save_path, monitor='val_Recall', verbose=1,
save_best_only=True, mode='max')
callbacks = [checkpoint]
model.fit_generator(G_train,2000,validation_data=G_eval,validation_steps=30,epochs=self.epochs,callbacks=callbacks)
model.save(self.save_path)#保存模型
def modelPred(self):#模型预测函数
model = keras.models.load_model(self.model_path,custom_objects={'multi_category_focal_loss1': self.multi_category_focal_loss})
model.summary()
images = glob.glob(self.val_images + '/*.jpg')#17年的数据用Jpg格式存放
images.sort()
instances = glob.glob(self.val_instances + '/*.png')#标签用tif存放
instances.sort()
zipped = itertools.cycle(zip(images,instances))
for _ in range(10):
img,seg = next(zipped)
img = cv2.resize(cv2.imread(img, -1), (self.input_width, self.input_height))/255
seg = cv2.imread(seg, 0)
x1_eval=np.expand_dims(img,0)
pred=tf.squeeze(tf.argmax(model.predict(x1_eval),axis=-1))
plt.subplot(121)
plt.title("pred")
plt.imshow(pred)
plt.subplot(122)
plt.title("pred")
plt.imshow(seg)
plt.show()
if __name__ == '__main__':
unet=UNet(#开始模型的实例化,每个类别训练一个网络
input_width=256,#图片resize成这个大小
input_height=256,
num_classes=1,#检测类别
train_class=4,#训练第几个类别
train_images=r'F:\BaiduNetdiskDownload\remoteSensing\U_net\train_x',#训练数据存放的地方
train_instances=r'F:\BaiduNetdiskDownload\remoteSensing\U_net\train_y',
val_images=r'F:\BaiduNetdiskDownload\remoteSensing\U_net\test_x',#测试数据存放的地方
val_instances=r'F:\BaiduNetdiskDownload\remoteSensing\U_net\test_y',
epochs=200,
lr=0.0001,
lr_decay=0.000001,
batch_size=4,
model_path=r'F:\BaiduNetdiskDownload\remoteSensing\U_net\U_netClass3.h5',
save_path=r'F:\BaiduNetdiskDownload\remoteSensing\U_net\U_netClass4.h5',#模型存储绝对路径
train_mode=0
)
unet.train()#开始训练
# unet.modelPred()
————————————————
版权声明:本文为CSDN博主「研志必有功」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。
原文链接:https://blog.csdn.net/qq_44930937/article/details/105443188
收集了一些有助于我新手入门的概念性的东西
① Conv2d函数
⑥反向卷积