#
#作者:韦访
#博客:https://blog.csdn.net/rookie_wei
#微信:1007895847
#添加微信的备注一下是CSDN的
#欢迎大家一起学习
#
------韦访 20190627
1、概述
上一讲我们将训练的代码跑起来了,这一讲开始真正的来分析源码了。看代码的时候要结合论文看,才能看懂。
2、计算网络的输入输出大小
打开train.py文件,从main函数开始看,
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Training codes for Openpose using Tensorflow')
parser.add_argument('--model', default='cmu', help='model name')
parser.add_argument('--datapath', type=str, default='G:/tensorflow/post-estimation/datasets/annotations')
parser.add_argument('--imgpath', type=str, default='G:/tensorflow/post-estimation/datasets/')
parser.add_argument('--batchsize', type=int, default=16)
parser.add_argument('--gpus', type=int, default=1)
parser.add_argument('--max-epoch', type=int, default=600)
parser.add_argument('--lr', type=str, default='0.001')
parser.add_argument('--tag', type=str, default='test')
parser.add_argument('--checkpoint', type=str, default='')
parser.add_argument('--input-width', type=int, default=432)
parser.add_argument('--input-height', type=int, default=368)
parser.add_argument('--quant-delay', type=int, default=-1)
args = parser.parse_args()
modelpath = logpath = './models/train/'
if args.gpus <= 0:
raise Exception('gpus <= 0')
首先是一些默认的参数的设置,上一讲说过就不再说了,继续往下看。
# define input placeholder
# 设置图片的宽高,做数据增强时用到
set_network_input_wh(args.input_width, args.input_height)
scale = 4
if args.model in ['cmu', 'vgg'] or 'mobilenet' in args.model:
# 因为CMU使用了VGG19网络,做了3次步长为2的max_pool操作,即缩小了8倍
scale = 8
# 设置scale,做数据增强时用到
set_network_scale(scale)
output_w, output_h = args.input_width // scale, args.input_height // scale
set_network_input_wh和set_network_scale函数将我们参数里设置的输入的宽和高还有scale传到pose_augment模块,上一讲讲过了,如果model用的是cmu网络,因为cmu使用了VGG19的前10层网络,做了3次步长为2的maxpool操作,所以每一次maxpool操作后,网络缩小2倍,经过3次则共缩小8倍,所以最终scale=8。所以输出的宽和高就等于
output_w, output_h = args.input_width // scale, args.input_height // scale。
3、定义占位符
继续往下看,
logger.info('define model+')
with tf.device(tf.DeviceSpec(device_type="CPU")):
# 定义占位符
# 输入图像 shape=(16, 368, 432, 3)
input_node = tf.placeholder(tf.float32, shape=(args.batchsize, args.input_height, args.input_width, 3), name='image')
# 向量图 shape=(16, 46, 54, 38)
vectmap_node = tf.placeholder(tf.float32, shape=(args.batchsize, output_h, output_w, 38), name='vectmap')
# 热图 shape=(16, 46, 54, 19)
heatmap_node = tf.placeholder(tf.float32, shape=(args.batchsize, output_h, output_w, 19), name='heatmap')
上面定义了3个占位符,分别是输入图像input_node,向量图(论文中的肢干矢量图,我瞎翻译的)vectmap_node,热图(论文中的关节点置信图,我瞎翻译的)heatmap_node,向量图和热图的深度分别为38和19是因为论文就是这样的。输入的图片深度是3。
4、数据增强
继续往下看,
# prepare data
# 初始化数据
# args.datapath:annotations
# batchsize:batchsize
# imgpath:dataset
# 解析 person_keypoints_train2017.json 文件,将训练数据存到队列
df = get_dataflow_batch(args.datapath, True, args.batchsize, img_path=args.imgpath)
来看看get_dataflow_batch函数,
def get_dataflow_batch(path, is_train, batchsize, img_path=None):
logger.info('dataflow img_path=%s' % img_path)
ds = get_dataflow(path, is_train, img_path=img_path)
ds = BatchData(ds, batchsize)
# if is_train:
# ds = PrefetchData(ds, 10, 2)
# else:
# ds = PrefetchData(ds, 50, 2)
return ds
先来看看get_dataflow函数,
def get_dataflow(path, is_train, img_path=None):
# 要读取的json是person_keypoints_train2017.json,图片文件在train2017
ds = CocoPose(path, img_path, is_train) # read data from lmdb
if is_train:
ds = MapData(ds, read_image_url)
# 数据增强处理
ds = MapDataComponent(ds, pose_random_scale)
ds = MapDataComponent(ds, pose_rotation)
ds = MapDataComponent(ds, pose_flip)
ds = MapDataComponent(ds, pose_resize_shortestedge_random)
ds = MapDataComponent(ds, pose_crop_random)
ds = MapData(ds, pose_to_img)
# augs = [
# imgaug.RandomApplyAug(imgaug.RandomChooseAug([
# imgaug.GaussianBlur(max_size=3)
# ]), 0.7)
# ]
# ds = AugmentImageComponent(ds, augs)
ds = PrefetchData(ds, 1000, multiprocessing.cpu_count() * 1)
else:
ds = MultiThreadMapData(ds, nr_thread=16, map_func=read_image_url, buffer_size=1000)
ds = MapDataComponent(ds, pose_resize_shortestedge_fixed)
ds = MapDataComponent(ds, pose_crop_center)
ds = MapData(ds, pose_to_img)
# 第二个参数:size of the queue to hold prefetched datapoints.
ds = PrefetchData(ds, 100, multiprocessing.cpu_count() // 4)
return ds
先来看看CocoPose的__init__函数,
def __init__(self, path, img_path=None, is_train=True, decode_img=True, only_idx=-1):
self.is_train = is_train
self.decode_img = decode_img
self.only_idx = only_idx
# 如果是训练,就取train2017里的数据,否则取val2017里的数据
if is_train:
whole_path = os.path.join(path, 'person_keypoints_train2017.json')
else:
whole_path = os.path.join(path, 'person_keypoints_val2017.json')
# 完整的路径
self.img_path = (img_path if img_path is not None else '') + ('train2017/' if is_train else 'val2017/')
# COCO API
self.coco = COCO(whole_path)
logger.info('%s dataset %d' % (path, self.size()))
上面代码中,如果是训练的话,就解析person_keypoints_train2017.json文件,否则解析person_keypoints_val2017.json文件,再实例化COCO类。COCO类我没了解过,它在我们安装的pycocotools模块中,把它当成解析COCO数据集的工具就可以了,这不是我们现在分析的重点。
回到get_dataflow函数,篇幅有限,调重点来讲,read_image_url函数读取解析到的图片文件,pose_random_scale、pose_rotation、pose_flip、pose_resize_shortestedge_random、pose_crop_random函数都是做一些数据增强的操作,比如,随机缩放、随机旋转、随机裁剪等。这里要单独实现数据增强的函数,而不用官方的函数是因为我们在对图片进行数据增强的时候,对应的标签里的关键点坐标值也会发生变化,所以还得计算数据增强以后的关键点的坐标。这也不是我们分析的重点,重点来看pose_to_img函数,
def pose_to_img(meta_l):
global _network_w, _network_h, _scale
# print('wf====>>>>pose_to_img _scale:', _scale, ' _network_w:', _network_w, ' _network_h:', _network_h, 'os.getpid():', os.getpid(), 'os.getppid():', os.getppid())
return [
# 输入图像
meta_l[0].img.astype(np.float16),
# 生成热图
meta_l[0].get_heatmap(target_size=(_network_w // _scale, _network_h // _scale)),
# 生成矢量图
meta_l[0].get_vectormap(target_size=(_network_w // _scale, _network_h // _scale))
]
先来看get_heatmap函数,target_size的宽高分别是输入图像的宽高除以_scale,对应我们论文的那么网络的话,这里的_scale=8,
# 生成热图
@jit
def get_heatmap(self, target_size):
heatmap = np.zeros((CocoMetadata.__coco_parts, self.height, self.width), dtype=np.float32)
# 解析每个人的关键点坐标
for joints in self.joint_list:
# 解析某个人的关键点
for idx, point in enumerate(joints):
# 如果有坐标是负数,则表示该关键点不存在
if point[0] < 0 or point[1] < 0:
continue
CocoMetadata.put_heatmap(heatmap, idx, point, self.sigma)
heatmap = heatmap.transpose((1, 2, 0))
# background
heatmap[:, :, -1] = np.clip(1 - np.amax(heatmap, axis=2), 0.0, 1.0)
# 缩放heatmap尺寸
if target_size:
heatmap = cv2.resize(heatmap, target_size, interpolation=cv2.INTER_AREA)
return heatmap.astype(np.float16)
上面的代码中,self.joint_list包含一张图片中所有人的关键点坐标,可能包含多个人。将这些关键点存到self.joint_list的函数是CocoMetadata的__init__函数,
# idx: 下标
# img_url: 图片url
# img_meta: 当前 image 的信息,宽高,连接,id等
# anns: image的annotations信息,segmentation,关键点个数,关键点,image_id等
# sigma=8.0
def __init__(self, idx, img_url, img_meta, annotations, sigma):
self.idx = idx
self.img_url = img_url
self.img = None
self.sigma = sigma
self.height = int(img_meta['height'])
self.width = int(img_meta['width'])
# 获取关节点的坐标,从keypoints中解析,一张图片中有可能有个人的关键点坐标
joint_list = []
# print('======================')
for ann in annotations:
# 没有关键点,跳过
if ann.get('num_keypoints', 0) == 0:
continue
# 找到关键点
kp = np.array(ann['keypoints'])
# 从0开始,隔3个取,共17个
xs = kp[0::3]
# 从1开始,隔3个取
ys = kp[1::3]
# 从2开始,隔3个取
# v 有3种状态, 0:未标注, 1:标注了,但是被遮挡了,2:标注了,且能看到
vs = kp[2::3]
# print('xs:', xs)
# print('ys:', ys)
# print('vs:', vs)
# vs >=1, 表示对应的xs和ys是关键点,如果没有关键点,用(-1000, -1000)代替
joint_list.append([(x, y) if v >= 1 else (-1000, -1000) for x, y, v in zip(xs, ys, vs)])
# print('joint_list:', joint_list)
# 这里放的就是关节的坐标,可能有多个人的
self.joint_list = []
# print('-----------------------')
# (6, 7) 点是脊柱的坐标,由左肩(8, 8)和右肩(7, 7)数据生成的,数据集没有该点的标注
# 因为coco数据集的关键点的index跟我们代码的不一致,所以这里相当于做个映射
transform = list(zip(
[1, 6, 7, 9, 11, 6, 8, 10, 13, 15, 17, 12, 14, 16, 3, 2, 5, 4],
[1, 7, 7, 9, 11, 6, 8, 10, 13, 15, 17, 12, 14, 16, 3, 2, 5, 4]
))
# print('transform:', transform)
# 重新生成我们要的关键点的顺序和数据
for prev_joint in joint_list:
new_joint = []
for idx1, idx2 in transform:
# print('idx1:', idx1)
# print('idx2:', idx2)
j1 = prev_joint[idx1-1]
j2 = prev_joint[idx2-1]
# print('j1:', j1)
# print('j2:', j2)
if j1[0] <= 0 or j1[1] <= 0 or j2[0] <= 0 or j2[1] <= 0:
new_joint.append((-1000, -1000))
else:
new_joint.append(((j1[0] + j2[0]) / 2, (j1[1] + j2[1]) / 2))
# 第19个关节点,数据集里没有这个点的标注
new_joint.append((-1000, -1000))
self.joint_list.append(new_joint)
# print('self.joint_list:', self.joint_list)
# print('joint size=%d' % len(self.joint_list))
# logger.debug('joint size=%d' % len(self.joint_list))
回到get_heatmap函数,接着看put_heatmap函数做了什么,
@staticmethod
@jit(nopython=True)
def put_heatmap(heatmap, plane_idx, center, sigma):
# 关键点坐标
center_x, center_y = center
# 置信图的高和宽
_, height, width = heatmap.shape[:3]
th = 4.6052
# 求平方根
delta = math.sqrt(th * 2)
# 以(center_x, center_y)为中点,(x0, y0)做左上角,(x1, y1)为右下角,组成一个方框,作为热图
x0 = int(max(0, center_x - delta * sigma))
y0 = int(max(0, center_y - delta * sigma))
x1 = int(min(width, center_x + delta * sigma))
y1 = int(min(height, center_y + delta * sigma))
for y in range(y0, y1):
for x in range(x0, x1):
# 高斯核函数
d = (x - center_x) ** 2 + (y - center_y) ** 2
exp = d / 2.0 / sigma / sigma
if exp > th:
continue
heatmap[plane_idx][y][x] = max(heatmap[plane_idx][y][x], math.exp(-exp))
heatmap[plane_idx][y][x] = min(heatmap[plane_idx][y][x], 1.0)
上面的函数就是我们解析论文那篇博客的第5点,不清楚的可以看下面的链接,
https://blog.csdn.net/rookie_wei/article/details/90705880
这样,就通过数据集生成了热图。接着看怎么生成矢量图,回到pose_to_img函数,看看get_vectormap函数做了什么,
@jit
def get_vectormap(self, target_size):
vectormap = np.zeros((CocoMetadata.__coco_parts*2, self.height, self.width), dtype=np.float32)
countmap = np.zeros((CocoMetadata.__coco_parts, self.height, self.width), dtype=np.int16)
# 解析每个人
for joints in self.joint_list:
for plane_idx, (j_idx1, j_idx2) in enumerate(CocoMetadata.__coco_vecs):
# __coco_vecs都减一了才对应我们要的关节
j_idx1 -= 1
j_idx2 -= 1
# 起始关节
center_from = joints[j_idx1]
# 结束关节
center_to = joints[j_idx2]
if center_from[0] < -100 or center_from[1] < -100 or center_to[0] < -100 or center_to[1] < -100:
continue
CocoMetadata.put_vectormap(vectormap, countmap, plane_idx, center_from, center_to)
vectormap = vectormap.transpose((1, 2, 0))
nonzeros = np.nonzero(countmap)
for p, y, x in zip(nonzeros[0], nonzeros[1], nonzeros[2]):
if countmap[p][y][x] <= 0:
continue
# 除以在这个像素点上存在关节的人的数量,要不然就叠加了
vectormap[y][x][p*2+0] /= countmap[p][y][x]
vectormap[y][x][p*2+1] /= countmap[p][y][x]
if target_size:
vectormap = cv2.resize(vectormap, target_size, interpolation=cv2.INTER_AREA)
return vectormap.astype(np.float16)
同样也是从self.joint_list获取每个人关键点的数据,重点来看put_vectormap函数,
@staticmethod
@jit(nopython=True)
def put_vectormap(vectormap, countmap, plane_idx, center_from, center_to, threshold=8):
# 矢量图的高和宽
_, height, width = vectormap.shape[:3]
# 这个关节的矢量的x和y,即(vec_x, vec_y)就是该关节的矢量
vec_x = center_to[0] - center_from[0]
vec_y = center_to[1] - center_from[1]
# 求关节矢量左上角坐标xy轴分别-threshold
min_x = max(0, int(min(center_from[0], center_to[0]) - threshold))
min_y = max(0, int(min(center_from[1], center_to[1]) - threshold))
# 求关节矢量右下角坐标xy轴分别+threshold
max_x = min(width, int(max(center_from[0], center_to[0]) + threshold))
max_y = min(height, int(max(center_from[1], center_to[1]) + threshold))
# 求关节矢量长度
norm = math.sqrt(vec_x ** 2 + vec_y ** 2)
# 长度为0就不管了
if norm == 0:
return
# x轴方向的单位向量
vec_x /= norm
# y轴方向的单位向量
vec_y /= norm
# print('=============================')
#
for y in range(min_y, max_y):
for x in range(min_x, max_x):
# (bec_x, bec_y)为(x, y)到(center_from_x, center_from_y)组成的向量
bec_x = x - center_from[0]
bec_y = y - center_from[1]
# 点 (x, y) 到 (vec_x, vec_y) 的垂直距离
dist = abs(bec_x * vec_y - bec_y * vec_x)
if dist > threshold:
continue
countmap[plane_idx][y][x] += 1
# 保存该关节相对于水平方向的余弦值cos
vectormap[plane_idx*2+0][y][x] = vec_x
# 保存该关节相对于水平方向的正弦值sin
vectormap[plane_idx*2+1][y][x] = vec_y
# print('vec_x:', vec_x)
# print('vec_y:', vec_y)
上面的函数做的就是我们解析论文的那篇博客的第6点,不明白的可以回去看论文。
6、可视化
接下来我们来可视化上面的代码得到的效果,得到解析论文博客中的第5和第6点的结果。利用pose_dataset.py最底部的main函数来实现,
if __name__ == '__main__':
os.environ['CUDA_VISIBLE_DEVICES'] = ''
from pose_augment import set_network_input_wh, set_network_scale
# set_network_input_wh(368, 368)
set_network_input_wh(480, 320)
set_network_scale(8)
# df = get_dataflow('/data/public/rw/coco/annotations', True, '/data/public/rw/coco/')
df = _get_dataflow_onlyread('/data/public/rw/coco/annotations', True, '/data/public/rw/coco/')
# df = get_dataflow('/root/coco/annotations', False, img_path='http://gpu-twg.kakaocdn.net/braincloud/COCO/')
from tensorpack.dataflow.common import TestDataSpeed
TestDataSpeed(df).start()
sys.exit(0)
with tf.Session() as sess:
df.reset_state()
t1 = time.time()
for idx, dp in enumerate(df.get_data()):
if idx == 0:
for d in dp:
logger.info('%d dp shape={}'.format(d.shape))
print(time.time() - t1)
t1 = time.time()
CocoPose.display_image(dp[0], dp[1].astype(np.float32), dp[2].astype(np.float32))
print(dp[1].shape, dp[2].shape)
pass
logger.info('done')
先将
df = _get_dataflow_onlyread('/data/public/rw/coco/annotations', True, '/data/public/rw/coco/')
改为
df = _get_dataflow_onlyread('G:/tensorflow/post-estimation/datasets/annotations', True, 'G:/tensorflow/post-estimation/datasets/')
这里的两个地址是我数据集存放的地址,不明白的,看上一篇博客,
https://blog.csdn.net/rookie_wei/article/details/93658329
再将
sys.exit(0)
删掉。
好了,先直接运行看能不能通过,执行下面的命令,
set PYTHONPATH=G:\tensorflow\post-estimation\tf-pose-estimation-master
python tf_pose\pose_dataset.py
这个G:\tensorflow\post-estimation\tf-pose-estimation-master是我项目的根目录,你们根据自己的情况改,运行结果,
from ._conv import register_converters as _register_converters
loading annotations into memory...
Done (t=9.33s)
creating index...
index created!
[2019-07-03 22:06:49,618] [pose_dataset] [INFO] G:/tensorflow/post-estimation/datasets/annotations dataset 118287
2019-07-03 22:06:49,618 INFO G:/tensorflow/post-estimation/datasets/annotations dataset 118287
0%| |0/5000[00:00<?,?it/s]
Traceback (most recent call last):
File "tf_pose\pose_dataset.py", line 558, in <module>
TestDataSpeed(df).start()
File "D:\Anaconda3\lib\site-packages\tensorpack\dataflow\common.py", line 56, in start
for idx, dp in enumerate(itr):
File "D:\Anaconda3\lib\site-packages\tensorpack\dataflow\common.py", line 292, in __iter__
ret = self.func(copy(dp)) # shallow copy the list
File "G:\tensorflow\post-estimation\tf-pose-estimation-master\tf_pose\pose_augment.py", line 273, in pose_to_img
meta_l[0].get_heatmap(target_size=(_network_w // _scale, _network_h // _scale)),
cv2.error: OpenCV(4.1.0) C:\projects\opencv-python\opencv\modules\imgproc\src\resize.cpp:3555: error: (-215:Assertion failed) func != 0 && cn <= 4 in function 'cv::hal::resize'
果然我们的运气还是那么差,运行出错。来解决它,看提示,问题出在pose_to_img函数的get_heatmap函数的OpenCV的resize函数里,并且提示从TestDataSpeed(df).start()开始出的问题,好了,先看TestDataSpeed(df).start()的源码,
class TestDataSpeed(ProxyDataFlow):
""" Test the speed of some DataFlow """
def __init__(self, ds, size=5000, warmup=0):
"""
Args:
ds (DataFlow): the DataFlow to test.
size (int): number of datapoints to fetch.
warmup (int): warmup iterations
"""
super(TestDataSpeed, self).__init__(ds)
self.test_size = int(size)
self.warmup = int(warmup)
def __iter__(self):
""" Will run testing at the beginning, then produce data normally. """
self.start()
for dp in self.ds:
yield dp
def start(self):
"""
Start testing with a progress bar.
"""
self.ds.reset_state()
itr = self.ds.__iter__()
if self.warmup:
for _ in tqdm.trange(self.warmup, **get_tqdm_kwargs()):
next(itr)
# add smoothing for speed benchmark
with get_tqdm(total=self.test_size,
leave=True, smoothing=0.2) as pbar:
for idx, dp in enumerate(itr):
pbar.update()
if idx == self.test_size - 1:
break
没看出什么,那么,这个pose_to_img函数是怎么被调用的?TestDataSpeed(df).start()传入一个df的参数,那么这个df是怎么来的?往上看,
df = _get_dataflow_onlyread('G:/tensorflow/post-estimation/datasets/annotations', True, 'G:/tensorflow/post-estimation/datasets/')
Ok,看看这个_get_dataflow_onlyread函数做了什么?
def _get_dataflow_onlyread(path, is_train, img_path=None):
ds = CocoPose(path, img_path, is_train) # read data from lmdb
ds = MapData(ds, read_image_url)
ds = MapData(ds, pose_to_img)
# ds = PrefetchData(ds, 1000, multiprocessing.cpu_count() * 4)
return ds
pose_to_img就是在这里被调用的,这个问题还真不好找,因为我对MapData的用法不熟悉,看到_get_dataflow_onlyread函数的上面的get_dataflow函数,
def get_dataflow(path, is_train, img_path=None):
# 要读取的json是person_keypoints_train2017.json,图片文件在train2017
ds = CocoPose(path, img_path, is_train) # read data from lmdb
if is_train:
ds = MapData(ds, read_image_url)
# 数据增强处理
ds = MapDataComponent(ds, pose_random_scale)
ds = MapDataComponent(ds, pose_rotation)
ds = MapDataComponent(ds, pose_flip)
ds = MapDataComponent(ds, pose_resize_shortestedge_random)
ds = MapDataComponent(ds, pose_crop_random)
ds = MapData(ds, pose_to_img)
# augs = [
# imgaug.RandomApplyAug(imgaug.RandomChooseAug([
# imgaug.GaussianBlur(max_size=3)
# ]), 0.7)
# ]
# ds = AugmentImageComponent(ds, augs)
ds = PrefetchData(ds, 1000, multiprocessing.cpu_count() * 1)
else:
ds = MultiThreadMapData(ds, nr_thread=16, map_func=read_image_url, buffer_size=1000)
ds = MapDataComponent(ds, pose_resize_shortestedge_fixed)
ds = MapDataComponent(ds, pose_crop_center)
ds = MapData(ds, pose_to_img)
# 第二个参数:size of the queue to hold prefetched datapoints.
ds = PrefetchData(ds, 100, multiprocessing.cpu_count() // 4)
return ds
这里在两个MapData中间还调用了MapDataComponent函数,对图片数据做一些数据增强处理,那么,我们模仿它的用法试试看,将
def _get_dataflow_onlyread(path, is_train, img_path=None):
ds = CocoPose(path, img_path, is_train) # read data from lmdb
ds = MapData(ds, read_image_url)
ds = MapData(ds, pose_to_img)
# ds = PrefetchData(ds, 1000, multiprocessing.cpu_count() * 4)
return ds
函数改为,
def _get_dataflow_onlyread(path, is_train, img_path=None):
print('CocoPose-------------')
ds = CocoPose(path, img_path, is_train) # read data from lmdb
print('CocoPose======')
ds = MapData(ds, read_image_url)
ds = MapDataComponent(ds, pose_random_scale)
ds = MapDataComponent(ds, pose_rotation)
ds = MapDataComponent(ds, pose_flip)
ds = MapDataComponent(ds, pose_resize_shortestedge_random)
ds = MapDataComponent(ds, pose_crop_random)
print('MapData-------------')
ds = MapData(ds, pose_to_img)
print('MapData======')
# ds = PrefetchData(ds, 1000, multiprocessing.cpu_count() * 4)
return ds
再运行,
没问题了,接着去喝杯茶等待进度条到100%...
好不容易等到进度条到了100%,就打印了一堆数据,如下
0.09376311302185059
(46, 54, 19) (46, 54, 38)
0.26555824279785156
(46, 54, 19) (46, 54, 38)
0.2968161106109619
(46, 54, 19) (46, 54, 38)
0.1718287467956543
(46, 54, 19) (46, 54, 38)
上面的数据没意思啊,是在for循环里print打印的,我们来看看CocoPose.display_image函数做了什么?
@staticmethod
def display_image(inp, heatmap, vectmap, as_numpy=False):
global mplset
# if as_numpy and not mplset:
# import matplotlib as mpl
# mpl.use('Agg')
mplset = True
import matplotlib.pyplot as plt
fig = plt.figure()
a = fig.add_subplot(2, 2, 1)
a.set_title('Image')
plt.imshow(CocoPose.get_bgimg(inp))
a = fig.add_subplot(2, 2, 2)
a.set_title('Heatmap')
plt.imshow(CocoPose.get_bgimg(inp, target_size=(heatmap.shape[1], heatmap.shape[0])), alpha=0.5)
tmp = np.amax(heatmap, axis=2)
plt.imshow(tmp, cmap=plt.cm.gray, alpha=0.5)
plt.colorbar()
tmp2 = vectmap.transpose((2, 0, 1))
tmp2_odd = np.amax(np.absolute(tmp2[::2, :, :]), axis=0)
tmp2_even = np.amax(np.absolute(tmp2[1::2, :, :]), axis=0)
a = fig.add_subplot(2, 2, 3)
a.set_title('Vectormap-x')
plt.imshow(CocoPose.get_bgimg(inp, target_size=(vectmap.shape[1], vectmap.shape[0])), alpha=0.5)
plt.imshow(tmp2_odd, cmap=plt.cm.gray, alpha=0.5)
plt.colorbar()
a = fig.add_subplot(2, 2, 4)
a.set_title('Vectormap-y')
plt.imshow(CocoPose.get_bgimg(inp, target_size=(vectmap.shape[1], vectmap.shape[0])), alpha=0.5)
plt.imshow(tmp2_even, cmap=plt.cm.gray, alpha=0.5)
plt.colorbar()
if not as_numpy:
plt.show()
else:
fig.canvas.draw()
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
fig.clear()
plt.close()
return data
上面的代码想用matplotlib将图片显示出来的,但是没有显示出来,这种情况还是经常见到的,我也不想折腾的了,最简单的方法,将图片保存到本地文件,再打开来看。将
if not as_numpy:
plt.show()
改成
if not as_numpy:
plt.show()
plt.savefig('h.png')
再运行,运行结果,
如果您感觉本篇博客对您有帮助,请打开支付宝,领个红包支持一下,祝您扫到99元,谢谢~~