Resnet学习笔记(四)--train_imagenet.py

最后是第四部分,作为最上层的train_imagenet.py,从顶层控制神经网络的运作。

import skimage.io  # bug. need to import this before tensorflow
import skimage.transform  # bug. need to import this before tensorflow
from resnet_train import train
import tensorflow as tf
import time
import os
import sys
import re
import numpy as np
import pdb

from synset import *
from image_processing import image_preprocessing
from resnet import inference

FLAGS = tf.app.flags.FLAGS
tf.app.flags.DEFINE_string('data_dir', '/home/modric/Downloads/train',
                           'imagenet dir') #图像存放路径

#获取所有图像的名称
def file_list(data_dir):
    dir_txt = data_dir + ".txt"
    filenames = []
    with open(dir_txt, 'r') as f:
        for line in f:
            if line[0] == '.': continue
            line = line.rstrip()
            fn = os.path.join(data_dir, line)
            filenames.append(fn)
    return filenames

#装载数据
def load_data(data_dir):
    data = []
    i = 0

    print "listing files in", data_dir
    start_time = time.time()
    files = file_list(data_dir)
    duration = time.time() - start_time
    print "took %f sec" % duration #list files所花的时间

    for img_fn in files:
        ext = os.path.splitext(img_fn)[1]
        if ext != '.jpg': continue #若不是jpg格式则不读入

        label_name = re.search(r'(n\d+)', img_fn).group(1)
    fn = os.path.join(data_dir, img_fn)

        label_index = synset_map[label_name]["index"] #从synset_map中读取label_index

        data.append({
            "filename": fn,
            "label_name": label_name,
            "label_index": label_index,
            "desc": synset[label_index],
        })

    return data

#distort图像
def distorted_inputs():
    data = load_data(FLAGS.data_dir) #调用装载函数

    filenames = [ d['filename'] for d in data ]
    #print(filenames)
    label_indexes = [ d['label_index'] for d in data ]

    filename, label_index = tf.train.slice_input_producer([filenames, label_indexes], shuffle=True) #将队列的QueueRunner添加到当前图形的QUEUE_RUNNER集合中

    num_preprocess_threads = 1 #处理的线程个数
    images_and_labels = []
    for thread_id in range(num_preprocess_threads):
        image_buffer = tf.read_file(filename) #将file读到buffer中

        bbox = []
        train = True
        image = image_preprocessing(image_buffer, bbox, train, thread_id) #对图像进行distort
        image = tf.image.resize_images(image, [300, 300], )
        image = tf.image.resize_image_with_crop_or_pad(image, 224, 224) #因为我输入的图像不是224*224的,所以有了这两步操作
        images_and_labels.append([image, label_index])

    images, label_index_batch = tf.train.batch_join(
        images_and_labels,
        batch_size=FLAGS.batch_size,
        capacity=2 * num_preprocess_threads * FLAGS.batch_size) #制作batch

    height = FLAGS.input_size
    width = FLAGS.input_size
    depth = 3

    images = tf.cast(images, tf.float32)
    images = tf.reshape(images, shape=[FLAGS.batch_size, height, width, depth])

    return images, tf.reshape(label_index_batch, [FLAGS.batch_size])


def main(_):
    images, labels = distorted_inputs()
    is_training = tf.placeholder('bool', [], name='is_training')
    logits = inference(images,
                       num_classes=1000,
                       is_training=True,
                       bottleneck=False,
                       num_blocks=[2, 2, 2, 2]) #调用神经网络
    train(is_training,logits, images, labels) #训练


if __name__ == '__main__':
    tf.app.run()

通过对resnet的学习可以对整个神经网络的组成结构和主要操作有一个大致的了解,有助于下一步学习。

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值