python tensorflow学习(四)加大难度!AlexNet模型实现

本文介绍了如何在TensorFlow中实现AlexNet模型,详细讲解了从猫狗大战数据集的预处理,转换为TFRecord格式,到网络模型架构的搭建,包括模型训练、存储及结果可视化的过程。通过实例展示了深度卷积神经网络在实际问题中的应用。
摘要由CSDN通过智能技术生成

加大难度!AlexNet模型实现

Alex 网络[8]是现代意义上的深度卷积神经网络的起源,是在 2012 年被推出
的一个经典网络模型,取得了同年 ImageNet 比赛的最优成绩。相比于上一章的 LeNet-5,AlexNet 的层次显著加深,参数规模也显著变大。
大纲

  • 猫狗大战数据集处理
    • 数据集的加工
    • 转换为TFRecord格式
  • Alex网络模型实现
    • 网络模型架构
    • 模型的训练存储及结果可视化

猫狗大战数据集处理

猫狗大战的数据集来源于Kaggle上的一个竞赛:Dogs vs. Cats。该数据集包括12500张猫的图片以及12500张狗的图片,是一个二分类问题。官方提供了免费下载:下载地址点这里.如果不想注册账号,还有微软的版本:下载地址二
不同于MNIIST数据集,这个数据集均来自于真实的照片,tensorflow中也没有封装好的函数来读取该数据,所以需要对该数据进行预处理。

数据集的加工

首先,AlexNet模型的输入图像大小为227×227×3,所以需要把数据集的分辨率调整为该大小,这里使用opencv进行处理

# 把图片大小转换为227x227x3
def rebuild(dir):
    for root, dirs, files in os.walk(dir):
        for file in files:
            try:
                filepath = os.path.join(root, file)
                image = cv2.imread(filepath)
                dim = (227, 227)
                resized = cv2.resize(image, dim)
                print(file)
                path = "D:/cat_and_dog/Cat/"+file
                cv2.imwrite(path, resized)
            except IOError:
                print(filepath)
                os.remove(filepath)
        cv2.waitKey(0)   # 退出

rebuild("D:/PetImages/Cat")

对于损坏的数据,这里使用os.remove()直接删除。

转换为TFRecord格式

在第二章介绍过了TFRecord文件的创建和读取,传送门->python tensorflow学习(二) tensorflow数据的生成与读取 这里不再介绍,直接贴源代码:
获取数据集和标签:

# 设置标签
def get_file(file_dir):
    images = []
    temp = []
    for root, sub_folders, files in os.walk(file_dir):
        for name in files:
            images.append(os.path.join(root, name))

        for name in sub_folders:
            temp.append(os.path.join(root, name))
    labels = []
    for one_folder in temp:
        n_img = len(os.listdir(one_folder))
        letter = one_folder.split("\\")[-1]
        if letter=='Cat':
            labels = np.append(labels, n_img*[0])
        else:
            labels = np.append(labels, n_img*[1])

    temp = np.array([images, labels])
    temp = temp.transpose()
    np.random.shuffle(temp)

    image_list = list(temp[:, 0])
    label_list = list(temp[:, 1])
    label_list = [int(float(i)) for i in label_list]

    return image_list, label_list

imagelist, labellist = get_file("D:/cat_and_dog")

转换为TFRecord文件:

# 生成TFRecord文件
def int64_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=value))

def bytes_feature(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def ToTFRecord(image_list, label_list, save_dir, name):
    filename = os.path.join(save_dir, name+'.tfrecords')
    n_samples = len(imagelist)
    writer = tf.python_io.TFRecordWriter(filename)
    print("transform start...")
    for i in np.arange(0, n_samples):
        try:
            image = cv2.imread(image_list[i])
            image_raw = image.tostring()
            label = [int(label_list[i])]
            example = tf.train.Example(features=tf.train.Features(feature={
   
                'label':int64_feature(label),
                'image_raw':bytes_feature(image_raw)
            }))
            writer.write(example.SerializeToString())
        except IOError as e:
            print('could not read:', image_list[i])
    writer.close()
    print('transform done!')

其中save_dirname分别是存储路径和文件名。
至此TFRecord文件已经生成,还需要一个能读取该文件的函数来获取数据:

#  读取数据
def read_and_decode(tfrecord_file, batch_size):
    filename_queue = tf.train.string_input_producer([tfrecord_file])
    reader = tf.TFRecordReader()
    _, serialized_example = reader.read(filename_queue)
    img_features = tf.parse_single_example(
        serialized_example,
        features={
   
            'label': tf.FixedLenFeature([], tf.int64),
            'image_raw': tf.FixedLenFeature([], tf.string),
        })
    image = tf.decode_raw(img_features['image_raw'], tf.uint8)
    image = tf.reshape(image, [227, 227, 3])
    label = tf.cast(img_features['label'], tf.int32)
    image_batch, label_batch = tf.train.shuffle_batch([image, label],
                                                      batch_size=batch_size,
                                                      min_after_dequeue=100,
                                                      num_threads=64,
                                                      capacity=200
                                                      )
    return image_batch, tf.reshape(label_batch, [batch_size])

image_batch, label_batch = read_and_decode('cat_and_dog.tfrecords', 25)

该函数每次读取指定数目的数据集以便训练时提供。

网络模型架构

在实现每一层的架构之前,先对参数进行集中管理,这是一个很好的习惯:

# 集中管理参数
learning_rate = 1e-4  # 学习速率
training_iters = 200  # 迭代次数
batch_size = 50       # 每批的大小 
n_classes = 2         # 种类
n_fc1 = 4096          
n_fc2 = 2048          

# 构建模型
x = tf.placeholder(tf.float32, [None, 227, 227, 3])
y = tf.placeholder(tf.int32, [None, n_classes])

W_conv = {
   
    'conv1': tf.Variable(tf.truncated_normal([11, 11, 3, 96], stddev=0.0001)),
    'conv2': tf.Variable(tf.truncated_normal
  • 0
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 7
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 7
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值