Tensorflow实现小波池化层

转载:
Github链接:https://github.com/LiQiufu/WaveCNet
论文链接:https://openaccess.thecvf.com/content_CVPR_2020/papers/Li_Wavelet_Integrated_CNNs_for_Noise-Robust_Image_Classification_CVPR_2020_paper.pdf

import numpy as np
import math
import cv2
import pywt
import os
from PIL import Image
from tensorflow.keras.utils import to_categorical, Sequence
from tensorflow.keras.preprocessing.image import load_img, img_to_array
from tensorflow.keras.models import Model, Sequential
import seaborn as sb
from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, Dense, Flatten, BatchNormalization, Activation, Dropout, Lambda, GlobalAveragePooling2D
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, LearningRateScheduler, ReduceLROnPlateau
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.optimizers import Adam, SGD, RMSprop
from tensorflow.keras.applications.vgg19 import VGG19
from tensorflow.keras.applications.resnet50 import ResNet50
from sklearn.metrics import classification_report,confusion_matrix
import tensorflow.keras.backend as K
import matplotlib.pyplot as plt
import tensorflow as tf
from keras.layers import Layer
K.set_image_data_format('channels_first')
print(pywt.wavelist(kind='discrete'))
#define the wavelet
#see pywt.wavelist(kind='discrete') for available wavelets 
wavelet = pywt.Wavelet('haar')

class DWT_Pooling(tf.keras.layers.Layer):
    def __init__(self,**kwargs):
        super(DWT_Pooling, self).__init__(**kwargs)
        
    def build(self, input_shape):
        super(DWT_Pooling, self).build(input_shape) 
    
    @tf.function
    def call(self, inputs):
        band_low = wavelet.rec_lo
        band_high = wavelet.rec_hi
        assert len(band_low) == len(band_high)
        band_length = len(band_low)
        assert band_length % 2 == 0
        band_length_half = math.floor(band_length / 2)

        input_height = inputs.shape[2]
        input_width = inputs.shape[3]

        L1 = input_height
        L = math.floor(L1 / 2)
        matrix_h = np.zeros( ( L,      L1 + band_length - 2 ), dtype=np.float32)
        matrix_g = np.zeros( ( L1 - L, L1 + band_length - 2 ), dtype=np.float32)
        end = None if band_length_half == 1 else (-band_length_half+1)
        
        index = 0
        for i in range(L):
            for j in range(band_length):
                matrix_h[i, index+j] = band_low[j]
            index += 2
        matrix_h_0 = matrix_h[0:(math.floor(input_height / 2)), 0:(input_height + band_length - 2)]
        matrix_h_1 = matrix_h[0:(math.floor(input_width / 2)), 0:(input_width + band_length - 2)]

        index = 0
        for i in range(L1 - L):
            for j in range(band_length):
                matrix_g[i, index+j] = band_high[j]
            index += 2

        matrix_g_0 = matrix_g[0:(input_height - math.floor(input_height / 2)),0:(input_height + band_length - 2)]
        matrix_g_1 = matrix_g[0:(input_width - math.floor(input_width / 2)),0:(input_width + band_length - 2)]

        matrix_h_0 = matrix_h_0[:,(band_length_half-1):end]
        matrix_h_1 = matrix_h_1[:,(band_length_half-1):end]
        matrix_h_1 = np.transpose(matrix_h_1)
        matrix_g_0 = matrix_g_0[:,(band_length_half-1):end]
        matrix_g_1 = matrix_g_1[:,(band_length_half-1):end]
        matrix_g_1 = np.transpose(matrix_g_1)

        matrix_low_0 = tf.convert_to_tensor(matrix_h_0,dtype=tf.float32)
        matrix_low_1 = tf.convert_to_tensor(matrix_h_1,dtype=tf.float32)
        matrix_high_0 = tf.convert_to_tensor(matrix_g_0,dtype=tf.float32)
        matrix_high_1 = tf.convert_to_tensor(matrix_g_1,dtype=tf.float32)
        
        L = tf.matmul(matrix_low_0, inputs)
        H = tf.matmul(matrix_high_0, inputs)
        LL = tf.matmul(L, matrix_low_1)
        LH = tf.matmul(L, matrix_high_1)
        HL = tf.matmul(H, matrix_low_1)
        HH = tf.matmul(H, matrix_high_1)
        return LL    
    
    def get_config(self):
        config = super(DWT_Pooling, self).get_config()
        return config

    def compute_output_shape(self, input_shape):
        return (input_shape[0], input_shape[1], input_shape[2]//2, input_shape[3]//2)
def create_model(input_shape=(1,28,28), num_classes = 1, output_bias=None):
  if output_bias is not None:
    output_bias = tf.keras.initializers.Constant(output_bias)

  inputs = Input(shape=input_shape)
  
  output = Conv2D(16,(3,3),padding='same',use_bias=False)(inputs)
  output = BatchNormalization(scale=False,center=True)(output)
  output = Activation('relu')(output)
  #output = MaxPooling2D()(output)
  output = DWT_Pooling()(output)

  output = Conv2D(32,(3,3),padding='same',use_bias=False)(output)
  output = BatchNormalization(scale=False,center=True)(output)
  output = Activation('relu')(output)
  #output = MaxPooling2D()(output)
  output = DWT_Pooling()(output)

  output = Flatten()(output)
  output = Dense(256,activation='relu')(output)
  output = Dropout(0.3)(output)
  if num_classes == 1:
    activation = 'sigmoid'
  else:
    activation = 'softmax'
  output = Dense(num_classes,activation=activation,bias_initializer=output_bias)(output)
  model = Model(inputs,output)
  return model
model = create_model(input_shape=(1,28,28),num_classes=10)
model.summary()

MNIST example

(x_train,y_train),(x_test,y_test) = tf.keras.datasets.mnist.load_data()
x_train.shape
x_train = x_train/255.
x_test = x_test/255.
from tensorflow.keras.utils import to_categorical
y_train_oh = to_categorical(y_train,10)
y_test_oh = to_categorical(y_test,10)
num_train = x_train.shape[0]
num_test = x_test.shape[0]
img_height = x_train.shape[1]
img_width = x_train.shape[2]
num_channels = 1
x_train = x_train.reshape(num_train,1,img_height,img_width)
x_test = x_test.reshape(num_test,1,img_height,img_width)
opt= Adam(learning_rate=0.01)
model.compile(optimizer = opt,loss='categorical_crossentropy',metrics=['accuracy'])
def lr_decay(epoch):
  return 0.01*math.pow(0.666,epoch)
lr_decay_cb = LearningRateScheduler(lr_decay,verbose=True)
model_check_cb = ModelCheckpoint('mnist_dwt.h5',save_best_only=True,monitor='val_loss')
history = model.fit(x_train,y_train_oh,validation_data=(x_test,y_test_oh),epochs=10,batch_size=64,
                    callbacks=[lr_decay_cb,model_check_cb])
best_model = tf.keras.models.load_model('mnist_dwt.h5',custom_objects={'DWT_Pooling':DWT_Pooling})
best_model.evaluate(x_test,y_test_oh)
y_preds = np.argmax(m.predict(x_test),axis=1)
from sklearn.metrics import classification_report,roc_auc_score
print(classification_report(y_test,y_preds))
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值