加载pytorch格式的dcgan已训练网络并测试

建立的dcgan网络通过之后可以得到生成网络netG.pkl文件和鉴别网络netD.pkl文件,加载这些网络输入参数即可得到结果。这里显示了生成网络的加载及测试。同时也调用了网络结构显示的库,以pdf的形式显示所加载的网络的具体结构。

'''
Created on 2020年10月27日

@author: afeng
'''

import torch
import torchvision.utils as vutils 
import numpy as np
from matplotlib import pyplot as plt
from torchviz import make_dot
from dcgan_facies_model import Generator
from tensorflow.python.keras.layers import noise

def loadModel(fileName):   
    trained_netG=torch.load(fileName)
    #print(trained_netG)
    return trained_netG

def testTrainNetG(filename):    
    loaded_netG = loadModel(filename)
    #print(loaded_netG)

    b_size=64
    nz=100
    device=torch.device('cuda:0')
    noise = torch.randn(b_size, nz, 1, 1, device=device)
    pred = loaded_netG(noise)
    print(pred.shape)

    #plt.imshow(np.transpose(vutils.make_grid(pred.to(device)[:64], padding=5, normalize=True).cpu().detach().numpy(),(1,2,0)))
    plt.imshow(np.transpose(vutils.make_grid(pred[0].to(device)[:64], padding=5, normalize=True).cpu().detach().numpy(),(1,2,0)))
    plt.show()    
    #saveNet2PDFFile(loaded_netG, noise)    

def saveNet2PDFFile(loaded_netG, noise):
    #plot the net model as pdf file 
    net_plot = make_dot(loaded_netG(noise), params=dict(loaded_netG.named_parameters()))
    #net_plot = make_dot(loaded_netG(noise))
    net_plot.view("loaded_net")


if __name__ == '__main__':
    filename='trained_netG.pkl'
    testTrainNetG(filename)
    pass

 下图展示了所加载网络的结构

netG网络的输出结果不再展示,和mnist的手写图像差不多。

 

  • 1
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

oceanstonetree

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值