Tensorflow VGG16训练cifar10

小白一枚,话不多说,直接上码,亲测能跑

预处理:

import numpy as np
import tensorflow as tf
import pickle
import os
import random


def _random_crop(batch, crop_shape, padding=None):
    oshape = batch[0].shape

    if padding:
        oshape = (oshape[0] + 2 * padding, oshape[1] + 2 * padding)
    new_batch = []
    npad = ((padding, padding), (padding, padding), (0, 0))
    for i in range(len(batch)):
        new_batch.append(batch[i])
        if padding:
            new_batch[i] = np.lib.pad(batch[i], pad_width=npad, mode='constant', constant_values=0)
        nh = random.randint(0, oshape[0] - crop_shape[0])
        nw = random.randint(0, oshape[1] - crop_shape[1])
        new_batch[i] = new_batch[i][nh:nh + crop_shape[0], nw:nw + crop_shape[1]]
    return new_batch


def _batch_random_flip_left_right(batch):
    new_batch = []
    for img in batch:
        if bool(random.randint(0, 1)):
            img = np.fliplr(img)
        new_batch.append(img)
    new_batch = np.array(new_batch)
    return new_batch


def training_data(data_dir):
    data_lst = os.listdir(data_dir)
    img = None
    labels = None
    for file in data_lst:
        if file in ['data_batch_1', 'data_batch_2', 'data_batch_3', 'data_batch_4', 'data_batch_5']:
            new_path = os.path.join(data_dir, file)
            with open(new_path, 'rb') as f:
                dict = pickle.load(f)
            if img is None:
                img = dict['data']
            else:
                img = np.vstack((img, dict['data']))
            if labels is None:
                labels = dict['labels']
            else:
                labels += dict['labels']
    cnt = 0
    images = []

    for image in img:
        r = image[:1024]
        g = image[1024:2048]
        b = image[2048:]
        r = np.array([r])
        g = np.array([g])
        b = np.array([b])
        r_t = r.T
        g_t = g.T
        b_t = b.T
        new_image = np.hstack((r_t, g_t, b_t))
        new_image = new_image.reshape([32, 32, 3])
        images.append(new_image)

        cnt += 1
    images = np.array(images)
    print '========Finish loading training data========'
    return images, labels


def testing_data(data_dir):
    data_lst = os.listdir(data_dir)
    img = None
    labels = None
    for file in data_lst:
        if file == 'test_batch':
            new_path = os.path.join(data_dir, file)
            with open(new_path, 'rb') as f:
                dict = pickle.load(f)
            if img is None:
                img = dict['data'
  • 1
    点赞
  • 11
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值