import torch
from models.models import Model
import cv2
from PIL import Image
import numpy as np
from matplotlib.animation import FFMpegWriter
import time
import matplotlib.pyplot as plt
from torchvision.transforms import functional
exp_name = './xxxx_results'
dataRoot = 'xxxx.mp4'
model_path = './checkpoint_best.pth'
def pre_image(image):
image = Image.fromarray(cv2.cvtColor(image,cv2.COLOR_BGR2RGB))
input_image = image.copy()
# image.show()
height, width = image.size[1], image.size[0]
height = round(height / 16) * 16
width = round(width / 16) * 16
image = image.resize((width, height), Image.BILINEAR)
image = functional.to_tensor(image)
image = functional.normalize(image, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
return input_image,torch.unsqueeze(image,0)
if __name__ == '__main__':
device = torch.device('cuda:0')
# load model
model=Model()
checkpoint = torch.load(model_path)
model.load_state_dict(checkpoint['model'])
model.cuda()
model.eval()
# input video
video = cv2.VideoCapture(dataRoot)
fps = video.get(cv2.CAP_PROP_FPS)
print(fps)
frameCount = video.get(cv2.CAP_PROP_FRAME_COUNT)
print(frameCount)
size = (int(video.get(cv2.CAP_PROP_FRAME_WIDTH)), int(video.get(cv2.CAP_PROP_FRAME_HEIGHT)))
# metadata = dict(title='Video Test', artist='Matplotlib', comment='Movie support!')
# writer = FFMpegWriter(fps=25, metadata=metadata)
# videoWriter = cv2.VideoWriter('trans.mp4', cv2.VideoWriter_fourcc(*'MP4V'), fps, size)
success, frame = video.read()
index = 1
figure = plt.figure()
while success:
# time1=time.time()
src_image,frame = pre_image(frame)
images = frame.to(device)
# time1 = time.time()
# ground truth
# gt_path = dataRoot + '/den/' + filename_no_ext + '.csv'
# predict
dense_map,atten_map = model(images)
# test = time.time() - time1
dense_map = dense_map.cpu().data.numpy()[0,0,:,:]
# test=time.time()-time1
dense_pred_count = np.sum(dense_map)
dense_map = dense_map/np.max(dense_map+1e-20)
# cv2.imshow("image", dense_map)
# cv2.waitKey(0)
plt.subplot(121)
plt.imshow(src_image)
# plt.title('original image')
plt.axis('off')
plt.subplot(122)
plt.imshow(dense_map)
# plt.title('dense map')
plt.text(25, 25, 'pred crowd count:%.4f ' % dense_pred_count, fontdict={'size': 10, 'color': 'red'})
plt.axis('off')
plt.tight_layout(pad=0.3, w_pad=0, h_pad=1)
# anni=animation.FuncAnimation(fig, animate, init_func=init,frames=200, interval=20, blit=True)
# anim.save('sin.gif', fps=75, writer='imagemagick')
plt.savefig(exp_name + '/'+ str('%05d' % index) + '_' + str(int(dense_pred_count)) + '.png', bbox_inches='tight', pad_inches=0, dpi=150)
# plt.show()
plt.clf()
success, frame = video.read()
index += 1
video.release()