class Extract_Feature(object):
def __init__(self, video_model_path = 'savevideo/netvladmodel.ckpt', audio_model_path = 'saveaudio/netvladmodelforaudio.ckpt'):
self.video_model_path = video_model_path
self.audio_model_path = audio_model_path
self.videosess = ''
self.audiosess = ''
return
def __del__(self):
self.release_model()
return
def release_model(self):
if self.videosess:
with self.videosess.as_default():
self.videosess.close()
elif self.audiosess:
with self.audiosess.as_default():
self.audiosess.close()
def set_model_path(self, video_model_path, audio_model_path):
self.video_model_path = video_model_path
self.audio_model_path = audio_model_path
return
def load_model_video(self):
self.video_data = tf.placeholder(dtype=tf.float32, shape=[VIDEO_MAX_SAMPLES,FEATURE_SIZE_VIDEO], name='video_data')
# tf.reset_default_graph()
self.video_out = VladNet.network_video(self.video_data)
self.video_out = tf.nn.softmax(self.video_out)
self.videorestorer = tf.train.Saver()
self.videosess = tf.Session()
with self.videosess.as_default():
self.videorestorer.restore(self.videosess, self.video_model_path)
return
def load_model_audio(self):
self.audio_data = tf.placeholder(dtype=tf.float32, shape=[AUDIO_MAX_SAMPLES,FEATURE_SIZE_AUDIO], name='audio_data')
# tf.reset_default_graph()
self.audio_out = VladNet.network_audio(self.audio_data)
self.audio_out = tf.nn.softmax(self.audio_out)
self.audiorestorer = tf.train.Saver()
self.audiosess = tf.Session()
with self.audiosess.as_default():
self.audiorestorer.restore(self.audiosess, self.audio_model_path)
return
#注释掉这一部分是分开处理的代码
# def videoclassify(self,video_feature):
# feed_dict_video = {}
# feed_dict_video[self.video_data] = video_feature
# with self.videosess.as_default():
# video_out = self.videosess.run(self.video_out,feed_dict=feed_dict_video)
# return video_out
# def audioclassify(self,audio_feature):
# feed_dict_audio = {}
# feed_dict_audio[self.audio_data] = audio_feature
# with self.audiosess.as_default():
# audio_out = self.audiosess.run(self.audio_out,feed_dict=feed_dict_audio)
# return audio_out
#这里把功能合成了一个
def classification(self,video_feature,audio_feature):
feed_dict_video = {}
feed_dict_audio = {}
feed_dict_video[self.video_data] = video_feature
feed_dict_audio[self.audio_data] = audio_feature
with self.videosess.as_default():
video_out = self.videosess.run(self.video_out,feed_dict=feed_dict_video)
with self.audiosess.as_default():
audio_out = self.audiosess.run(self.audio_out,feed_dict=feed_dict_audio)
fusion_out = np.mean((video_out,audio_out),axis=0)
class_index = np.argmax(np.array(fusion_out),axis=1)
class_index = int(class_index)
class_name = Classes[class_index]
return class_name,video_out,audio_out
加载模型和处理特征拆分
最新推荐文章于 2023-07-16 03:11:28 发布