转载:
Github链接:https://github.com/LiQiufu/WaveCNet
论文链接:https://openaccess.thecvf.com/content_CVPR_2020/papers/Li_Wavelet_Integrated_CNNs_for_Noise-Robust_Image_Classification_CVPR_2020_paper.pdf
import numpy as np
import math
import cv2
import pywt
import os
from PIL import Image
from tensorflow.keras.utils import to_categorical, Sequence
from tensorflow.keras.preprocessing.image import load_img, img_to_array
from tensorflow.keras.models import Model, Sequential
import seaborn as sb
from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, Dense, Flatten, BatchNormalization, Activation, Dropout, Lambda, GlobalAveragePooling2D
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, LearningRateScheduler, ReduceLROnPlateau
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.optimizers import Adam, SGD, RMSprop
from tensorflow.keras.applications.vgg19 import VGG19
from tensorflow.keras.applications.resnet50 import ResNet50
from sklearn.metrics import classification_report,confusion_matrix
import tensorflow.keras.backend as K
import matplotlib.pyplot as plt
import tensorflow as tf
from keras.layers import Layer
K.set_image_data_format('channels_first')
print(pywt.wavelist(kind='discrete'))
#define the wavelet
#see pywt.wavelist(kind='discrete') for available wavelets
wavelet = pywt.Wavelet('haar')
class DWT_Pooling(tf.keras.layers.Layer):
def __init__(self,**kwargs):
super(DWT_Pooling, self).__init__(**kwargs)
def build(self, input_shape):
super(DWT_Pooling, self).build(input_shape)
@tf.function
def call(self, inputs):
band_low = wavelet.rec_lo
band_high = wavelet.rec_hi
assert len(band_low) == len(band_high)
band_length = len(band_low)
assert band_length % 2 == 0
band_length_half = math.floor(band_length / 2)
input_height = inputs.shape[2]
input_width = inputs.shape[3]
L1 = input_height
L = math.floor(L1 / 2)
matrix_h = np.zeros( ( L, L1 + band_length - 2 ), dtype=np.float32)
matrix_g = np.zeros( ( L1 - L, L1 + band_length - 2 ), dtype=np.float32)
end = None if band_length_half == 1 else (-band_length_half+1)
index = 0
for i in range(L):
for j in range(band_length):
matrix_h[i, index+j] = band_low[j]
index += 2
matrix_h_0 = matrix_h[0:(math.floor(input_height / 2)), 0:(input_height + band_length - 2)]
matrix_h_1 = matrix_h[0:(math.floor(input_width / 2)), 0:(input_width + band_length - 2)]
index = 0
for i in range(L1 - L):
for j in range(band_length):
matrix_g[i, index+j] = band_high[j]
index += 2
matrix_g_0 = matrix_g[0:(input_height - math.floor(input_height / 2)),0:(input_height + band_length - 2)]
matrix_g_1 = matrix_g[0:(input_width - math.floor(input_width / 2)),0:(input_width + band_length - 2)]
matrix_h_0 = matrix_h_0[:,(band_length_half-1):end]
matrix_h_1 = matrix_h_1[:,(band_length_half-1):end]
matrix_h_1 = np.transpose(matrix_h_1)
matrix_g_0 = matrix_g_0[:,(band_length_half-1):end]
matrix_g_1 = matrix_g_1[:,(band_length_half-1):end]
matrix_g_1 = np.transpose(matrix_g_1)
matrix_low_0 = tf.convert_to_tensor(matrix_h_0,dtype=tf.float32)
matrix_low_1 = tf.convert_to_tensor(matrix_h_1,dtype=tf.float32)
matrix_high_0 = tf.convert_to_tensor(matrix_g_0,dtype=tf.float32)
matrix_high_1 = tf.convert_to_tensor(matrix_g_1,dtype=tf.float32)
L = tf.matmul(matrix_low_0, inputs)
H = tf.matmul(matrix_high_0, inputs)
LL = tf.matmul(L, matrix_low_1)
LH = tf.matmul(L, matrix_high_1)
HL = tf.matmul(H, matrix_low_1)
HH = tf.matmul(H, matrix_high_1)
return LL
def get_config(self):
config = super(DWT_Pooling, self).get_config()
return config
def compute_output_shape(self, input_shape):
return (input_shape[0], input_shape[1], input_shape[2]//2, input_shape[3]//2)
def create_model(input_shape=(1,28,28), num_classes = 1, output_bias=None):
if output_bias is not None:
output_bias = tf.keras.initializers.Constant(output_bias)
inputs = Input(shape=input_shape)
output = Conv2D(16,(3,3),padding='same',use_bias=False)(inputs)
output = BatchNormalization(scale=False,center=True)(output)
output = Activation('relu')(output)
#output = MaxPooling2D()(output)
output = DWT_Pooling()(output)
output = Conv2D(32,(3,3),padding='same',use_bias=False)(output)
output = BatchNormalization(scale=False,center=True)(output)
output = Activation('relu')(output)
#output = MaxPooling2D()(output)
output = DWT_Pooling()(output)
output = Flatten()(output)
output = Dense(256,activation='relu')(output)
output = Dropout(0.3)(output)
if num_classes == 1:
activation = 'sigmoid'
else:
activation = 'softmax'
output = Dense(num_classes,activation=activation,bias_initializer=output_bias)(output)
model = Model(inputs,output)
return model
model = create_model(input_shape=(1,28,28),num_classes=10)
model.summary()
MNIST example
(x_train,y_train),(x_test,y_test) = tf.keras.datasets.mnist.load_data()
x_train.shape
x_train = x_train/255.
x_test = x_test/255.
from tensorflow.keras.utils import to_categorical
y_train_oh = to_categorical(y_train,10)
y_test_oh = to_categorical(y_test,10)
num_train = x_train.shape[0]
num_test = x_test.shape[0]
img_height = x_train.shape[1]
img_width = x_train.shape[2]
num_channels = 1
x_train = x_train.reshape(num_train,1,img_height,img_width)
x_test = x_test.reshape(num_test,1,img_height,img_width)
opt= Adam(learning_rate=0.01)
model.compile(optimizer = opt,loss='categorical_crossentropy',metrics=['accuracy'])
def lr_decay(epoch):
return 0.01*math.pow(0.666,epoch)
lr_decay_cb = LearningRateScheduler(lr_decay,verbose=True)
model_check_cb = ModelCheckpoint('mnist_dwt.h5',save_best_only=True,monitor='val_loss')
history = model.fit(x_train,y_train_oh,validation_data=(x_test,y_test_oh),epochs=10,batch_size=64,
callbacks=[lr_decay_cb,model_check_cb])
best_model = tf.keras.models.load_model('mnist_dwt.h5',custom_objects={'DWT_Pooling':DWT_Pooling})
best_model.evaluate(x_test,y_test_oh)
y_preds = np.argmax(m.predict(x_test),axis=1)
from sklearn.metrics import classification_report,roc_auc_score
print(classification_report(y_test,y_preds))