批训练——根据mnist数据集的批训练源代码改编

因为数据集为图像,图像尺寸比较大 ,在gpu的环境下,设置的训练样本数也不能过大,否则会造成显存溢出。为了训练出的模型更有效,我采取了批训练方式。

我在mnist.py的基础上,针对实验的需求进行了修改。

import numpy
from tensorflow.python.framework import dtypes, random_seed


class DataSet(object):
  """
  Container class for a dataset .
  """

  def __init__(self,
               images,
               dtype=dtypes.float32,
               seed=None):
    """
    Construct a DataSet.
    `dtype` can be either`uint8` to leave the input as `[0, 255]`, or `float32` to rescale into`[0, 1]`.
    Seed arg provides for convenient deterministic testing.
    """
    seed1, seed2 = random_seed.get_seed(seed)
    # If op level seed is not set, use whatever graph level seed is returned
    numpy.random.seed(seed1 if seed is None else seed2)
    dtype = dtypes.as_dtype(dtype).base_dtype
    if dtype not in (dtypes.uint8, dtypes.float32):
      raise TypeError(
          'Invalid image dtype %r, expected uint8 or float32' % dtype)

    # if dtype == dtypes.float32:
    #   # Convert from [0, 255] -> [0.0, 1.0].
    #   images = images.astype(numpy.float32)
    #   images = numpy.multiply(images, 1.0 / 255.0)

    self._images = images
    self._num_examples = images.shape[0]
    self._epochs_completed = 0
    self._index_in_epoch = 0

  @property
  def images(self):
    return self._images


  @property
  def num_examples(self):
    return self._num_examples

  @property
  def epochs_completed(self):
    return self._epochs_completed

  def next_batch(self, batch_size, shuffle=True):
    """Return the next `batch_size` examples from this data set."""

    start = self._index_in_epoch
    # Shuffle for the first epoch
    if self._epochs_completed == 0 and start == 0 and shuffle:
      perm0 = numpy.arange(self._num_examples)
      numpy.random.shuffle(perm0)
      self._images = self.images[perm0]
    # Go to the next epoch
    if start + batch_size > self._num_examples:
      # Finished epoch
      self._epochs_completed += 1
      # Get the rest examples in this epoch
      rest_num_examples = self._num_examples - start
      images_rest_part = self._images[start:self._num_examples]
      # Shuffle the data
      if shuffle:
        perm = numpy.arange(self._num_examples)
        numpy.random.shuffle(perm)
        self._images = self.images[perm]
      # Start next epoch
      start = 0
      self._index_in_epoch = batch_size - rest_num_examples
      end = self._index_in_epoch
      images_new_part = self._images[start:end]
      return numpy.concatenate(
          (images_rest_part, images_new_part), axis=0)
    else:
      self._index_in_epoch += batch_size
      end = self._index_in_epoch
      return self._images[start:end]


class Datasets(object):
  def __init__(self,
               train,
               validation):
    self._train = train
    self._validation = validation

  @property
  def train(self):
    return self._train

  @property
  def validation(self):
    return self._validation

def read_data_sets(data_dir,
                   dtype=dtypes.float32,
                   validation_size=100,
                   seed=None):
  images = numpy.load(data_dir)
  # print(type(images))  # <class 'numpy.ndarray'>

  if not 0 <= validation_size <= len(images):
    raise ValueError('Validation size should be between 0 and {}. Received: {}.'
                     .format(len(images), validation_size))

  validation_images = images[:validation_size]
  train_images = images[validation_size:]
  # print(type(train_images))  # <class 'numpy.ndarray'>

  options = dict(dtype=dtype, seed=seed)

  train = DataSet(train_images, **options)
  validation = DataSet(validation_images, **options)

  return Datasets(train=train, validation=validation)


# def load_rock(data_dir):
#   return read_data_sets(data_dir)

 

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值