【图像生成】ACGAN

1. 框架

在这里插入图片描述

2. main

# Author:yx
# Time: 2022.07.15
# Function: ACGAN ISB(background、muscle、nerve)

import argparse
import torch
from torch.utils.data import DataLoader
from torch import autograd, optim
from torchvision.transforms import transforms
from torchvision import datasets
import numpy as np
from ACGAN import *
import os
import json
import sys
from tqdm import tqdm
from torch.autograd import Variable
from ShowFeatureMap import *

# parameter set
def getArgs():
    parse = argparse.ArgumentParser()
    parse.add_argument("--action", type=str, default="train&val", help="train/val")
    parse.add_argument("--rootpath", type=str, default='/media/yuanxingWorkSpace/ImageAugmentation/ACGAN/ISBClassifyLabelData', help="path")
    parse.add_argument("--epoch", type=int, default=100)
    parse.add_argument('--arch', '-a', metavar='ARCH', default='ACGAN', help='ACGAN')
    parse.add_argument("--batch_size", type=int, default=1)
    parse.add_argument("--shuffle", default= True)
    parse.add_argument('--dataset', default='ISBClassifyLabelData', help='ISBClassifyLabelData')
    parse.add_argument('--lr', type=float, default=0.0001, help='Learning Rate. Default=0.001')
    parse.add_argument("--weight_decay", "--wd", default=0, type=float, help="Weight decay, Default: 1e-4")
    parse.add_argument("--momentum", default=0.9, type=float, help="Momentum, Default: 0.9")
    parse.add_argument("--step", type=int, default=20, help="Sets the learning rate to the initial LR decayed by momentum every n epochs, Default: n=10")
    parse.add_argument("--latent_dim", type=int, default=100)   # noise
    parse.add_argument("--n_classes", type=int, default=3)      # classify num
    args = parse.parse_args()
    return args


def getDataset(args):
    data_transform = {
        "train": transforms.Compose([
            transforms.Grayscale(num_output_channels=1),  
            transforms.CenterCrop(480),
            transforms.ToTensor(),]),
                                     
        "val": transforms.Compose([
            transforms.Grayscale(num_output_channels=1),  
            transforms.CenterCrop(480),
            transforms.ToTensor(),])}
    

    assert os.path.exists(args.rootpath), "{} path does not exist.".format(args.rootpath)
    train_dataset = datasets.ImageFolder(root=os.path.join(args.rootpath, "train"),transform=data_transform["train"])
    ISB_list = train_dataset.class_to_idx
    cla_dict = dict((val, key) for key, val in ISB_list.items())
    # write dict into json file
    json_str = json.dumps(cla_dict, indent=4)
    with open('class_indices.json', 'w') as json_file:
        json_file.write(json_str)

    nw = min([os.cpu_count(), args.batch_size if args.batch_size > 1 else 0, 4])  # number of workers
    print('Using {} dataloader workers every process'.format(nw))

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size, shuffle=True,
                                               num_workers=nw)

    validate_dataset = datasets.ImageFolder(root=os.path.join(args.rootpath, "val"),
                                            transform=data_transform["val"])

    val_loader = torch.utils.data.DataLoader(validate_dataset,
                                                batch_size=args.batch_size, shuffle=False,
                                                num_workers=nw)

    return train_loader, val_loader


def getModel(args):
    if args.arch == 'ACGAN':
        Gmodel = Generator()
        Dmodel = Discriminator()
    return Gmodel, Dmodel
 

# train
def train(Dmodel, Gmodel, adversarial_loss, auxiliary_loss, Doptimizer, Goptimizer, train_loader, val_loader, args):
    num_epochs = args.epoch
    for epoch in range(1, num_epochs+1):

        Dmodel = Dmodel.train()
        Gmodel = Gmodel.train()

        # FloatTensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
        # LongTensor = torch.cuda.LongTensor if cuda else torch.LongTensor

        train_bar = tqdm(train_loader, file=sys.stdout)
        for i, (imgs, labels) in enumerate(train_bar):
            batch_size = imgs.shape[0]
            
            # Adversarial ground truths   放入cuda
            valid =torch.FloatTensor(batch_size, 1).fill_(1.0).to(device)
            fake = torch.FloatTensor(batch_size, 1).fill_(0.0).to(device)
            # valid和fake代表不同的值 训练D时让真实数据值和valid接近  伪造数据值和fake接近
            # Configure input
            real_imgs = torch.FloatTensor(imgs).to(device)
            labels = torch.LongTensor(labels).to(device)
            # -----------------
            #  Train Generator 先训练生成器 得到标签
            # -----------------

            Goptimizer.zero_grad()

            # Sample noise and labels as generator input
            z = torch.FloatTensor(np.random.normal(0, 1, (batch_size, args.latent_dim))).to(device) #均值、标准差、shape(一批多少个z 一个z多少维)
            gen_labels = torch.LongTensor(np.random.randint(0, args.n_classes, batch_size)).to(device)# 返回随机整数[low,high),多少个

            # Generate a batch of images
            gen_imgs = Gmodel(z, gen_labels)
            showGFeatureMap(real_imgs, gen_imgs)
            # Loss measures generator's ability to fool the discriminator
            validity, pred_label = Dmodel(gen_imgs)
            g_loss = 0.5 * (adversarial_loss(validity, valid) + auxiliary_loss(pred_label, gen_labels))

            # 判别器给gen_img的得分和预标签为validity, pred_label    生成器的输入为z,gen_labels    希望真实数据的值为valid
            g_loss.backward()
            Goptimizer.step()

            # ---------------------
            #  Train Discriminator
            # ---------------------
            Doptimizer.zero_grad()
            #鉴别器loss得定义
            # Loss for real images
            real_pred, real_aux = Dmodel(real_imgs) #真实数据评分和标签
            d_real_loss = (adversarial_loss(real_pred, valid) + auxiliary_loss(real_aux, labels)) / 2
            #要让valid和real_pred接近    并且让label(直接取得)和判别的aux一致
            # Loss for fake images
            fake_pred, fake_aux = Dmodel(gen_imgs.detach())
            d_fake_loss = (adversarial_loss(fake_pred, fake) + auxiliary_loss(fake_aux, gen_labels)) / 2
            # 要让fake和fake_pred接近    并且让gen_label(直接取得)和判别的aux一致
            # Total discriminator loss
            d_loss = (d_real_loss + d_fake_loss) / 2

            # Calculate discriminator accuracy   concatenate数组拼接
            pred = np.concatenate([real_aux.data.cpu().numpy(), fake_aux.data.cpu().numpy()], axis=0)   #鉴别器判别的
            gt = np.concatenate([labels.data.cpu().numpy(), gen_labels.data.cpu().numpy()], axis=0)     #实际取得
            d_acc = np.mean(np.argmax(pred, axis=1) == gt)      #判别和取得 相同程度

            d_loss.backward()
            Doptimizer.step()

            print("[Epoch %d/%d] [Batch %d/%d] [D loss: %f, acc: %d%%] [G loss: %f]"% (epoch, num_epochs, i, len(train_loader), d_loss.item(), 100 * d_acc, g_loss.item()))

            
        # valodation
        Gmodel.eval()
        with torch.no_grad():
            for i, (valRealImg, valLabels) in enumerate(val_loader):
                # noise
                valNoise = torch.FloatTensor(np.random.normal(0, 1, (1, 100))).to(device)
                # lables
                valLabels = torch.LongTensor(valLabels).to(device)
                valGImg =  Gmodel(valNoise, valLabels) 
                showValGFeatureMap(valRealImg, valGImg)
        torch.save(Gmodel.state_dict(), r"/media/yuanxingWorkSpace/ImageAugmentation/ACGAN/ISB-ACGAN/Gmodel/" + str(args.arch) + '_' + str(args.batch_size) + '_' + str(args.epoch) + '.pth')
        


if __name__ == '__main__':

    os.environ["CUDA_VISIBLE_DEVICES"] = "0, 1"

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(device)

    args = getArgs()

    print('**************************')
    print('models:%s,\nepoch:%s,\nbatch size:%s\ndataset:%s' % \
          (args.arch, args.epoch, args.batch_size, args.dataset))
    print('**************************')   
    Gmodel, Dmodel = getModel(args)
    if torch.cuda.device_count() > 1:
        print("Let's use", torch.cuda.device_count(), "GPUs!")
        Dmodel = nn.DataParallel(Dmodel, device_ids=[0, 1])      
        Gmodel = nn.DataParallel(Gmodel, device_ids=[0, 1]) 
    Gmodel.to(device)
    Dmodel.to(device)

    train_loader, val_loader = getDataset(args)

    # Optimizers
    Goptimizer = optim.Adam(Gmodel.parameters(), lr=args.lr, betas=(0.5, 0.999))
    Doptimizer = optim.Adam(Dmodel.parameters(), lr=args.lr, betas=(0.5, 0.999))

    # Loss functions
    adversarial_loss = torch.nn.BCELoss()
    auxiliary_loss = torch.nn.CrossEntropyLoss()

    
    if 'train' in args.action:
        train(Dmodel, Gmodel, adversarial_loss, auxiliary_loss, Doptimizer, Goptimizer, train_loader, val_loader, args)

2. ACGAN

import torch.nn as nn
import torch.nn.functional as F
import torch




class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

        self.label_emb = nn.Embedding(3, 100)    #词嵌入(类别数,维度)  ACGAN的条件信息
        

        self.init_size = 480 // 4  # Initial size before upsampling 整除
        self.l1 = nn.Sequential(nn.Linear(100, 128 * self.init_size ** 2))   #输入100维噪声  全连接层


        self.conv_blocks = nn.Sequential(       #CNN用于GAN中要变动
            nn.BatchNorm2d(128),                #卷积的输出通道数 128  (理解为128个卷积核 128片)
            nn.Upsample(scale_factor=2),        #上采样  指定输出为输入的多少倍数
            nn.Conv2d(128, 128, 3, stride=1, padding=1),    #nn.Conv2d(in_channel, out_channel, 3, stride, 1, bias=False)
            nn.BatchNorm2d(128, 0.8),           #0.8为使数值稳定而加到分母上的值
            nn.LeakyReLU(0.2, inplace=True),    #0.2=控制负斜率的角度,inplace-选择是否进行覆盖运算
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 64, 3, stride=1, padding=1),
            nn.BatchNorm2d(64, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 1, 3, stride=1, padding=1),
            nn.Tanh(),
        )

    def forward(self, noise, labels):
       
        gen_input = torch.mul(self.label_emb(labels), noise)    #torch.mul(input, value, out=None)
    

        out = self.l1(gen_input)                #经过l1网络全连接层
        out = out.view(out.shape[0], 128, self.init_size, self.init_size)   #通道数
        img = self.conv_blocks(out)             #卷积块
        return img


class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        def discriminator_block(in_filters, out_filters, bn=True):
            """Returns layers of each discriminator block"""
            block = [nn.Conv2d(in_filters, out_filters, 3, 2, 1), nn.LeakyReLU(0.2, inplace=True), nn.Dropout2d(0.25)]
            if bn:
                block.append(nn.BatchNorm2d(out_filters, 0.8))
            return block

        self.conv_blocks = nn.Sequential(
            *discriminator_block(1, 16, bn=False), #discriminator_block在81行
            *discriminator_block(16, 32),
            *discriminator_block(32, 64),
            *discriminator_block(64, 128),
        )

        # The height and width of downsampled image
        ds_size = 480 // 2 ** 4

        # Output layers
        self.adv_layer = nn.Sequential(nn.Linear(128 * ds_size ** 2, 1), nn.Sigmoid())
        self.aux_layer = nn.Sequential(nn.Linear(128 * ds_size ** 2, 3), nn.Softmax())

    def forward(self, img):
        out = self.conv_blocks(img) #卷积块
        out = out.view(out.shape[0], -1)    #平铺
        validity = self.adv_layer(out)      #对抗的结果
        label = self.aux_layer(out)         #分类的结果

        return validity, label

3. ShowFeatureMap

import numpy as np
import imageio


def showGFeatureMap(real, featureMap):
    # batch_size = 1时squeeze(0) , batch_size>1时,squeeze(1),
    G_real = real.squeeze(0)
    featureMap = featureMap.squeeze(0)   
    # 若显示中间层的多通道数,可采用循环提取batch_size中每张图片。
    G_real = G_real.detach().cpu().numpy()
    featureMap = featureMap.detach().cpu().numpy()
    concatImg = np.concatenate((G_real, featureMap), axis=2)  
    featureMapNum = featureMap.shape[0]
    
    for index in range(1, featureMapNum+1):
        imageio.imwrite("/media/yuanxingWorkSpace/ImageAugmentation/ACGAN/ISB-ACGAN/GFeatureMap/" +  str(index) + ".png",
        (concatImg[index-1]*255).astype("uint8"))

    

def showValGFeatureMap(real, featureMap):
    # batch_size = 1时squeeze(0) , batch_size>1时,squeeze(1),
    G_real = real.squeeze(0)
    featureMap = featureMap.squeeze(0)   
    # 若显示中间层的多通道数,可采用循环提取batch_size中每张图片。
    G_real = G_real.detach().cpu().numpy()
    featureMap = featureMap.detach().cpu().numpy()
    concatImg = np.concatenate((G_real, featureMap), axis=2)  
    featureMapNum = featureMap.shape[0]
    
    for index in range(1, featureMapNum+1):
        imageio.imwrite("/media/yuanxingWorkSpace/ImageAugmentation/ACGAN/ISB-ACGAN/valGFeatureMap/" +  str(index) + ".png",
        (concatImg[index-1]*255).astype("uint8"))
  • 5
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

只搬烫手的砖

你的鼓励将是我最大的动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值