代码结构
这里主要是参考了github上面一个u-net的程序的结构。链接
在项目中,程序员将整个神经网络分成了网络结构和训练两个类,并定义了一些函数来完成类似混淆矩阵生成这样的操作。
在这里,也是模仿了他的写法,这样会显得清晰一些。
代码
这里使用Keras库,采用的是Functional API
的搭建网络方式:
#!/usr/bin/env python2
# -*- coding: utf-8 -*-
"""
Created on Thu Jun 15 10:03:52 2017
a simple cnn classifier for mnist data using Functional API
@author: huijian
"""
from __future__ import print_function
import numpy as np
# set the seed for reproducibility
np.random.seed(1234)
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
# the libraries of keras
from keras.datasets import mnist
from keras.models import Sequential, Model
from keras.models import model_from_json
from keras.layers import Input
from keras.layers import Dense, Dropout, Activation, Flatten
from keras.layers import Conv2D, MaxPooling2D
from keras.utils import np_utils
from keras import backend as K
K.set_image_dim_ordering("tf")
"""
for a picture of a form (128,128,3) (img_row, img_col, channels)
th: (3,128,128) (channels,img_row,img_col)
tf: (128,128,3) (img_row, img_col, channels)
"""
def create_conv(img_shape, num_classes):
"""
param:
img_shape: a 1-D tensor, [img_row, img_col, channels]
"""
inputs = Input(shape=(img_shape[0], img_shape[1], img_shape[2]))
conv1 = Conv2D(filters = 30, kernel_size = [5,