这两天真是极其糟糕,感觉诸事不训,而且文章下面还存在这问题目前没有解决,如果那位大神刚好路回,不坊顺路解决一下,小弟不胜感激,临博涕零,不知所言……
扯远了,进入正题,先说说我艰难的心路历程,本来兴高采烈的,因为发现google的gpu(想了解的话,我推荐这篇Google Colab 免费GPU服务器使用教程,因为那些挂载云盘的命令可以直接粘贴来用),那个忘乎所以,恨不得把之前想做的项目都给搞定了,但是,饭还得一口一口的吃,写完了数据打标签程序,接下来就是做数据了,由于多个文件上传的慢,又处于速度考虑,决定使用tfrecords格式,毕竟这也是tf的标准数据格式,支持多线程的。
生成代码自然不会有什么问题了,代码如下:
def _bytes_list(value):
if not isinstance(value,list):
value = [value]
return tf.train.Feature(bytes_list = tf.train.BytesList(value = value))
def _int64_list(value):
if not isinstance(value,list):
value = [value]
return tf.train.Feature(int64_list = tf.train.Int64List(value = value))
def convert_to_example(image):
data = Image.open(image)######
img = data.resize((224,224))#######
img = img.tobytes()#######
w = 224######
h = 224########
feature = {
'image':_bytes_list(img),
'w':_int64_list(w),
'h':_int64_list(h)
}
return tf.train.Example(features = tf.train.Features(feature = feature))
def convert_to_tfrecords(examples,path):
assert type(examples) == list or type(examples) == tuple
with tf.python_io.TFRecordWriter(path) as tfrecord_writer:
i = 0
for example in examples:
tfrecord_writer.write(example.SerializeToString())
i += 1
print('成功完成%g条数据'%i)
def run(images,path):
examples = []
for image in images:
examples.append(convert_to_example(image))
convert_to_tfrecords(examples,path)
上面的基本上到处都可以百度到,但是我要说的就是标记的那几行,最开始并没有resize,而是调用size获取他的shape,保存进去,这也没什么问题,但是,当读的时候就会报各种错误,大概就是说什么不相等之类的,看下之前错误的代码吧
def convert_to_example(image):
data = Image.open(image)
w = data.width
h = data.height
img = img.tobytes()
feature = {
'image':_bytes_list(img),
'w':_int64_list(w),
'h':_int64_list(h)
}
return tf.train.Example(features = tf.train.Features(feature = feature))
然后读取的时候一执行tf.reshape(img,[h,w,3])就报错,不明所以可能是PIL的问题吧,或者我的操作不对?另外一个问题是,当我采用tf原本的读取方式的时候,像下面这个样子
def read_records_test(filename):
img,label= read_records('./train_2_of_9000_file.tfrecords')
# img_batch,label_batch = tf.train.shuffle_batch([img,label],
# batch_size=1,
# min_after_dequeue=10,
# capacity=20)
#
coord = tf.train.Coordinator()
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
threads = tf.train.start_queue_runners(sess = sess,coord=coord)
for i in range(10):
val,l = sess.run([img,label])
print(val.shape,l)
coord.request_stop()
coord.join(threads)
如果去掉注释部分,就会报错,注释了还可以读取出来数据,why?上面的样例是实验在天池点击打开链接这个比赛的数据(当然比赛我是没有做),数据大概就是这样的
有些把后面的.jpg去掉就是动态图片,使用opencv读取不到,然后,我就直接用PIL了,目前成功的代码就是直接resize再存,读取实验如下:
def _parse_of_tfrecords(record):
features = {
'image':tf.FixedLenFeature([],tf.string),
'w':tf.FixedLenFeature([],tf.int64),
'h':tf.FixedLenFeature([],tf.int64)
}
parse_example = tf.parse_single_example(serialized=record,features = features)
image = tf.decode_raw(parse_example['image'],out_type=tf.uint8)
w = tf.cast(parse_example['w'],tf.int32)
h = tf.cast(parse_example['h'],tf.int32)
image = tf.reshape(image,shape=[h,w,3])
return image
def read_test(path):
dataset = tf.data.TFRecordDataset(path)
dataset = dataset.map(_parse_of_tfrecords)
dataset = dataset.batch(3).repeat(1)
iterator = dataset.make_one_shot_iterator()
next_batch = iterator.get_next()
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
for i in range(2):
image = sess.run(next_batch)
print(image)
return image
最后的结果就是这样的:
现在也大概是能用了,上面那些问题嘛,跪求大神指导!!!!