Tensorflow学习笔记【3】加载数据集

图像分类Dome演示

1.步骤

  1. 准备要加载的数据,可以是路劲,也可以是numpy的数据,如果是路径,则在预处理需要进一步进行转化。
  2. 使用 tf.data.Dataset.from_tensor_slices() 函数进行加载。
  3. 使用 shuffle() 打乱数据。
  4. 使用 map() 函数进行预处理。
  5. 使用 batch() 函数设置 batch size 值。
  6. 使用prefetch()设置缓冲区提高性能。
  7. 根据需要 使用 repeat() 设置是否循环迭代数据集或者使用make_one_shot_iterator().get_next()进行迭代。

2.无标签

  # Create input data pipeline.
  with tf.device("/cpu:0"):
    train_files = glob.glob(filedir) #获取目录下所有的图片存入train_files
    if not train_files:
      raise RuntimeError(
          "No training images found with glob '{}'.".format(args.train_glob))
    train_dataset = tf.data.Dataset.from_tensor_slices(train_files) #进行数据装载,
    train_dataset = train_dataset.shuffle(buffer_size=len(train_files)).repeat() #对所有的图片进行随机
    train_dataset = train_dataset.map(
        read_png, num_parallel_calls=args.preprocess_threads) #预处理,包括多线程
    train_dataset = train_dataset.map(
        lambda x: tf.random_crop(x, (patchsize, patchsize, 3)))#进行裁剪
    train_dataset = train_dataset.batch(batchsize)  #设置batchsize
    train_dataset = train_dataset.prefetch(32)   #不知道,应该是缓冲区吧

  num_pixels = batchsize * patchsize ** 2

  # Get training patch from dataset.
  x = train_dataset.make_one_shot_iterator().get_next() #设置迭代
----------------------------------------------------------------------
def read_png(filename):
  """Loads a PNG image file."""
  string = tf.read_file(filename)
  image = tf.image.decode_image(string, channels=3)
  image = tf.cast(image, tf.float32)
  image /= 255
  return image

3.有标签

def get_files(filename):
 class_train = []
 label_train = []
 for train_class in os.listdir(filename):
  for pic in os.listdir(filename+train_class):
   class_train.append(filename+train_class+'/'+pic)
   label_train.append(train_class)
 temp = np.array([class_train,label_train])
 temp = temp.transpose()
 #shuffle the samples
 np.random.shuffle(temp)
 #after transpose, images is in dimension 0 and label in dimension 1
 image_list = list(temp[:,0])
 label_list = list(temp[:,1])
 label_list = [int(i) for i in label_list]
 #print(label_list)
 return image_list,label_list
def get_batches(image,label,resize_w,resize_h,batch_size,capacity):
 #convert the list of images and labels to tensor
 image = tf.cast(image,tf.string)
 label = tf.cast(label,tf.int64)
 queue = tf.train.slice_input_producer([image,label])
 label = queue[1]
 image_c = tf.read_file(queue[0])
 image = tf.image.decode_jpeg(image_c,channels = 3)
 #resize
 image = tf.image.resize_image_with_crop_or_pad(image,resize_w,resize_h)
 #(x - mean) / adjusted_stddev
 image = tf.image.per_image_standardization(image)
  
 image_batch,label_batch = tf.train.batch([image,label],
            batch_size = batch_size,
            num_threads = 64,
            capacity = capacity)
 images_batch = tf.cast(image_batch,tf.float32)
 labels_batch = tf.reshape(label_batch,[batch_size])
 return images_batch,labels_batch
  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值