import tensorflow as tf
# import Ipynb_importer
from PIL import Image
import os
import matplotlib.pyplot as plt
from TFrecorder import get_batch_record
import numpy as np
from resnet import resnet18
# https://www.cnblogs.com/shiningstar/p/12758705.html
# SBSB
# 1、制作数据集
# 2、搭建网络结构,训练模型
# 3、反向传播,输入图片和标签,使得loss最小化的代码
# 3、测试集验证
# 4、单张图片预测
def weight_variable(shape):
initial = tf.truncated_normal(shape ,stddev=0.1)
return tf.Variable(initial)
def bias_variable(shape):
initial = tf.constant(0.1 ,shape = shape)
return tf.Variable(initial)
def conv2d(input ,filter ,strides ,padding="SAME"):
return tf.nn.conv2d(input ,filter ,strides ,padding="SAME")
def resnet18(input):
kernel_1 = weight_variable([7 ,7 ,3 ,64])
bias_1 = weight_variable([64])
layer_1 = tf.nn.relu(conv2d(input ,kernel_1 ,strides=[1 ,2 ,2 ,1]) + bias_1)
Maxpool_1 = tf.nn.max_pool(layer_1 ,ksize=[1 ,3 ,3 ,1] ,strides=[1 ,2 ,2 ,1] ,padding="SAME")
kernel_2 = weight_variable([3 ,3 ,64 ,64])
layer_2 = tf.nn.relu(conv2d(Maxpool_1 ,kernel_2 ,strides=[1 ,1 ,1 ,1]))
kernel_3 = weight_variable([3 ,3 ,64 ,64])
layer_3 = conv2d(layer_2 ,kernel_3 ,strides=[1 ,1 ,1 ,1])
res1 = tf.nn.relu(Maxpool_1 +layer_3)
kernel_4 = weight_variable([3 ,3 ,64 ,64])
layer_4 = tf.nn.relu(conv2d(res1 ,kernel_4 ,strides=[1 ,1 ,1 ,1]))
kernel_5 = weight_variable([3 ,3 ,64 ,64])
layer_5 = conv2d(layer_4 ,kernel_5 ,strides=[1 ,1 ,1 ,1])
res2 = tf.nn.relu(res1 +layer_5)
kernel_6 = weight_variable([3 ,3 ,64 ,128])
layer_6 = tf.nn.relu(conv2d(res2 ,kernel_6 ,strides=[1 ,2 ,2 ,1]))
kernel_7 = weight_variable([3 ,3 ,128 ,128])
layer_7 = conv2d(layer_6 ,kernel_7 ,strides=[1 ,1 ,1 ,1])
kernel_line_1 = weight_variable([1 ,1 ,64 ,128])
res3 = tf.nn.relu(conv2d(res2 ,kernel_line_1 ,strides=[1 ,2 ,2 ,1]) + layer_7)
kernel_8 = weight_variable([3 ,3 ,128 ,128])
layer_8 = tf.nn.relu(conv2d(res3 ,kernel_8 ,strides=[1 ,1 ,1 ,1]))
kernel_9 = weight_variable([3 ,3 ,128 ,128])
layer_9 = conv2d(layer_8 ,kernel_9 ,strides=[1 ,1 ,1 ,1])
res4 = tf.nn.relu(res3 +layer_9)
kernel_10 = weight_variable([3 ,3 ,128 ,256])
layer_10 = tf.nn.relu(conv2d(res4 ,kernel_10 ,strides=[1 ,2 ,2 ,1]))
kernel_11 = weight_variable([3 ,3 ,256 ,256])
layer_11 = conv2d(layer_10 ,kernel_11 ,strides=[1 ,1 ,1 ,1])
kernel_line_2 = weight_variable([1 ,1 ,128 ,256])
res5 = tf.nn.relu(conv2d(res4 ,kernel_line_2 ,strides=[1 ,2 ,2 ,1] ) +layer_11)
kernel_12 = weight_variable([3 ,3 ,256 ,256])
layer_12 = tf.nn.relu(conv2d(res5 ,kernel_12 ,strides=[1 ,1 ,1 ,1]))
kernel_13 = weight_variable([3 ,3 ,256 ,256])
layer_13 = conv2d(layer_12 ,kernel_13 ,strides=[1 ,1 ,1 ,1])
res6 = tf.nn.relu(res5 +layer_13)
kernel_14 = weight_variable([3 ,3 ,256 ,512])
layer_14 = tf.nn.relu(conv2d(res6 ,kernel_14 ,strides=[1 ,2 ,2 ,1]))
kernel_15 = weight_variable([3 ,3 ,512 ,512])
layer_15 = conv2d(layer_14 ,kernel_15 ,strides=[1 ,1 ,1 ,1])
kernel_line_3 = weight_variable([1 ,1 ,256 ,512])
res7 = tf.nn.relu(conv2d(res6 ,kernel_line_3 ,strides=[1 ,2 ,2 ,1] )+ layer_15)
kernel_16 = weight_variable([3 ,3 ,512 ,512])
layer_16 = tf.nn.relu(conv2d(res7 ,kernel_16 ,strides=[1 ,1 ,1 ,1]))
kernel_17 = weight_variable([3 ,3 ,512 ,512])
layer_17 = conv2d(layer_16 ,kernel_17 ,strides=[1 ,1 ,1 ,1])
res8 = tf.nn.relu(layer_17 +res7)
avgpool = tf.nn.avg_pool(res8 ,ksize=[1 ,7 ,7 ,1] ,strides=[1 ,1 ,1 ,1] ,padding="VALID")
line = tf.reshape(avgpool ,[-1 ,512])
fc_18 = weight_variable([512 ,2])
bias_18 = bias_variable([2])
layer_18 = tf.matmul(line ,fc_18 ) +bias_18
return layer_18
batch_size=20
filename = "data/record/train.tfrecords"
filename_test = "data/record/test.tfrecords"
num_classes = 2
img_w = 224
img_h = 224
x = tf.placeholder(tf.float32, [None, img_w, img_h, 3])
y = tf.placeholder(tf.float32, [None, num_classes])
prediction=resnet18(x)
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y,logits=prediction))
train_step=tf.train.AdamOptimizer(0.001).minimize(loss)
correct_prediction=tf.equal(tf.argmax(y,1),tf.argmax(prediction,1))#argmax求最大的概率
accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
image_batch, label_batch = get_batch_record(filename, batch_size, img_w, img_h)
image_batch_test, label_batch_test = get_batch_record(filename_test, batch_size, img_w, img_h)
init = tf.global_variables_initializer()
saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(init)
saver.restore(sess, "net/my_resnet18.ckpt")
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess, coord)
#image, label = sess.run([image_batch, label_batch])
for i in range(302):
image, label = sess.run([image_batch, label_batch])
#image_test, label_test = sess.run([image_batch_test, label_batch_test])
sess.run(train_step,feed_dict={x:image,y:label})
l = sess.run(loss,feed_dict={x:image,y:label})
#acc = sess.run(accuracy,feed_dict={x:image_test,y:label_test})
if i%20==0:
print("iter: "+str(i)+" loss "+str(l))
saver.save(sess,"net/my_resnet18.ckpt")
coord.request_stop() # 7
coord.join(threads)
path = "D:\code\resnet\data"
train_record_path = "data/record/train.tfrecords"
test_record_path = "data/record/test.tfrecords"
classes={'bottle','paper'} #人为 设定 2 类
def _byteslist(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
def _int64list(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
def create_train_record():
writer = tf.python_io.TFRecordWriter(train_record_path)
NUM = 1
for index, name in enumerate(classes):
class_path = path + "/" + name + '/'
l = int(len(os.listdir(class_path))*0.7)
print("create tf "+str(index))
for image_name in os.listdir(class_path)[:l]:
image_path = class_path + image_name
img = Image.open(image_path)
img = img.resize((224, 224))
img_raw = img.tobytes()
example = tf.train.Example(
features=tf.train.Features(feature={
'label': _int64list(index),
'img_raw': _byteslist(img_raw)}))
writer.write(example.SerializeToString())
print('creat train record in ', NUM)
NUM += 1
writer.close()
print('creat_train_record success !')
def create_test_record():
writer = tf.python_io.TFRecordWriter(test_record_path)
NUM = 1
for index,name in enumerate(classes):
class_path = path + "/"+name+"/"
l = int(len(os.listdir(class_path))*0.7)
for image_name in os.listdir(class_path)[l:]:
image_path = class_path + image_name
img = Image.open(image_path)
img = img.resize((224,224))
img_raw = img.tobytes()
example = tf.train.Example(
features = tf.train.Features(feature={
'label':_int64list(index),
'img_raw':_byteslist(img_raw)}))
writer.write(example.SerializeToString())
print('creat test record in',NUM)
NUM+=1
writer.close()
print('creat_test_record success !')
def read_record(filename, img_w, img_h):
filename_queue = tf.train.string_input_producer([filename])
reader = tf.TFRecordReader()
_, serialize_example = reader.read(filename_queue)
feature = tf.parse_single_example(
serialize_example,
features={
'label': tf.FixedLenFeature([], tf.int64),
'img_raw': tf.FixedLenFeature([], tf.string)})
label = feature['label']
img = feature['img_raw']
img = tf.decode_raw(img, tf.uint8)
img = tf.reshape(img, (224, 224, 3))
img = tf.image.resize_image_with_crop_or_pad(img, img_w, img_h)
img = tf.cast(img, tf.float32)/255
label = tf.cast(label, tf.int32)
return img, label
def get_batch_record(filename, batch_size, img_W, img_H):
image, label = read_record(filename, img_W, img_H)
image_batch, label_batch= tf.train.shuffle_batch([image, label],
batch_size=batch_size,
capacity=30,
min_after_dequeue=10)
label_batch = tf.one_hot(label_batch,depth=2)
return image_batch, label_batch
# 88
# 89
# 90 # if __name__ == '__main__':
# 91 # img, label = get_batch_record(test_record_path,1,224,224)
# 92 # print(img)
# 93
# 94 # img, label = get_batch_record(test_record_path,2, 224, 224)
# 95
# 96 # with tf.Session() as sess:
# 97 #
# 98 # sess.run(tf.global_variables_initializer())
# 99 # sess.run(tf.local_variables_initializer())
# 100 # coord = tf.train.Coordinator()
# 101 # threads = tf.train.start_queue_runners(sess, coord)
# 102 # for i in range(200):
# 103 # image, l =sess.run([img, label])
# 104 # print(image[0].shape)
# 105
# 106 # #print(image[1].shape)
# 107 # print(l[0])
# 108
# 109 # plt.imshow(image[0])
# 111 # plt.show()
# 112 # coord.request_stop()
具体是谁就不收了。。反正网上很多信息,各种真真假假,难以实用的。也就那样,随便他们了。
又给更新了,,库文件好多东西又要变化了。。烦,各种bug…