theano-xnor-net代码注释9 pylearn2/cifar10.py

"""
.. todo::

    WRITEME
"""
import os
import logging

import numpy
from theano.compat.six.moves import xrange

from pylearn2.datasets import cache, dense_design_matrix
from pylearn2.expr.preprocessing import global_contrast_normalize
from pylearn2.utils import contains_nan
from pylearn2.utils import serial
from pylearn2.utils import string_utils


_logger = logging.getLogger(__name__)


class CIFAR10(dense_design_matrix.DenseDesignMatrix):

    """
    .. todo::

        WRITEME

    Parameters
    ----------
    which_set : str
        One of 'train', 'test'
    center : WRITEME
    rescale : WRITEME
    gcn : float, optional
        Multiplicative constant to use for global contrast normalization.
        No global contrast normalization is applied, if None
    start : WRITEME
    stop : WRITEME
    axes : WRITEME
    toronto_prepro : WRITEME
    preprocessor : WRITEME
    """

    def __init__(self, which_set, center=False, rescale=False, gcn=None,
                 start=None, stop=None, axes=('b', 0, 1, 'c'),
                 toronto_prepro = False, preprocessor = None):
        # note: there is no such thing as the cifar10 validation set;
        # pylearn1 defined one but really it should be user-configurable
        # (as it is here)

        self.axes = axes

        # we define here:
        dtype = 'uint8'
        ntrain = 50000
        nvalid = 0  # artefact, we won't use it
        ntest = 10000

        # we also expose the following details:
        self.img_shape = (3, 32, 32)
        #self.img_size存的是图片元素个数,3*32*32=3072
        self.img_size = numpy.prod(self.img_shape)
        #类别为10,0-9对应标签为label_names
        self.n_classes = 10
        self.label_names = ['airplane', 'automobile', 'bird', 'cat', 'deer',
                            'dog', 'frog', 'horse', 'ship', 'truck']

        # prepare loading
        #fnames为一个列表,存的是data_batch1~5
        fnames = ['data_batch_%i' % i for i in range(1, 6)]
        #datasets为一个空字典
        datasets = {}
        #${PYLEARN2_DATA_PATH}已经在配置pylearn之后存在个人系统目录.bashrc中了,为/home/ubuntu/pylearn2-data/data
        # datapath里存/home/ubuntu/pylearn2-data/data/cifar10/cifar-10-batches-py/
        datapath = os.path.join(
            string_utils.preprocess('${PYLEARN2_DATA_PATH}'),
            'cifar10', 'cifar-10-batches-py')
        #在data_batch1~5+test_batch六个文件中做循环
        for name in fnames + ['test_batch']:
            #当前fname为当前操作数据集文件的全路径
            fname = os.path.join(datapath, name)
            #如果文件不存在,raise一个error
            if not os.path.exists(fname):
                raise IOError(fname + " was not found. You probably need to "
                              "download the CIFAR-10 dataset by using the "
                              "download script in "
                              "pylearn2/scripts/datasets/download_cifar10.sh "
                              "or manually from "
                              "http://www.cs.utoronto.ca/~kriz/cifar.html")
            #将当前数据集文件快速缓存进datasets字典中
            datasets[name] = cache.datasetCache.cache_file(fname)
        #lenx数值就是50000
        lenx = int(numpy.ceil((ntrain + nvalid) / 10000.) * 10000)
        #设置一个全0矩阵x大小50000×3072,y大小50000×1的np.array
        x = numpy.zeros((lenx, self.img_size), dtype=dtype)
        y = numpy.zeros((lenx, 1), dtype=dtype)

        # load train data
        #下载训练集
        nloaded = 0
        #enumerate返回的是(引索值,当前迭代对象)
        for i, fname in enumerate(fnames):
            #将括号内信息存入log文件
            _logger.info('loading file %s' % datasets[fname])
            #从刚加载好的datasets字典中取出当前操作文件数据,存入data,python版本的cifar10本身就是一个字典,
            # 所以当前data就是一个字典,字典中有batch_label,labels,data,filenames四种信息
            data = serial.load(datasets[fname])
            #一个数据集中有10000个图片信息,对应data为10000个3072的np.array,labels对应10000个一维标签,依次取出5个对应训练数据集文件,按照顺序依次存入x与y
            x[i * 10000:(i + 1) * 10000, :] = data['data']
            y[i * 10000:(i + 1) * 10000, 0] = data['labels']
            #以下三行代码运行不到,在迭代完5个文件时候nloaded=50000,小于60000,此时循环就已经退出
            nloaded += 10000
            if nloaded >= ntrain + nvalid + ntest:
                break

        # load test data
        #加载测试集合
        #将括号内信息存入log文件
        _logger.info('loading file %s' % datasets['test_batch'])
        #加载'test_batch'测试集数据,存入data,前面data中信息已经清空
        data = serial.load(datasets['test_batch'])
        #重组数据
        # process this data
        #Xs为一个字典,‘train’关键字中存训练集的50000条图像数据,‘test’关键字中存测试集的10000条图像数据
        #Ys为一个字典,‘train’关键字中存训练集的50000个标签,‘test’关键字中存测试集的10000个标签
        Xs = {'train': x[0:ntrain],
              'test': data['data'][0:ntest]}

        Ys = {'train': y[0:ntrain],
              'test': data['labels'][0:ntest]}
        #which_set为调用CIFAR10类时候传如的参数,选择是[train、test]
        #即X为对应[train or test]的图像数据
        #y为对应[train or test]的标签

        X = numpy.cast['float32'](Xs[which_set])
        y = Ys[which_set]
        #在该数据集中标签的存储为一个列表list,该行代码是要将label转化为与data一样的ndarray格式
        if isinstance(y, list):
            y = numpy.asarray(y).astype(dtype)
        #如果测试数据集标签数不为10000,重新整理为(y.shape[0], 1)大小
        if which_set == 'test':
            assert y.shape[0] == 10000
            y = y.reshape((y.shape[0], 1))

        if center:
            X -= 127.5
        self.center = center

        if rescale:
            X /= 127.5
        self.rescale = rescale

        if toronto_prepro:
            assert not center
            assert not gcn
            X = X / 255.
            if which_set == 'test':
                other = CIFAR10(which_set='train')
                oX = other.X
                oX /= 255.
                X = X - oX.mean(axis=0)
            else:
                X = X - X.mean(axis=0)
        self.toronto_prepro = toronto_prepro

        self.gcn = gcn
        if gcn is not None:
            gcn = float(gcn)
            X = global_contrast_normalize(X, scale=gcn)

        if start is not None:
            # This needs to come after the prepro so that it doesn't
            # change the pixel means computed above for toronto_prepro
            assert start >= 0
            assert stop > start
            assert stop <= X.shape[0]
            X = X[start:stop, :]
            y = y[start:stop, :]
            assert X.shape[0] == y.shape[0]

        if which_set == 'test':
            assert X.shape[0] == 10000

        view_converter = dense_design_matrix.DefaultViewConverter((32, 32, 3),
                                                                  axes)

        super(CIFAR10, self).__init__(X=X, y=y, view_converter=view_converter,
                                      y_labels=self.n_classes)

        assert not contains_nan(self.X)

        if preprocessor:
            preprocessor.apply(self)

    def adjust_for_viewer(self, X):
        """
        .. todo::

            WRITEME
        """
        # assumes no preprocessing. need to make preprocessors mark the
        # new ranges
        rval = X.copy()

        # patch old pkl files
        if not hasattr(self, 'center'):
            self.center = False
        if not hasattr(self, 'rescale'):
            self.rescale = False
        if not hasattr(self, 'gcn'):
            self.gcn = False

        if self.gcn is not None:
            rval = X.copy()
            for i in xrange(rval.shape[0]):
                rval[i, :] /= numpy.abs(rval[i, :]).max()
            return rval

        if not self.center:
            rval -= 127.5

        if not self.rescale:
            rval /= 127.5

        rval = numpy.clip(rval, -1., 1.)

        return rval

    def __setstate__(self, state):
        super(CIFAR10, self).__setstate__(state)
        # Patch old pkls
        if self.y is not None and self.y.ndim == 1:
            self.y = self.y.reshape((self.y.shape[0], 1))
        if 'y_labels' not in state:
            self.y_labels = 10

    def adjust_to_be_viewed_with(self, X, orig, per_example=False):
        """
        .. todo::

            WRITEME
        """
        # if the scale is set based on the data, display X oring the
        # scale determined by orig
        # assumes no preprocessing. need to make preprocessors mark
        # the new ranges
        rval = X.copy()

        # patch old pkl files
        if not hasattr(self, 'center'):
            self.center = False
        if not hasattr(self, 'rescale'):
            self.rescale = False
        if not hasattr(self, 'gcn'):
            self.gcn = False

        if self.gcn is not None:
            rval = X.copy()
            if per_example:
                for i in xrange(rval.shape[0]):
                    rval[i, :] /= numpy.abs(orig[i, :]).max()
            else:
                rval /= numpy.abs(orig).max()
            rval = numpy.clip(rval, -1., 1.)
            return rval

        if not self.center:
            rval -= 127.5

        if not self.rescale:
            rval /= 127.5

        rval = numpy.clip(rval, -1., 1.)

        return rval

    def get_test_set(self):
        """
        .. todo::

            WRITEME
        """
        return CIFAR10(which_set='test', center=self.center,
                       rescale=self.rescale, gcn=self.gcn,
                       toronto_prepro=self.toronto_prepro,
                       axes=self.axes)
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值