import os
filenames=r'C:\Users\Administrator\Desktop\test3\hymenoptera_data\train\ants\0013035.jpg'
class_name = os.path.basename(os.path.dirname(filenames)) # 提取文件所在的文件夹名,即分类名,用于后续做标签
print(class_name)
# ants
2、解析所有图片路径
def_get_filenames_and_classes(dataset_dir):"""Returns a list of filenames and inferred class names.
返回所有图片对应的路径以及文件名(对应每一类)
Args:
dataset_dir: A directory containing a set of subdirectories representing
class names. Each subdirectory should contain PNG or JPG encoded images.
Returns:
A list of image file paths, relative to `dataset_dir` and the list of
subdirectories, representing class names.
"""
flower_root = os.path.join(dataset_dir, 'flower_photos')
directories = []
class_names = []
for filename in os.listdir(flower_root):
path = os.path.join(flower_root, filename)
if os.path.isdir(path): # 如果是文件夹
directories.append(path)
class_names.append(filename) # 文件名对应类名 用于做标签
photo_filenames = [] # 存储了所有图片的路径for directory in directories:
for filename in os.listdir(directory):
path = os.path.join(directory, filename) #图片完整路径
photo_filenames.append(path)
return photo_filenames, sorted(class_names)
3、图像转成tfr数据
def_convert_dataset(split_name, filenames, class_names_to_ids, dataset_dir):"""Converts the given filenames to a TFRecord dataset. 图像转成tfr数据
Args:
split_name: The name of the dataset, either 'train' or 'validation'.
filenames: A list of absolute paths to png or jpg images. 存放所有图片的路径 list
class_names_to_ids: A dictionary from class names (strings) to ids
(integers).
dataset_dir: The directory where the converted datasets are stored.
"""assert split_name in ['train', 'validation']
num_per_shard = int(math.ceil(len(filenames) / float(_NUM_SHARDS))) # 将总的文件分成_NUM_SHARDS份,每一份是num_per_shard个图片with tf.Graph().as_default(): # 新建图表
image_reader = ImageReader()
with tf.Session('') as sess:
for shard_id in range(_NUM_SHARDS):
output_filename = _get_dataset_filename(
dataset_dir, split_name, shard_id)
with tf.python_io.TFRecordWriter(output_filename) as tfrecord_writer:
start_ndx = shard_id * num_per_shard
end_ndx = min((shard_id+1) * num_per_shard, len(filenames)) # 每num_per_shard个图片生成一个tfr文件for i in range(start_ndx, end_ndx):
sys.stdout.write('\r>> Converting image %d/%d shard %d' % (
i+1, len(filenames), shard_id)) # 输出进度条
sys.stdout.flush()
# Read the filename:
image_data = tf.gfile.FastGFile(filenames[i], 'rb').read()
height, width = image_reader.read_image_dims(sess, image_data)
class_name = os.path.basename(os.path.dirname(filenames[i])) # 类名(文件夹名)
class_id = class_names_to_ids[class_name] # 换成对应的类的id 如 0,1等
example = dataset_utils.image_to_tfexample(
image_data, b'jpg', height, width, class_id)
tfrecord_writer.write(example.SerializeToString())
sys.stdout.write('\n')
sys.stdout.flush()
4、数据下载及解压
defdownload_and_uncompress_tarball(tarball_url, dataset_dir):"""Downloads the `tarball_url` and uncompresses it locally.
Args:
tarball_url: The URL of a tarball file.
dataset_dir: The directory where the temporary files are stored.
"""
filename = tarball_url.split('/')[-1]
filepath = os.path.join(dataset_dir, filename)
def_progress(count, block_size, total_size):# 下载进度
sys.stdout.write('\r>> Downloading %s %.1f%%' % (
filename, float(count * block_size) / float(total_size) * 100.0))
sys.stdout.flush()
filepath, _ = urllib.request.urlretrieve(tarball_url, filepath, _progress)
print()
statinfo = os.stat(filepath)
print('Successfully downloaded', filename, statinfo.st_size, 'bytes.')
tarfile.open(filepath, 'r:gz').extractall(dataset_dir)