使用CNN对自然图像压缩重构【图像压缩感知】

基于深度学习的图像压缩感知

针对图像的压缩感知有好多篇论文使用深度学习的方法实现图像压缩采样和重构,主要是复现论文的代码过程。
分析论文:[1]Shi W, Jiang F, Zhang S, et al. Deep Networks for Compressed Image Sensing[J]. 2017:877-882.
论文题目:Deep Networks for Compressed Image Sensing
首先论文的框架是:
在这里插入图片描述
中心思想是通过卷积和步长实现图像的压缩,然后通过卷积的深度实现小块图像的重构和块图像的拼接,最后通过5层卷积神经网络实现最终图像的复原。这个过程和GANLU在2009年的块压缩感知的论文[2]过程差不多。结合两篇论文的过程,实现代码:
[2].Gan L. Block Compressed Sensing of Natural Images[C]// International Conference on Digital Signal Processing. IEEE, 2007:403-406.
准备数据集:利用BSDS500中的400张自然图像,通过平移旋转镜像等方法得到128*128大小的图像70000张左右,存储格式使用的是.h5用起来比较方便,测试集使用了100张图像作为测试集。

path = '/home/train.h5'
def read_data(path):   
    with h5py.File(path, 'r') as hf:
         orig_image = np.array(hf.get('orig_image'))
         sample_line = np.array(hf.get('sample_line'))
    return orig_image, sample_line
orig_image, sample_line = read_data(path)
#read validation data:
path1 = '/home/test1.h5'
def read_data1(path):
    with h5py.File(path, 'r') as hf:
         orig_image = np.array(hf.get('orig_image'))
         sample_line = np.array(hf.get('sample_line'))
    return orig_image, sample_line
test_image, test_line = read_data1(path1)
test_image = test_image.reshape((100, 256, 256, 1))
# Load test data:
read_dictionary = np.load('/home/lab30202/Juanjuan/images/h5/test_data.npy').item()
#print(read_dictionary['baby']) # displays "world"
image_test = tf.placeholder(tf.float32, [1, 512, 512, 1])

整个过程代码:

#!/usr/bin/env python2
# -*- coding: utf-8 -*-
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
import tensorflow as tf
import numpy as np
from PIL import Image 
import matplotlib.pyplot as plt
from skimage import io, img_as_float, measure
import os
import h5py
import collections
#########################################################################
#read train original data
#read train original data
path = '/home/train.h5'
def read_data(path):
    
    with h5py.File(path, 'r') as hf:
         orig_image = np.array(hf.get('orig_image'))
         sample_line = np.array(hf.get('sample_line'))
    return orig_image, sample_line
orig_image, sample_line = read_data(path)

#read validation data:
path1 = '/home/test1.h5'
def read_data1(path): 
    with h5py.File(path, 'r') as hf:
         orig_image = np.array(hf.get('orig_image'))
         sample_line = np.array(hf.get('sample_line'))
    return orig_image, sample_line
test_image, test_line = read_data1(path1)
test_image = test_image.reshape((100, 256, 256, 1))
# Load test data:
read_dictionary = np.load('/home/lab30202/Juanjuan/images/h5/test_data.npy').item()
image_test = tf.placeholder(tf.float32, [1, 512, 512, 1])
#CNN
learning_rate1 = 0.001
learning_rate2 = 0.0001
learning_rate3 = 0.00001

#training_iters = 25
batch_size = 64
num_samples= len(orig_image)
batch_index = num_samples/batch_size
training_epochs = 50
training_epochs1 = 80
training_epochs2 = 100
display_step = 1

x = tf.placeholder(tf.float32, [None, 128, 128, 1])
y = tf.placeholder(tf.float32, [None, 128, 128, 1])
keep_prob = tf.placeholder(tf.float32)

def tensor_concat(f, axis):
    x1 = f[0, :, :]
    for i in range(1, f.shape[0]):
        x1 = tf.concat([x1, f[i, :, :]], axis=axis)
    return x1

def block_to_image(f, batch_size):
    x3 =[]
    f = tf.reshape(f, [batch_size, f.shape[1], f.shape[2], 1024])
    for k in range(int(f.shape[0])):
        x = f[k, :, :, :]
        q = f.shape[1]*f.shape[2]
        p = int(f.shape[2])
        x1 = tf.reshape(x, [q, 1024])
        x1 = tf.reshape(x1, [q, 32, 32])
        m2 = tensor_concat(x1[0:f.shape[1], :, :], axis=1)
        for i in range(1, f.shape[1]):
            m1 = tensor_concat(x1[i*p:(i+1)*p, :, :], axis=1)
            m2 = tf.concat([m2, m1], axis=0)
        x2 = tf.reshape(m2, [32*p, 32*p, 1])
        x3.append(x2)
        x4 = tf.stack(x3)
    return x4

def conv2dcs1(name, x, W, strides):
   x = tf.nn.conv2d(x, W, strides=[1, strides, strides, 1], padding='VALID')
   return x
   #return  tf.nn.relu(x, name=name) 
   #return tf.nn.sigmoid(x, name=name)

def conv2dcs2(name, x, W, strides):
   x = tf.nn.conv2d(x, W, strides=[1, strides, strides, 1], padding='SAME')
   return x
   #return  tf.nn.relu(x, name=name) 
   #return tf.nn.sigmoid(x, name=name)

def conv2d(name, x, W, strides=1):
   x = tf.nn.conv2d(x, W, strides=[1, strides, strides, 1], padding='SAME')
   return  tf.nn.relu(x, name=name)

weight1 = {
  'wcs1': tf.Variable(tf.random_normal([32, 32, 1, 102], stddev=0.02)),
  'wcs2': tf.Variable(tf.random_normal([1, 1, 102, 1024], stddev=0.02))
  }
  
weight2 = {
  'wc1': tf.Variable(tf.random_normal([3, 3, 1, 64], stddev=0.1)), 
  'wc2': tf.Variable(tf.random_normal([3, 3, 64, 64], stddev=0.1)),  
  'wc3': tf.Variable(tf.random_normal([3, 3, 64, 64], stddev=0.1)),
  'wc4': tf.Variable(tf.random_normal([3, 3, 64, 64], stddev=0.1)),
  'wc5': tf.Variable(tf.random_normal([3, 3, 64, 1], stddev=0.1))
  }

def cs_net(x, weight1, batch_size):
    x = tf.cast(x, dtype=tf.float32)
    #the sampling matrix 
    conv1 = conv2dcs1('conv1', x, weight1['wcs1'], strides=32)  
    #the reshape and concatenation layer:
    conv2 = conv2dcs2('conv2', conv1, weight1['wcs2'], strides=1)
    output = block_to_image(conv2, batch_size)
    return output
    
def cs_cnn_net(x, weight1, weight2, batch_size):
  #sampling matrix network:
   x = tf.reshape(x, shape=([-1, x.shape[1], x.shape[2], 1]))
   x1 = cs_net(x, weight1, batch_size)
  #the deep reconstruction sub-betwork
   conv3 = conv2d('conv3', x1, weight2['wc1'])
   conv4 = conv2d('conv4', conv3, weight2['wc2'])
   conv5 = conv2d('conv5', conv4, weight2['wc3'])
   conv6 = conv2d('conv6', conv5, weight2['wc4'])
   conv7 = conv2d('conv7', conv6, weight2['wc5'])
   return conv7

pred = cs_cnn_net(x, weight1, weight2, batch_size)
y_true = x
mse_loss = tf.reduce_mean(tf.pow(pred - y_true, 2))

optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate1, beta1=0.9, beta2=0.999).minimize(mse_loss)
optimizer1 = tf.train.AdamOptimizer(learning_rate=learning_rate2, beta1=0.9, beta2=0.999).minimize(mse_loss)
optimizer2 = tf.train.AdamOptimizer(learning_rate=learning_rate3, beta1=0.9, beta2=0.999).minimize(mse_loss)

#验证集的测试:
pred1 = cs_cnn_net(test_image, weight1, weight2, batch_size=100)
y_pred1 = test_image
y_pred1_mse = tf.reduce_mean(tf.pow(pred1-y_pred1, 2))

#测试集的测试:
pred2 = cs_cnn_net(image_test, weight1, weight2, batch_size=1)
y_pred2 = image_test
y_pred2_mse = tf.reduce_mean(tf.pow(pred2-y_pred2, 2))

#模型存储:
model_save_path = '/home/model_cnn_rs/' 
model_name = 'model.ckpt'
saver = tf.train.Saver(tf.all_variables())
init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
with tf.Session(config=tf.ConfigProto(allow_soft_placement=True,
                                      log_device_placement=True)) as sess:
    sess.run(init_op)
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(coord=coord)
    mse_loss1 = []
    mse_loss2 = []
    psnr_test = collections.defaultdict(list)
    mse_test = collections.defaultdict(list)
    ssim_test = collections.defaultdict(list)

    #开始训练
    #前 50次训练:
    for epoch in range(training_epochs):
        for i in range(batch_index):
            batch_orig = orig_image[i*batch_size:(i+1)*batch_size]
            _, c = sess.run([optimizer, mse_loss], feed_dict={x: batch_orig})
            if i % 100 == 0:
                print "iter:", "%04d" %(i+1), 'mse=', '{:.9f}'.format(c)
                mse_loss1.append(c)

        #每一轮打印一次
        if epoch % display_step == 0:
            print "epoch:", "%04d" %(epoch +1), "mse=", "{:.9f}".format(c)
            checkpoint_path = os.path.join(model_save_path, model_name)
            saver.save(sess, checkpoint_path, global_step=epoch)
            mse = sess.run([y_pred1_mse])
            mse_loss2.append(mse)  #验证集
            
           # keys = ['butterfly', 'pepper', 'baby', 'bird', 'lenna']
            keys = ['pepper', 'baby', 'lenna']
            for key in keys:
                print key
                image = read_dictionary[key]
                image_test1 = image.reshape((1, image.shape[0], image.shape[1], 1))
                y, mse1 = sess.run([pred2, y_pred2_mse], feed_dict={image_test: image_test1})
                y = y.reshape((y.shape[1], y.shape[2]))
                y = np.where(y>=-1, y, -1)
                y = np.where(y<=1, y, 1)

                psnr = measure.compare_psnr(image, y, data_range=1)
                ssim = measure.compare_ssim(image, y, data_range=1)
                mse = measure.compare_mse(image, y)
                print psnr
                mse_test[key].append(mse)
                psnr_test[key].append(psnr)
                ssim_test[key].append(ssim)

                print "validation_image:", 'mse=', '{:.9f}'.format(mse1)
                print "test_image:", key, 'mse=', '{:.9f}'.format(mse), 'psnr=', '{:.9f}'.format(psnr), 'ssim=', '{:.9f}'.format(ssim)
               # y = y.reshape((y.shape[1], y.shape[2]))
                error = y - image
                io.imsave(model_save_path+key+'pred'+np.str(epoch)+'.jpg', y)
                io.imsave(model_save_path+key+'error'+np.str(epoch)+'.jpg', error)
     
    #前 50-80次训练:
    for epoch in range(training_epochs, training_epochs1):
        for i in range(batch_index):
            batch_orig = orig_image[i*batch_size:(i+1)*batch_size]
            _, c = sess.run([optimizer1, mse_loss], feed_dict={x: batch_orig})
            if i % 100 == 0:
                print "iter:", "%04d" %(i+1), 'mse=', '{:.9f}'.format(c)
                mse_loss1.append(c)

        #每一轮打印一次
        if epoch % display_step == 0:
            print "epoch:", "%04d" %(epoch +1), "mse=", "{:.9f}".format(c)
            checkpoint_path = os.path.join(model_save_path, model_name)
            saver.save(sess, checkpoint_path, global_step=epoch)
            mse = sess.run([y_pred1_mse])
            mse_loss2.append(mse)  #验证集
            
           # keys = ['butterfly', 'pepper', 'baby', 'bird', 'lenna']
            keys = ['pepper', 'baby', 'lenna']
            for key in keys:
                image = read_dictionary[key]
                image_test1 = image.reshape((1, image.shape[0], image.shape[1], 1))
                y, mse1 = sess.run([pred2, y_pred2_mse], feed_dict={image_test: image_test1})
                y = y.reshape((y.shape[1], y.shape[2]))
                y = np.where(y>=-1, y, -1)
                y = np.where(y<=1, y, 1)

                psnr = measure.compare_psnr(image, y, data_range=1)
                ssim = measure.compare_ssim(image, y, data_range=1)
                mse = measure.compare_mse(image, y)
                mse_test[key].append(mse)
                psnr_test[key].append(psnr)
                ssim_test[key].append(ssim)

                print "validation_image:", 'mse=', '{:.9f}'.format(mse1)
                print "test_image:", key, 'mse=', '{:.9f}'.format(mse), 'psnr=', '{:.9f}'.format(psnr), 'ssim=', '{:.9f}'.format(ssim)
                # y = y.reshape((y.shape[1], y.shape[2]))
                error = y - image
                io.imsave(model_save_path+key+'pred'+np.str(epoch)+'.jpg', y)
                io.imsave(model_save_path+key+'error'+np.str(epoch)+'.jpg', error)    
       
    #前 80-100次训练:
    for epoch in range(training_epochs1, training_epochs2):
        for i in range(batch_index):
            batch_orig = orig_image[i*batch_size:(i+1)*batch_size]
            _, c = sess.run([optimizer2, mse_loss], feed_dict={x: batch_orig})
            if i % 100 == 0:
                print "iter:", "%04d" %(i+1), 'mse=', '{:.9f}'.format(c)
                mse_loss1.append(c)

        #每一轮打印一次
        if epoch % display_step == 0:
            print "epoch:", "%04d" %(epoch +1), "mse=", "{:.9f}".format(c)
            checkpoint_path = os.path.join(model_save_path, model_name)
            saver.save(sess, checkpoint_path, global_step=epoch)
            mse = sess.run([y_pred1_mse])
            mse_loss2.append(mse)  #验证集
            
           # keys = ['butterfly', 'pepper', 'baby', 'bird', 'lenna']
            keys = ['pepper', 'baby', 'lenna']
            for key in keys:
                image = read_dictionary[key]
                image_test1 = image.reshape((1, image.shape[0], image.shape[1], 1))
                y, mse1 = sess.run([pred2, y_pred2_mse], feed_dict={image_test: image_test1})
                y = y.reshape((y.shape[1], y.shape[2]))
                y = np.where(y>=-1, y, -1)
                y = np.where(y<=1, y, 1)
                psnr = measure.compare_psnr(image, y, data_range=1)
                ssim = measure.compare_ssim(image, y, data_range=1)
                mse = measure.compare_mse(image, y)

                mse_test[key].append(mse)
                psnr_test[key].append(psnr)
                ssim_test[key].append(ssim)
                print "validation_image:", 'mse=', '{:.9f}'.format(mse1)
                print "test_image:", key, 'mse=', '{:.9f}'.format(mse), 'psnr=', '{:.9f}'.format(psnr), 'ssim=', '{:.9f}'.format(ssim)
               # y = y.reshape((y.shape[1], y.shape[2]))
                error = y - image
                io.imsave(model_save_path+key+'pred'+np.str(epoch)+'.jpg', y)
                io.imsave(model_save_path+key+'error'+np.str(epoch)+'.jpg', error)
    
    print "Optimaization Finished"
    np.savetxt(model_save_path+'mse_loss1', mse_loss1)
    np.savetxt(model_save_path+'mse_loss2', mse_loss2)
    np.save(model_save_path+'psnr_test.npy', psnr_test)
    np.save(model_save_path+'mse_test.npy', mse_test)
    np.save(model_save_path+'ssim_test.npy', ssim_test)
  
    coord.request_stop()
    coord.join(threads)

训练对应的验证集的损失函数的变化:
在这里插入图片描述
图像的测试结果:
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
这个通过卷积神经网络实现图像压缩重构的方法真实的效果非常好!赞叹!复现整个论文代码的过程也非常有趣!

  • 10
    点赞
  • 96
    收藏
    觉得还不错? 一键收藏
  • 37
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 37
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值