代码
import torch
from model import MattingNetwork
from torchvision import transforms
from PIL import Image
import cv2
from inference_utils import OneFrameReader
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt
model = MattingNetwork('mobilenetv3').eval().cuda() # or "resnet50"
model.load_state_dict(torch.load('rvm_mobilenetv3.pth'))
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor
from inference_utils import VideoReader, VideoWriter
bgr = torch.tensor([.47, 1, .6]).view(3, 1, 1).cuda() # Green background.
rec = [None] * 4 # Initial recurrent states.
downsample_ratio = 0.25 # Adjust based on your video.
import numpy as np
import cv2 as cv
from torchvision.transforms import ToTensor
totensor = ToTensor()
cap = cv.VideoCapture('videos/qihang/qihang.mp4')
while cap.isOpened():
ret, frame = cap.read()
# if frame is read correctly ret is True
if not ret:
print("Can't receive frame (stream end?). Exiting ...")
break
# gray = cv.cvtColor(frame, cv.COLOR_BGR2GRAY)
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frame = totensor(frame)
frame = torch.unsqueeze(frame, 0)
fgr, pha, *rec = model(frame.cuda(), *rec, downsample_ratio) # Cycle the recurrent states.
com = fgr * pha + bgr * (1 - pha) # Composite to green background.
com = com.mul(255).byte()
com = torch.squeeze(com, 0) # c h w
com = torch.permute(com, (1,2,0)) # h w c
com = com.detach().cpu().numpy()
com = com[..., ::-1] # bgr2rgb
cv.imshow('com', com)
if cv.waitKey(1) == ord('q'):
break
cap.release()
cv.destroyAllWindows()
备注
1,如果是webcam, cv.VideoCapture(‘videos/qihang/qihang.mp4’) 的参数改为’0’。
2, opencv 的cv2.imread, cv2.imwrite, cv2.imshow 的图片的通道模式都是bgr。
3,刚接触一个代码模块时,要先用样例跑,而不是自己先敲, 可能会敲漏而跑不出结果,到处找原因很麻烦。