GAN 评估

该博客介绍了基于Pytorch实现的多种GAN评估指标,包括Inception Score, Mode Score, Kernel MMD, Wasserstein距离和Frechet Inception Distance,并提到了1-Nearest Neighbor classifier。文章提供了生成器生成图片和真实图片的保存目录。" 115813846,10548705,Matlab绘制十字坐标系指南,"['Matlab', '图形绘制', '坐标轴']
摘要由CSDN通过智能技术生成

https://github.com/xuqiantong/GAN-Metrics做出部分修改。

下面给出基于Pytorch实现的GAN评价指标计算,包括的GAN评价指标如下:

  • Inception Score
  • Mode Score
  • Kernel MMD
  • Wasserstein distance
  • Frechet Inception Distance
  • 1-Nearest Neighbor classifier

将生成器生成的图片保存目录为:/home/wdong/PycharmProjects/GAN_metrices/results/fake/0/

 真实图片保存目录为:/home/wdong/PycharmProjects/GAN_metrices/results/real/0/

main.py

import metrics

import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--dataset',  default='folder' ,help='cifar10 | lsun | imagenet | folder | lfw | fake')
parser.add_argument('--dataroot_real', default='/real/', help='path to dataset real')
parser.add_argument('--dataroot_fake', default='./fake', help='path to dataset fake')
parser.add_argument('--workers', type=int, help='number of data loading workers', default=4)
parser.add_argument('--batchSize', type=int, default=16, help='input batch size')
parser.add_argument('--imageSize', type=int, default=256, help='the height / width of the input image to network')
parser.add_argument('--outf', default='/home/wdong/PycharmProjects/GAN_metrices/results', help='folder to output images and model checkpoints')

opt = parser.parse_args()

#inception_v3
s = metrics.compute_score_raw(opt.dataset, opt.imageSize, opt.dataroot_real,  opt.batchSize, opt.outf+'/real', opt.outf+'/fake',
                                 opt.dataroot_fake, conv_model='inception_v3', workers=int(opt.workers))
for i in range(len(s)):
    print(i,"=", s[i])

 metrics.py

import math
import os

import math

import numpy as np
import ot
import torch
from torch import nn
import torch.nn.functional as F
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
import torchvision.models as models
import pdb
from tqdm import tqdm

from scipy.stats import entropy
from numpy.linalg import norm
from scipy import linalg


def giveName(iter):  # 7 digit name.
    ans = str(iter)
    return ans.zfill(7)


def make_dataset(dataset, dataroot, imageSize):
    """
    :param dataset: must be in 'cifar10 | lsun | imagenet | folder | lfw | fake'
    :return: pytorch dataset for DataLoader to utilize
    """
    if dataset in ['imagenet', 'folder', 'lfw']:
        print(os.getcwd()+dataroot)
        # folder dataset
        #dataset = dset.ImageFolder(root=dataroot,
        dataset = dset.ImageFolder(root=os.getcwd()+dataroot,
                                   transform=transforms.Compose([
                                       transforms.Resize(imageSize),
                                       #transforms.CenterCrop(imageSize),
                                       transforms.ToTensor(),
                                       transforms.Normalize(
                                           (0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                                   ]))
    elif dataset == 'lsun':
        dataset = dset.LSUN(db_path=dataroot, classes=['bedroom_train'],
                            transform=transforms.Compose([
                                transforms.Resize(imageSize),
                                transforms.CenterCrop(imageSize),
                                transforms.ToTensor(),
                                transforms.Normalize(
                     
评论 21
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值