预训练图像分类模型预测
此为 datawhale 的公开教程
教程地址:github
1. 调用 pytorch 中 model 加载模型
model = models.resnet18(pretrained=True)
model = model.eval()
model = model.to(device)
Notes:
- model.eval() 通常在对模型进行验证时需要设置的。此设置的目的是:在模型中有BatchNormal以及Dropout层时,取消BN和Dropout层的效果,以达到所有数据都进行测试的效果。
- BatchNormal 以及 Dropout 均为正则化手段,一定程度上可以处理过拟合的情况
- model.to(device) 是将模型转移到指定的设备上
2. 图像预处理以及测试图片
from torchvision import transforms
from PIL import Image
test_transform = transforms.Compose([transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
img_pil = Image.open(img_path)
Notes:
- tansformams 可以对图像进行一系列处理包括,设置图片大小,将图片对象转为 Tensor 对象等
- 使用 Image 打开图片
3. 调用摄像头获取图像(视频)
import cv2
import time
# 获取摄像头,传入0表示获取系统默认摄像头
cap = cv2.VideoCapture(1)
# 打开cap
cap.open(0)
# 无限循环,直到break被触发
while cap.isOpened():
# 获取画面
success, frame = cap.read()
if not success:
print('Error')
break
## !!!处理帧函数
frame = process_frame(frame)
# 展示处理后的三通道图像
cv2.imshow('my_window',frame)
if cv2.waitKey(1) in [ord('q'),27]: # 按键盘上的q或esc退出(在英文输入法下)
break
# 关闭摄像头
cap.release()
# 关闭图像窗口
cv2.destroyAllWindows()
Notes:
- 使用 cv2 库调用摄像头