import argparse
import os
import sys
import torch
from PIL import Image
import torchvision.transforms as transforms
from networks.drn_seg import DRNSub
from utils.tools import*from utils.visualize import*defload_classifier(model_path, gpu_id):if torch.cuda.is_available()and gpu_id !=-1:
device ='cuda:{}'.format(gpu_id)else:
device ='cpu'
model = DRNSub(1)#本地networks 封装的模型 drn_seg
state_dict = torch.load(model_path, map_location='cpu')
model.load_state_dict(state_dict['model'])
model.to(device)
model.device = device
model.eval()return model
tf = transforms.Compose([transforms.ToTensor(),
transforms.Normalize(mean=[0.485,0.456,0.406],
std=[0.229,0.224,0.225])])defclassify_fake(model, img_path, no_crop=False,
model_file='utils/dlib_face_detector/mmod_human_face_detector.dat'):# Data preprocessing
im_w, im_h = Image.open(img_path).size
if no_crop:
face = Image.open(img_path).convert('RGB')#convert()函数,用于不同模式图像之间的转换。PIL中有九种不同模式,分别为1,L,P,RGB,RGBA,CMYK,YCbCr,I,Felse:
faces = face_detection(img_path, verbose=False, model_file=model_file)#人脸检测,是检测出图片中包含的正面人脸.iflen(faces)==0:print("no face detected by dlib, exiting")
sys.exit()
face, box = faces[0]
face = resize_shorter_side(face,400)[0]#本地tools.py中的函数.缩小图片,设定边长不超过400,[0] 返回img.resize((w, h), Image.BICUBIC)
face_tens = tf(face).to(model.device)#转变为ToTensor类型并标准化,加载到GPU或者CPU上# Predictionwith torch.no_grad():
prob = model(face_tens.unsqueeze(0))[0].sigmoid().cpu().item()#unsqueeze()函数的功能是在tensor的某个维度上添加一个维数为1的维度#在0维度上添加一个维数return prob
if __name__ =='__main__':
parser = argparse.ArgumentParser()# 通过命令行运行Python脚本时,可以通过ArgumentParser来高效地接受并解析命令行参数#添加命令
parser.add_argument("--input_path", required=True,help="the model input")
parser.add_argument("--model_path", required=True,help="path to the drn model")
parser.add_argument("--gpu_id", default='0',help="the id of the gpu to run model on")
parser.add_argument("--no_crop",
action="store_true",help="do not use a face detector, instead run on the full input image")
args = parser.parse_args()#把parser中设置的所有"add_argument"给返回到args子类实例当中
model = load_classifier(args.model_path, args.gpu_id)#加载模型
prob = classify_fake(model, args.input_path, args.no_crop)#print("Probibility being modified by Photoshop FAL: {:.2f}%".format(prob*100))