mxnet mnist 3

# -*- coding: utf-8 -*-
"""
Created on Thu Nov  8 14:32:42 2018

@author: lin
"""
import numpy as np
import os
import struct
import datetime
from tensorflow.examples.tutorials.mnist import input_data
import mxnet as mx
import logging


def make_dcgan_sym(ngf, ndf, nc, no_bias=True, fix_gamma=True, eps=1e-5 + 1e-12):
    BatchNorm = mx.sym.BatchNorm
    # 生成网络G
    # 输入生成网络G的变量,这个是潜在空间
    rand = mx.sym.Variable('rand')

    g1 = mx.sym.Deconvolution(rand, name='g1', kernel=(4,4), num_filter=ngf*8, no_bias=no_bias)
    gbn1 = BatchNorm(g1, name='gbn1', fix_gamma=fix_gamma, eps=eps)
    gact1 = mx.sym.Activation(gbn1, name='gact1', act_type='relu')

    g2 = mx.sym.Deconvolution(gact1, name='g2', kernel=(4,4), stride=(2,2), pad=(1,1), num_filter=ngf*4, no_bias=no_bias)
    gbn2 = BatchNorm(g2, name='gbn2', fix_gamma=fix_gamma, eps=eps)
    gact2 = mx.sym.Activation(gbn2, name='gact2', act_type='relu')

    g3 = mx.sym.Deconvolution(gact2, name='g3', kernel=(4,4), stride=(2,2), pad=(1,1), num_filter=ngf*2, no_bias=no_bias)
    gbn3 = BatchNorm(g3, name='gbn3', fix_gamma=fix_gamma, eps=eps)
    gact3 = mx.sym.Activation(gbn3, name='gact3', act_type='relu')

    g4 = mx.sym.Deconvolution(gact3, name='g4', kernel=(4,4), stride=(2,2), pad=(1,1), num_filter=ngf, no_bias=no_bias)
    gbn4 = BatchNorm(g4, name='gbn4', fix_gamma=fix_gamma, eps=eps)
    gact4 = mx.sym.Activation(gbn4, name='gact4', act_type='relu')

    g5 = mx.sym.Deconvolution(gact4, name='g5', kernel=(4,4), stride=(2,2), pad=(1,1), num_filter=nc, no_bias=no_bias)
    # 生成网络G最后得到一张相片
    gout = mx.sym.Activation(g5, name='gact5', act_type='tanh')

    # 判别网络D,这里里的结构与一般的分类网络区别不大
    data = mx.sym.Variable('data')
    label = mx.sym.Variable('label')

    d1 = mx.sym.Convolution(data, name='d1', kernel=(4,4), stride=(2,2), pad=(1,1), num_filter=ndf, no_bias=no_bias)
    dact1 = mx.sym.LeakyReLU(d1, name='dact1', act_type='leaky', slope=0.2)

    d2 = mx.sym.Convolution(dact1, name='d2', kernel=(4,4), stride=(2,2), pad=(1,1), num_filter=ndf*2, no_bias=no_bias)
    dbn2 = BatchNorm(d2, name='dbn2', fix_gamma=fix_gamma, eps=eps)
    dact2 = mx.sym.LeakyReLU(dbn2, name='dact2', act_type='leaky', slope=0.2)
    d3 = mx.sym.Convolution(dact2, name='d3', kernel=(4,4), stride=(2,2), pad=(1,1), num_filter=ndf*4, no_bias=no_bias)
    dbn3 = BatchNorm(d3, name='dbn3', fix_gamma=fix_gamma, eps=eps)
    dact3 = mx.sym.LeakyReLU(dbn3, name='dact3', act_type='leaky', slope=0.2)

    d4 = mx.sym.Convolution(dact3, name='d4', kernel=(4,4), stride=(2,2), pad=(1,1), num_filter=ndf*8, no_bias=no_bias)
    dbn4 = BatchNorm(d4, name='dbn4', fix_gamma=fix_gamma, eps=eps)
    dact4 = mx.sym.LeakyReLU(dbn4, name='dact4', act_type='leaky', slope=0.2)

    d5 = mx.sym.Convolution(dact4, name='d5', kernel=(4,4), num_filter=1, no_bias=no_bias)
    d5 = mx.sym.Flatten(d5)
    # 用逻辑回归计算最后的loss
    dloss = mx.sym.LogisticRegressionOutput(data=d5, label=label, name='dloss')
    # 返回这G与D这两个网络
    return gout, dloss
    
if __name__ == '__main__':
    logging.basicConfig(level=logging.DEBUG)

    # =============setting============
    dataset = 'mnist'
    imgnet_path = './MNIST_data'
    ndf = 64
    ngf = 64
    nc = 3
    batch_size = 64
    Z = 100
    lr = 0.0002
    beta1 = 0.5
    ctx = mx.gpu(0)
    check_point = False

    symG, symD = make_dcgan_sym(ngf, ndf, nc)
    #mx.viz.plot_network(symG, shape={'rand': (batch_size, 100, 1, 1)}).view()
    #mx.viz.plot_network(symD, shape={'data': (batch_size, nc, 64, 64)}).view()

    # ==============data==============
    if dataset == 'mnist':
        X_train, X_test = get_mnist()
        train_iter = mx.io.NDArrayIter(X_train, batch_size=batch_size)
    elif dataset == 'imagenet':
        train_iter = ImagenetIter(imgnet_path, batch_size, (3, 64, 64))
    rand_iter = RandIter(batch_size, Z)
    label = mx.nd.zeros((batch_size,), ctx=ctx)

    # =============module G=============
    modG = mx.mod.Module(symbol=symG, data_names=('rand',), label_names=None, context=ctx)
    modG.bind(data_shapes=rand_iter.provide_data)
    modG.init_params(initializer=mx.init.Normal(0.02))
    modG.init_optimizer(
        optimizer='adam',
        optimizer_params={
            'learning_rate': lr,
            'wd': 0.,
            'beta1': beta1,
        })
    mods = [modG]

    # =============module D=============
    modD = mx.mod.Module(symbol=symD, data_names=('data',), label_names=('label',), context=ctx)
    modD.bind(data_shapes=train_iter.provide_data,
              label_shapes=[('label', (batch_size,))],
              inputs_need_grad=True)
    modD.init_params(initializer=mx.init.Normal(0.02))
    modD.init_optimizer(
        optimizer='adam',
        optimizer_params={
            'learning_rate': lr,
            'wd': 0.,
            'beta1': beta1,
        })
    mods.append(modD)


    # ============printing==============
    def norm_stat(d):
        return mx.nd.norm(d)/np.sqrt(d.size)
    mon = mx.mon.Monitor(10, norm_stat, pattern=".*output|d1_backward_data", sort=True)
    mon = None
    if mon is not None:
        for mod in mods:
            pass
def facc(label, pred):
    pred = pred.ravel()
    label = label.ravel()
    return ((pred > 0.5) == label).mean()

def fentropy(label, pred):
    pred = pred.ravel()
    label = label.ravel()
    return -(label*np.log(pred+1e-12) + (1.-label)*np.log(1.-pred+1e-12)).mean()        
    

    mG = mx.metric.CustomMetric(fentropy)
    mD = mx.metric.CustomMetric(fentropy)
    mACC = mx.metric.CustomMetric(facc)

    print('Training...')
    stamp =  datetime.now().strftime('%Y_%m_%d-%H_%M')

    # =============train===============
    for epoch in range(100):
        train_iter.reset()
        for t, batch in enumerate(train_iter):
            rbatch = rand_iter.next()

            if mon is not None:
                mon.tic()

            # 首先生成对抗样本
            modG.forward(rbatch, is_train=True)
            outG = modG.get_outputs()

            # update discriminator on fake
            # 这里的负样本label为0,正样本label为1,不像普遍的mnist一样。那么modG就想生成样本label为1的,modD要将modG生成的数据判定为0
            # train_iter(真实样本)中的数据判定为1。
            label[:] = 0
            modD.forward(mx.io.DataBatch(outG, [label]), is_train=True)
            modD.backward()
            #modD.update()
            # 先Copy得到的对抗样本的梯度,要注意是复制不是引用。
            gradD = [[grad.copyto(grad.context) for grad in grads] for grads in modD._exec_group.grad_arrays]

            modD.update_metric(mD, [label])
            modD.update_metric(mACC, [label])

            # update discriminator on real
            # 对真实样本的数据训练
            label[:] = 1
            batch.label = [label]
            modD.forward(batch, is_train=True)
            modD.backward()
            # 对抗样本与真实样本的梯度合到一起建行梯度更新
            for gradsr, gradsf in zip(modD._exec_group.grad_arrays, gradD):
                for gradr, gradf in zip(gradsr, gradsf):
                    gradr += gradf
            modD.update()

            modD.update_metric(mD, [label])
            modD.update_metric(mACC, [label])
            # update generator
            # 更新modG的参数,这里要注意的是,modG想要生成的样本label是1的,所以在modD中用了这个label,就是想生成的样本向label=1靠近。
            # 前向和向后生成输入数据的梯度diffD
            label[:] = 1
            modD.forward(mx.io.DataBatch(outG, [label]), is_train=True)
            modD.backward()
            diffD = modD.get_input_grads()
            # diffD就是modG的loss产生的梯度,用它来向后传播并更新参数。
            modG.backward(diffD)
            modG.update()

            mG.update([label], modD.get_outputs())


            if mon is not None:
                mon.toc_print()

            t += 1
            if t % 10 == 0:
                print('epoch:', epoch, 'iter:', t, 'metric:', mACC.get(), mG.get(), mD.get())
                mACC.reset()
                mG.reset()
                mD.reset()

                visual('gout', outG[0].asnumpy())
                diff = diffD[0].asnumpy()
                diff = (diff - diff.mean())/diff.std()
                visual('diff', diff)
                visual('data', batch.data[0].asnumpy())
                if check_point:
                    print('Saving...')
                    modG.save_params('%s_G_%s-%04d.params'%(dataset, stamp, epoch))
                    modD.save_params('%s_D_%s-%04d.params'%(dataset, stamp, epoch))

 

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
该资源内项目源码是个人的课程设计、毕业设计,代码都测试ok,都是运行成功后才上传资源,答辩评审平均分达到96分,放心下载使用! ## 项目备注 1、该资源内项目代码都经过测试运行成功,功能ok的情况下才上传的,请放心下载使用! 2、本项目适合计算机相关专业(如计科、人工智能、通信工程、自动化、电子信息等)的在校学生、老师或者企业员工下载学习,也适合小白学习进阶,当然也可作为毕设项目、课程设计、作业、项目初期立项演示等。 3、如果基础还行,也可在此代码基础上进行修改,以实现其他功能,也可用于毕设、课设、作业等。 下载后请首先打开README.md文件(如有),仅供学习参考, 切勿用于商业用途。 该资源内项目源码是个人的课程设计,代码都测试ok,都是运行成功后才上传资源,答辩评审平均分达到96分,放心下载使用! ## 项目备注 1、该资源内项目代码都经过测试运行成功,功能ok的情况下才上传的,请放心下载使用! 2、本项目适合计算机相关专业(如计科、人工智能、通信工程、自动化、电子信息等)的在校学生、老师或者企业员工下载学习,也适合小白学习进阶,当然也可作为毕设项目、课程设计、作业、项目初期立项演示等。 3、如果基础还行,也可在此代码基础上进行修改,以实现其他功能,也可用于毕设、课设、作业等。 下载后请首先打开README.md文件(如有),仅供学习参考, 切勿用于商业用途。
该资源内项目源码是个人的课程设计、毕业设计,代码都测试ok,都是运行成功后才上传资源,答辩评审平均分达到96分,放心下载使用! ## 项目备注 1、该资源内项目代码都经过测试运行成功,功能ok的情况下才上传的,请放心下载使用! 2、本项目适合计算机相关专业(如计科、人工智能、通信工程、自动化、电子信息等)的在校学生、老师或者企业员工下载学习,也适合小白学习进阶,当然也可作为毕设项目、课程设计、作业、项目初期立项演示等。 3、如果基础还行,也可在此代码基础上进行修改,以实现其他功能,也可用于毕设、课设、作业等。 下载后请首先打开README.md文件(如有),仅供学习参考, 切勿用于商业用途。 该资源内项目源码是个人的课程设计,代码都测试ok,都是运行成功后才上传资源,答辩评审平均分达到96分,放心下载使用! ## 项目备注 1、该资源内项目代码都经过测试运行成功,功能ok的情况下才上传的,请放心下载使用! 2、本项目适合计算机相关专业(如计科、人工智能、通信工程、自动化、电子信息等)的在校学生、老师或者企业员工下载学习,也适合小白学习进阶,当然也可作为毕设项目、课程设计、作业、项目初期立项演示等。 3、如果基础还行,也可在此代码基础上进行修改,以实现其他功能,也可用于毕设、课设、作业等。 下载后请首先打开README.md文件(如有),仅供学习参考, 切勿用于商业用途。
C语言是一种广泛使用的编程语言,它具有高效、灵活、可移植性强等特点,被广泛应用于操作系统、嵌入式系统、数据库、编译器等领域的开发。C语言的基本语法包括变量、数据类型、运算符、控制结构(如if语句、循环语句等)、函数、指针等。下面详细介绍C语言的基本概念和语法。 1. 变量和数据类型 在C语言中,变量用于存储数据,数据类型用于定义变量的类型和范围。C语言支持多种数据类型,包括基本数据类型(如int、float、char等)和复合数据类型(如结构体、联合等)。 2. 运算符 C语言中常用的运算符包括算术运算符(如+、、、/等)、关系运算符(如==、!=、、=、<、<=等)、逻辑运算符(如&&、||、!等)。此外,还有位运算符(如&、|、^等)和指针运算符(如、等)。 3. 控制结构 C语言中常用的控制结构包括if语句、循环语句(如for、while等)和switch语句。通过这些控制结构,可以实现程序的分支、循环和多路选择等功能。 4. 函数 函数是C语言中用于封装代码的单元,可以实现代码的复用和模块化。C语言中定义函数使用关键字“void”或返回值类型(如int、float等),并通过“{”和“}”括起来的代码块来实现函数的功能。 5. 指针 指针是C语言中用于存储变量地址的变量。通过指针,可以实现对内存的间接访问和修改。C语言中定义指针使用星号()符号,指向数组、字符串和结构体等数据结构时,还需要注意数组名和字符串常量的特殊性质。 6. 数组和字符串 数组是C语言中用于存储同类型数据的结构,可以通过索引访问和修改数组中的元素。字符串是C语言中用于存储文本数据的特殊类型,通常以字符串常量的形式出现,用双引号("...")括起来,末尾自动添加'\0'字符。 7. 结构体和联合 结构体和联合是C语言中用于存储不同类型数据的复合数据类型。结构体由多个成员组成,每个成员可以是不同的数据类型;联合由多个变量组成,它们共用同一块内存空间。通过结构体和联合,可以实现数据的封装和抽象。 8. 文件操作 C语言中通过文件操作函数(如fopen、fclose、fread、fwrite等)实现对文件的读写操作。文件操作函数通常返回文件指针,用于表示打开的文件。通过文件指针,可以进行文件的定位、读写等操作。 总之,C语言是一种功能强大、灵活高效的编程语言,广泛应用于各种领域。掌握C语言的基本语法和数据结构,可以为编程学习和实践打下坚实的基础。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值