因为数据集为图像,图像尺寸比较大 ,在gpu的环境下,设置的训练样本数也不能过大,否则会造成显存溢出。为了训练出的模型更有效,我采取了批训练方式。
我在mnist.py的基础上,针对实验的需求进行了修改。
import numpy
from tensorflow.python.framework import dtypes, random_seed
class DataSet(object):
"""
Container class for a dataset .
"""
def __init__(self,
images,
dtype=dtypes.float32,
seed=None):
"""
Construct a DataSet.
`dtype` can be either`uint8` to leave the input as `[0, 255]`, or `float32` to rescale into`[0, 1]`.
Seed arg provides for convenient deterministic testing.
"""
seed1, seed2 = random_seed.get_seed(seed)
# If op level seed is not set, use whatever graph level seed is returned
numpy.random.seed(seed1 if seed is None else seed2)
dtype = dtypes.as_dtype(dtype).base_dtype
if dtype not in (dtypes.uint8, dtypes.float32):
raise TypeError(
'Invalid image dtype %r, expected uint8 or float32' % dtype)
# if dtype == dtypes.float32:
# # Convert from [0, 255] -> [0.0, 1.0].
# images = images.astype(numpy.float32)
# images = numpy.multiply(images, 1.0 / 255.0)
self._images = images
self._num_examples = images.shape[0]
self._epochs_completed = 0
self._index_in_epoch = 0
@property
def images(self):
return self._images
@property
def num_examples(self):
return self._num_examples
@property
def epochs_completed(self):
return self._epochs_completed
def next_batch(self, batch_size, shuffle=True):
"""Return the next `batch_size` examples from this data set."""
start = self._index_in_epoch
# Shuffle for the first epoch
if self._epochs_completed == 0 and start == 0 and shuffle:
perm0 = numpy.arange(self._num_examples)
numpy.random.shuffle(perm0)
self._images = self.images[perm0]
# Go to the next epoch
if start + batch_size > self._num_examples:
# Finished epoch
self._epochs_completed += 1
# Get the rest examples in this epoch
rest_num_examples = self._num_examples - start
images_rest_part = self._images[start:self._num_examples]
# Shuffle the data
if shuffle:
perm = numpy.arange(self._num_examples)
numpy.random.shuffle(perm)
self._images = self.images[perm]
# Start next epoch
start = 0
self._index_in_epoch = batch_size - rest_num_examples
end = self._index_in_epoch
images_new_part = self._images[start:end]
return numpy.concatenate(
(images_rest_part, images_new_part), axis=0)
else:
self._index_in_epoch += batch_size
end = self._index_in_epoch
return self._images[start:end]
class Datasets(object):
def __init__(self,
train,
validation):
self._train = train
self._validation = validation
@property
def train(self):
return self._train
@property
def validation(self):
return self._validation
def read_data_sets(data_dir,
dtype=dtypes.float32,
validation_size=100,
seed=None):
images = numpy.load(data_dir)
# print(type(images)) # <class 'numpy.ndarray'>
if not 0 <= validation_size <= len(images):
raise ValueError('Validation size should be between 0 and {}. Received: {}.'
.format(len(images), validation_size))
validation_images = images[:validation_size]
train_images = images[validation_size:]
# print(type(train_images)) # <class 'numpy.ndarray'>
options = dict(dtype=dtype, seed=seed)
train = DataSet(train_images, **options)
validation = DataSet(validation_images, **options)
return Datasets(train=train, validation=validation)
# def load_rock(data_dir):
# return read_data_sets(data_dir)