因为比赛需要用到ResNet,要自己写一个网络框架,就自己手写了一个ResNet18。使用的数据集是cifar10,最后跑出来的结果很差,只有50%左右,并且震荡严重(有时候准确率会下降)。有大神能帮忙指出下问题在哪吗?感谢!
以下是我手写的resnet18代码,cifar数据集是手动下载的。
import numpy as np
import math
import cv2 as cv
import pickle
import tensorflow as tf
import matplotlib.pyplot as plt
import os
# os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
# mnist = input_data.read_data_sets('MNIST_data',one_hot = True)
#参数配置
train_step = 2000
batch_size = 32
def load_file(filename):
with open(filename, 'rb') as fo:
data = pickle.load(fo,encoding = 'latin1')
return data
class Cifar:
def __init__(self, filename, need_shuffle):
all_data = []
all_label = []
for each in filename:
temp = load_file(each)
data = temp['data']
all_data.append(data)
label = temp['labels']
all_label.append(label)
self._data = np.vstack(all_data)
self._data = (self._data /127.5)-1
self._label = np.hstack(all_label)
self._example_num = self._data.shape[0]
self._shuffle = need_shuffle
self._begin = 0
if self._shuffle:
self._shuffle_data()
def _shuffle_data(self):
p = np.random.permutation(self._example_num)
self._data = self._data[p]
self._label = self._label[p]
def next_batch(self, batch_size):
end = self._begin + batch_size
if end > self._example_num:
if self._shuffle:
self._shuffle_data()
self._begin = 0
end = batch_size
else:
raise Exception('error1')
if end>self._example_num:
raise Exception('error2')
out_data = self._data[self._begin: end]
out_label = self._label[self._begin: end]
self._begin = end
return out_data,out_label
def Resnet50(input,train_):
#第0层卷积
with tf.variable_scope('conv0_'):
w0 = tf.get_variable('conv0', [5,5,3,64], initializer = tf.truncated_normal_initializer(stddev = 0.1), dtype = tf.float16)
y0 = tf.nn.conv2d(input, w0, strides = [1,2,2,1], p