一 场景描述
因为有时候我们的数据集是由xml-csv-record文件转换过来的,而record文件又是序列化文件,所以有时候我们想检查以后原tfrecord文件标注是否有问题,就会比较麻烦。搜索到了很多老哥的可视化代码,但不知为何,均会出现错误,在github看到一个项目,
https://github.com/EricThomson/tfrecord-view
修改了用起来看起来还是可以的
二 代码
import cv2
import numpy as np
import tensorflow.compat.v1 as tf
tf.enable_eager_execution()
import warnings
warnings.filterwarnings('ignore', category = FutureWarning) #tf 1.14 and np 1.17 are clashing: temporary solution
def cv_bbox(image, bbox, color = (255, 255, 255), line_width = 2):
"""
use opencv to add bbox to an image
assumes bbox is in standard form x1 y1 x2 y2
"""
cv2.rectangle(image, (bbox[0], bbox[1]), (bbox[2], bbox[3]), color, line_width)
return
def parse_record(data_record):
"""
parse the data record from a tfrecord file, typically pulled from an iterator,
in this case a one_shot_iterator created from the dataset.
"""
feature = {'image/encoded': tf.FixedLenFeature([], tf.string),
'image/object/class/label': tf.VarLenFeature(tf.int64),
'image/object/bbox/xmin': tf.VarLenFeature(tf.float32),
'image/object/bbox/ymin': tf.VarLenFeature(tf.float32),
'image/object/bbox/xmax': tf.VarLenFeature(tf.float32),
'image/object/bbox/ymax': tf.VarLenFeature(tf.float32),
'image/filename': tf.FixedLenFeature([], tf.string)
}
return tf.parse_single_example(data_record, feature)
def view_records(file_path, class_labels, stride = 1, verbose = 1):
"""
peek at the data using opencv and tensorflow tools.
Inputs:
file_path: path to tfrecord file (usually has 'record' extension)
class_labels: dictionary of labels with name:number pairs (start with 1)
stride (default 1): how many records to jump (you might have thousands so skip a few)
verbose (default 1): display text output if 1, display nothing except images otherwise.
Usage:
Within the image window, enter 'n' for next image, 'esc' to stop seeing images.
"""
dataset = tf.data.TFRecordDataset([file_path])
record_iterator = dataset.make_one_shot_iterator()
num_records = dataset.reduce(np.int64(0), lambda x, _: x + 1).numpy()
if verbose:
print(f"\nGoing through {num_records} records with a stride of {stride}.")
print("Enter 'n' to bring up next image in record.\n")
for im_ind in range(num_records):
#Parse and process example
parsed_example = parse_record(record_iterator.get_next())
if im_ind % stride != 0:
continue
fname = parsed_example['image/filename'].numpy()
encoded_image = parsed_example['image/encoded']
image_np = tf.image.decode_image(encoded_image, channels=3).numpy()
labels = tf.sparse_tensor_to_dense(parsed_example['image/object/class/label'], default_value=0).numpy()
x1norm = tf.sparse_tensor_to_dense( parsed_example['image/object/bbox/xmin'], default_value=0).numpy()
x2norm = tf.sparse_tensor_to_dense( parsed_example['image/object/bbox/xmax'], default_value=0).numpy()
y1norm = tf.sparse_tensor_to_dense( parsed_example['image/object/bbox/ymin'], default_value=0).numpy()
y2norm = tf.sparse_tensor_to_dense( parsed_example['image/object/bbox/ymax'], default_value=0).numpy()
num_bboxes = len(labels)
#% Process and display image
height, width = image_np[:, :, 1].shape
image_copy = image_np.copy()
image_rgb = cv2.cvtColor(image_copy, cv2.COLOR_BGR2RGB)
if num_bboxes > 0:
x1 = np.int64(x1norm*width)
x2 = np.int64(x2norm*width)
y1 = np.int64(y1norm*height)
y2 = np.int64(y2norm*height)
for bbox_ind in range(num_bboxes):
bbox = (x1[bbox_ind], y1[bbox_ind], x2[bbox_ind], y2[bbox_ind])
label_name = list(class_labels.keys())[list(class_labels.values()).index(labels[bbox_ind])]
label_position = (bbox[0] + 5, bbox[1] + 20)
cv_bbox(image_rgb, bbox, color = (250, 250, 150), line_width = 2)
cv2.putText(image_rgb,
label_name,
label_position,
cv2.FONT_HERSHEY_SIMPLEX,
1, (10, 10, 255), 2); #scale, color, thickness
if verbose:
print(f"\nImage {im_ind}")
print(f" {fname}")
print(f" Height/width: {height, width}")
print(f" Num bboxes: {num_bboxes}")
cv2.imshow("bb data", image_rgb)
k = cv2.waitKey()
if k == 27:
break
elif k == ord('n'):
continue
cv2.destroyAllWindows()
if verbose:
print("\n\ntfrecord-view: done going throug the data.")
if __name__ == '__main__':
class_labels = {"Jam" : 1, "Target": 2,"Clutter": 3 }
#Make the following using voc_to_tfr.py
data_path = r"train.record"
verbose = 1
stride = 1
view_records(data_path, class_labels, stride = stride, verbose = verbose)
三 修改
将该代码放到你的数据集所在路径,修改以下部分:
然后在当前路径下执行即可
四 效果展示
这里应该每一个box对应一张图
五 其他有趣补充
https://github.com/ahmetcetin/tfrecord-viewer
有看到这个浏览文件位置选择文件就可以一键可视化,不过需要安装docker和npm,就没看了,有兴趣的同学可自行查看