CNN实现卫星图像分类(tensorflow)

本文介绍了如何使用TensorFlow库对包含飞机和湖泊两类卫星图像的数据集进行预处理、构建CNN模型,进行二分类,并展示了训练过程、模型性能评估以及预测示例。
摘要由CSDN通过智能技术生成

使用的数据集卫星图像有两类,airplane和lake,每个类别样本量各700张,大小为256*256,RGB三通道彩色卫星影像。搭建深度卷积神经网络,实现卫星影像二分类。
数据链接百度网盘地址,提取码: cq47

1、查看tensorflow版本

import tensorflow as tf

print('Tensorflow Version:{}'.format(tf.__version__))
print(tf.config.list_physical_devices())

在这里插入图片描述

2、加载并显示训练数据

从文件夹中获取所有数据路径

import glob
import random

all_image_path = glob.glob('./data/air_lake_dataset/*/*.jpg')  # glob相比于pathlib更简洁
random.shuffle(all_image_path)

读取并处理图像

def load_and_preprocess_image(path):
    img_raw = tf.io.read_file(path)
    img_tensor = tf.image.decode_jpeg(img_raw,channels=3)
    img_tensor = tf.image.resize(img_tensor,[256,256])
    img_tensor = tf.cast(img_tensor,tf.float32)
    img_tensor = img_tensor/255
    return img_tensor

处理标签

label_to_index = {'airplane':0,'lake':1}
index_to_label = dict((v,k) for k,v in label_to_index.items())
labels = [label_to_index.get(img.split('/')[3]) for img in all_image_path]

显示卫星影像

import matplotlib.pyplot as plt

def plot_images_lables(all_image_path,labels,start_idx,num=5):
    fig = plt.gcf()
    fig.set_size_inches(12,14)
    images = [load_and_preprocess_image(img_path) for img_path in all_image_path[start_idx:start_idx+5]]
    for i in range(num):
        ax = plt.subplot(1,num,1+i)
        ax.imshow(images[i])
        title = 'label=' + index_to_label.get(labels[start_idx+i])
        ax.set_title(title,fontsize=10)
        ax.set_xticks([])
        ax.set_yticks([])
    plt.show()

plot_images_lables(all_image_path,labels,0,5)

在这里插入图片描述

4、使用tf.data.Dataset制作训练/测试数据

制作 Dataset

img_ds = tf.data.Dataset.from_tensor_slices(all_image_path)
img_ds = img_ds.map(load_and_preprocess_image)
label_ds = tf.data.Dataset.from_tensor_slices(labels)
img_label_ds = tf.data.Dataset.zip((img_ds,label_ds))

训练集、测试集划分

test_count = int(len(labels)*0.2) 
train_count = len(labels) - test_count

train_ds = img_label_ds.skip(test_count)
test_ds = img_label_ds.take(test_count)

分批次加载数据

BATCH_SIZE = 16
train_ds = train_ds.repeat().shuffle(100).batch(BATCH_SIZE)
test_ds = test_ds.repeat().batch(BATCH_SIZE)

5、CNN模型构建

from keras.layers import Input,Dense,Dropout
from keras.layers import Conv2D,MaxPool2D,GlobalAvgPool2D

model = tf.keras.Sequential([
    Input(shape=(256,256,3)),
    Conv2D(filters=64,kernel_size=(3,3),activation='relu',padding='same'),  # 增加filter个数,增加模型拟合能力
    Conv2D(filters=64,kernel_size=(3,3),activation='relu',padding='same'),
    MaxPool2D(),  # 默认2*2. 池化层扩大视野
    Dropout(0.2),  # 防止过拟合
    Conv2D(filters=128,kernel_size=(3,3),activation='relu',padding='same'),
    Conv2D(filters=128,kernel_size=(3,3),activation='relu',padding='same'),
    MaxPool2D(),
    Dropout(0.2),
    Conv2D(filters=256,kernel_size=(3,3),activation='relu',padding='same'),
    Conv2D(filters=256,kernel_size=(3,3),activation='relu',padding='same'),
    MaxPool2D(),
    Dropout(0.2),
    Conv2D(filters=512,kernel_size=(3,3),activation='relu',padding='same'),
    Conv2D(filters=512,kernel_size=(3,3),activation='relu',padding='same'),
    GlobalAvgPool2D(),  # 全局平均池化
    Dense(1024,activation='relu'),
    Dense(256,activation='relu'),
    Dense(1,activation='sigmoid') 
])

model.summary()

在这里插入图片描述

6、模型编译与训练

model.compile(optimizer=tf.keras.optimizers.Adam(0.0001),
              loss=tf.keras.losses.BinaryCrossentropy(from_logits=False),  # 已经使用sigmoid激活过了
              metrics=['acc'])

steps_per_epoch = train_count//BATCH_SIZE
val_step = test_count//BATCH_SIZE

H = model.fit(train_ds,
             epochs=10,
             steps_per_epoch=steps_per_epoch,
             validation_data=test_ds,
             validation_steps=val_step,
             verbose=1)

在这里插入图片描述

7、模型评估

import matplotlib.pyplot as plt

fig = plt.gcf()
fig.set_size_inches(12,4)
plt.subplot(1,2,1)
plt.plot(H.epoch, H.history['loss'], label='loss')
plt.plot(H.epoch, H.history['val_loss'], label='val_loss')
plt.legend()
plt.title('loss')

plt.subplot(1,2,2)
plt.plot(H.epoch, H.history['acc'], label='acc')
plt.plot(H.epoch, H.history['val_acc'], label='val_acc')
plt.legend()
plt.title('acc')
plt.show()

在这里插入图片描述

8、模型预测

def pred_img(img_path):
    img = load_and_preprocess_image(img_path)
    img = tf.expand_dims(img, axis=0)
    pred = model.predict(img)
    pred = index_to_label.get((pred>0.5).astype('int')[0][0])
    return pred
    
img_path = './data/air_lake_dataset/airplane/airplane_240.jpg'
pred = pred_img(img_path)
img_tensor = load_and_preprocess_image(img_path)
plt.imshow(img_tensor)
title = 'label=' + img_path.split('/')[3].strip() + ', pred=' + pred
plt.title(title)
plt.show()

在这里插入图片描述

  • 19
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 3
    评论
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值