对https://github.com/xuqiantong/GAN-Metrics做出部分修改。
下面给出基于Pytorch实现的GAN评价指标计算,包括的GAN评价指标如下:
- Inception Score
- Mode Score
- Kernel MMD
- Wasserstein distance
- Frechet Inception Distance
- 1-Nearest Neighbor classifier
将生成器生成的图片保存目录为:/home/wdong/PycharmProjects/GAN_metrices/results/fake/0/
真实图片保存目录为:/home/wdong/PycharmProjects/GAN_metrices/results/real/0/
main.py
import metrics
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--dataset', default='folder' ,help='cifar10 | lsun | imagenet | folder | lfw | fake')
parser.add_argument('--dataroot_real', default='/real/', help='path to dataset real')
parser.add_argument('--dataroot_fake', default='./fake', help='path to dataset fake')
parser.add_argument('--workers', type=int, help='number of data loading workers', default=4)
parser.add_argument('--batchSize', type=int, default=16, help='input batch size')
parser.add_argument('--imageSize', type=int, default=256, help='the height / width of the input image to network')
parser.add_argument('--outf', default='/home/wdong/PycharmProjects/GAN_metrices/results', help='folder to output images and model checkpoints')
opt = parser.parse_args()
#inception_v3
s = metrics.compute_score_raw(opt.dataset, opt.imageSize, opt.dataroot_real, opt.batchSize, opt.outf+'/real', opt.outf+'/fake',
opt.dataroot_fake, conv_model='inception_v3', workers=int(opt.workers))
for i in range(len(s)):
print(i,"=", s[i])
metrics.py
import math
import os
import math
import numpy as np
import ot
import torch
from torch import nn
import torch.nn.functional as F
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
import torchvision.models as models
import pdb
from tqdm import tqdm
from scipy.stats import entropy
from numpy.linalg import norm
from scipy import linalg
def giveName(iter): # 7 digit name.
ans = str(iter)
return ans.zfill(7)
def make_dataset(dataset, dataroot, imageSize):
"""
:param dataset: must be in 'cifar10 | lsun | imagenet | folder | lfw | fake'
:return: pytorch dataset for DataLoader to utilize
"""
if dataset in ['imagenet', 'folder', 'lfw']:
print(os.getcwd()+dataroot)
# folder dataset
#dataset = dset.ImageFolder(root=dataroot,
dataset = dset.ImageFolder(root=os.getcwd()+dataroot,
transform=transforms.Compose([
transforms.Resize(imageSize),
#transforms.CenterCrop(imageSize),
transforms.ToTensor(),
transforms.Normalize(
(0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]))
elif dataset == 'lsun':
dataset = dset.LSUN(db_path=dataroot, classes=['bedroom_train'],
transform=transforms.Compose([
transforms.Resize(imageSize),
transforms.CenterCrop(imageSize),
transforms.ToTensor(),
transforms.Normalize(