由于工作需要,要做海量视频检索。但是视频是一种复杂的文件形式,不能直接拿来做检索。所以,要先将视频解码成图像的形式,借用图像检索即以图搜图的思想来实现,所以如何把很长的视频文件解码并提取关键帧就显的尤为重要。
博主在查阅大量论文资料的基础上,总结出了一个切实可行的方案:即利用聚类的思想,把从视频中解码出的海量帧图像自适应的聚类,然后选取每个聚类中与聚类中心最为接近的图像最为该类的代表,从而实现从巨量的视频帧活得关键帧。
在这里要说明一下,博主的代码都是自己一行一行琢磨出来的,并且实际用到了工作中,所以本博主不会提供源代码,但是会详细记录下算法实现的过程,并且欢迎大家来讨论和交流。
一、简单的数据集的聚类大家都能做,这个网上有很多现成的代码。而图像的聚类要比数据集的聚类更为复杂,主要是图像数据都是高维的而且数据量较大,如果直接用图像数据间的欧式距离来作为聚类的标准,计算量过于庞大,而且会出现内存溢出的情况,因为所有的数据都要加载到内存中进行运算,这是非常不可取的也是效率极低的算法,虽然有实现的可能性。
二、所以对于图像数据的聚类,思想是降维,进一步说就是聚类的准则要进行改变。不能再以图像之间的直接欧式距离来做度量。论文中提到了一种很好的方法用来做聚类标准,即图像间的相似性度量。这种方法不用计算欧式距离,而是计算图像颜色直方图之间的最大相似性,然后根据这个数进行聚类。
算法过程如下:
(1)转换图像颜色空间:RGB2HSV
(2)计算图像在HSV颜色空间的颜色直方图,但是要做类似归一化的处理。即H颜色直方图分成12份,S、V分成5份,然后在分别计算他们的颜色直方图。这样做是为了简化后续的计算,当然你也可以有别的分法,只不过我建议还是直接用论文中的数据比较好,毕竟是经过大量的实验证明的。
(3)开始聚类:
这里要说一下,聚类是逐帧进行的,这样一是可以覆盖整个数据集使得聚类的结果真实可靠,而且不会漏帧,基本可以提取到数量足够而又准确的关键帧。
首先:图像数据集中的第一帧图像作为初始聚类中心,之后拿出一帧和当前所有的聚类中心进行相似性度量,如果其中最大的相似性值仍然小于你给定的一个阈值thrshold,则是表明该帧与所有聚类中心的距离都太远,所以要自成一类。然后重复此过程,直到取完所有的帧。
然后:聚类中心的计算。每一次当前类中加入一张新的图像时,就要重新计算聚类中心。即求一次平均值。
(4)聚类完成后,在每个聚类中心中计算与聚类中心最为接近的帧作为关键帧。
至此,整个关键帧的提取过程就完成了。
参考代码:
import cv2
import multiprocessing
import numpy as np
from sklearn.cluster import Birch
import matplotlib.pyplot as plt
def video2frame(video_name):
total_frames = []
cap = cv2.VideoCapture(video_name)
c = 1
fps = cap.get(cv2.CAP_PROP_FPS)
print('fps is {}'.format(fps))
if cap.isOpened() == False:
print('Error opening video stream of file')
while cap.isOpened():
ret,frame = cap.read()
if ret == True:
if c==1:
total_frames.append(frame)
elif c%int(fps)==0:
total_frames.append(frame)
else:
break
c += 1
cap.release()
return total_frames
def calc_hist(frame):
h,w,_ = frame.shape
temp = cv2.cvtColor(frame,cv2.COLOR_BGR2HSV)
hist = cv2.calcHist([temp],[0,1,2],None,[12,5,5],[0,256,0,256,0,256])
hist = hist.flatten()
hist /= h*w
return hist
def similarity(a1,a2):
## compute similarity between frames.
#temp = np.concatenate((a1,a2),axis=1)
temp = np.vstack((a1,a2))
#print(temp)
#print(temp.shape)
s = temp.min(axis=0)
#print(s.shape)
si = np.sum(s)
#print(si)
return si
def ekf(total_frames):
## extract key frames from total frames.
## First cluster.
## Second ekf.
centers_d = {}
result = []
for i in range(len(total_frames)):
temp = 0.0
if len(centers_d) < 1:
centers_d[i] = [total_frames[i],i]
else:
centers = list(centers_d.keys())
for index,each in enumerate(centers):
ind = -1
t_si = similarity(total_frames[i],centers_d[each][0])
#print(t_si)
if t_si < 0.8:
continue
elif t_si > temp:
temp = t_si
ind = index
else:
continue
if temp > 0.8 and ind != -1:
centers_d[centers[ind]].append(i)
length = len(centers_d[centers[ind]]) -1
c_old = centers_d[centers[ind]][0] * length
c_new = (c_old + total_frames[i])/(length+1)
centers_d[centers[ind]][0] = c_new
else:
centers_d[i] = [total_frames[i],i]
cks = list(centers_d.keys())
for index,each in enumerate(cks):
if len(centers_d[each]) <=6:
result.extend(centers_d[each][1:])
else:
temp = []
accordence = {}
c = centers_d[each][0]
for jindex,jeach in enumerate(centers_d[each][1:]):
accordence[jindex] = jeach
tempsi = similarity(c,total_frames[jeach])
temp.append(tempsi)
oktemp = np.argsort(temp).tolist()
print('oktemp {}'.format(oktemp))
print('accordence: {}'.format(accordence))
for i in range(5):
oktemp[i] = accordence[oktemp[i]]
result.extend(oktemp[:5])
return centers_d,sorted(result)
if __name__ == '__main__':
pool = multiprocessing.Pool(processes=10)
video_name = 'test.mp4'
total_frames = video2frame(video_name)
print("there are {} frames in video".format(len(total_frames)))
h,w,_ = total_frames[0].shape
hist = pool.map(calc_hist,total_frames)
print('hist.shape: {}'.format(hist[0].shape))
#print('hist[0]: {}'.format(hist[0]))
si = similarity(hist[80],hist[90])
print('similarity between two frames: {}'.format(si))
#print((hist[1]+hist[2]+hist[3])/3)
cents,results = ekf(hist)
print(len(cents),results)
to_show = cv2.cvtColor(total_frames[cents[0][-1]],cv2.COLOR_BGR2RGB)
plt.imshow(to_show)
plt.show()
如果发现了bug或者别的问题,可以留言或私信。