残差网络 ResNet 的 tensorflow 简单实现
前言
和前面几篇一样,这个也是没有使用到 slim 的朴素 tensorflow 实现,重复造轮子,大概是因为懒。
图片及代码参考来源于此。
首先是 ResNet 的整体架构,由普通的 CNN 网络 + 一些残差路径而已。
上图中,把残差块分为了 CONV BLOCK 和 ID BLOCK 区别如下。
ID BLOCK 的残余项就是 X,直接短路即可,如下图。
而CONV BLOCK 要对 X 进行一个卷积操作,再连接残余项,如下图。
准确率不太高,但也懒得调了。算了吧。
那接下来就可以直接上代码了。
ID_block 定义代码
def ID_block(X, channels_in, kernel_channels, is_training, name = 'ID_block'):
conv1 = conv_layer(X, 1, 1, channels_in, kernel_channels, is_training, name + '/conv1')
conv2 = conv_layer(conv1, 3, 1, kernel_channels, kernel_channels, is_training, name + '/conv2')
conv3 = conv_layer(conv2, 1, 1, kernel_channels, channels_in, is_training, name + '/conv3', False)
add = tf.add(conv3, X)
result = tf.nn.relu(add)
return result
CONV_block 定义代码
def CONV_block(X, channels_in, channels_out, is_training, name = 'CONV_block'):
conv1 = conv_layer(X, 1, 1, channels_in, channels_out, is_training, name + '/conv1')
conv2 = conv_layer(conv1, 3, 1, channels_out, channels_out, is_training, name + '/conv2')
conv3 = conv_layer(conv2, 1, 1, channels_out, channels_out, is_training, name + '/conv3', False)
short_cut = conv_layer(X, 3, 1, channels_in, channels_out, is_training, name + '/short_cut', False)
add = tf.add(conv3, short_cut)
result = tf.nn.relu(add)
return result
完整代码
由于我的数据相对简单,所以就跑了一下浅层的 ResNet,可以对应自己的数据,增加几层。
#!/usr/bin/env python
# coding: utf-8
# In[1]:
import numpy as np
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
os.system("rm -r logs")
import tensorflow as tf
get_ipython().run_line_magic('matplotlib', 'inline')
import matplotlib.pyplot as plt
from PIL import Image
# import multiprocessing
from multiprocessing import Process
import threading
import time
# In[2]:
TrainPath = '/home/winsoul/disk/MyML/data/tfrecord/train.tfrecords'
ValPath = '/home/winsoul/disk/MyML/data/tfrecord/val.tfrecords'
# In[3]:
def read_tfrecord(TFRecordPath):
with tf.Session() as sess:
feature = {
'image': tf.FixedLenFeature([], tf.string),
'label': tf.FixedLenFeature([], tf.int64)
}
# filename_queue = tf.train.string_input_producer([TFRecordPath], num_epochs = 1)
filename_queue = tf.train.string_input_producer([TFRecordPath]