import torch
from torch.nn import functional as F
from model.RIFE import Model
import warnings
warnings.filterwarnings("ignore")
import argparse
import cv2
from utils.flow_viz import flow_to_image
import matplotlib.pyplot as plt
import os
import glob
import tqdm
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.set_grad_enabled(False)
if torch.cuda.is_available():
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True
def viz(img1, img2, flo, outPath):
# img1 = img1[0].permute(1,2,0).cpu().numpy()
# img2 = img2[0].permute(1,2,0).cpu().numpy()
flo = flo[0].permute(1,2,0).cpu().numpy()
# map flow to rgb image
flo = flow_to_image(flo)
fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(20, 4))
ax1.set_title('input image1')
ax1.imshow(img1.astype(int))
ax2.set_title('input image2')
ax2.imshow(img2.astype(int))
ax3.set_title('estimated optical flow')
ax3.imshow(flo)
# plt.show()
plt.savefig(outPath, bbox_inches='tight') # 'optical_flow_comparison.png' 可以替换为您想要的文件名
plt.close()
if __name__ == "__main__":
model = Model(arbitrary=True)
model.load_model('RIFE_m_train_log')
model.eval()
model.device()
images = glob.glob(os.path.join("../dataset", '*.png')) + \
glob.glob(os.path.join("../dataset", '*.jpg'))
images = sorted(images)
for i, (imfile1, imfile2) in tqdm.tqdm(enumerate(zip(images[:-1], images[1:]))):
img0_ = cv2.imread(imfile1, cv2.IMREAD_UNCHANGED)
img1_ = cv2.imread(imfile2, cv2.IMREAD_UNCHANGED)
img0 = torch.from_numpy(img0_.copy()).permute(2, 0, 1) / 255.0
img1 = torch.from_numpy(img1_.copy()).permute(2, 0, 1) / 255.0
img = torch.cat((img0, img1), 0).to(torch.float).unsqueeze(0).cuda()
n, c, h, w = img.shape
ph = ((h - 1) // 32 + 1) * 32
pw = ((w - 1) // 32 + 1) * 32
padding = (0, pw - w, 0, ph - h)
img = F.pad(img, padding)
# print(f"img size {img.size()}")
with torch.no_grad():
flow = model.flownet(img, timestep=1.0, returnflow=True)[:, :2] # will get flow1->0
# flow = flow[0].permute(1,2,0).cpu().numpy()
print(f"flow size : {flow.size()}")
viz(img0_, img1_, flow, f"{i}_.png")
- 1.
- 2.
- 3.
- 4.
- 5.
- 6.
- 7.
- 8.
- 9.
- 10.
- 11.
- 12.
- 13.
- 14.
- 15.
- 16.
- 17.
- 18.
- 19.
- 20.
- 21.
- 22.
- 23.
- 24.
- 25.
- 26.
- 27.
- 28.
- 29.
- 30.
- 31.
- 32.
- 33.
- 34.
- 35.
- 36.
- 37.
- 38.
- 39.
- 40.
- 41.
- 42.
- 43.
- 44.
- 45.
- 46.
- 47.
- 48.
- 49.
- 50.
- 51.
- 52.
- 53.
- 54.
- 55.
- 56.
- 57.
- 58.
- 59.
- 60.
- 61.
- 62.
- 63.
- 64.
- 65.
- 66.
- 67.
- 68.
- 69.
- 70.
- 71.