1.4半监督生成对抗网络(SGAN)

本文介绍了半监督生成对抗网络(SGAN),它在半监督学习中利用少量标签数据进行分类。SGAN的鉴别器作为多分类器,区分真实与生成样本,并学习类别。文章详细阐述了SGAN的架构,包括生成器和鉴别器的设计,并提供了MNIST数据集上的代码实现。训练结果显示,即使在仅有100个标签样本的情况下,SGAN也能达到较高的分类准确率。
摘要由CSDN通过智能技术生成

  1.SGAN简介   

       半监督学习只为训练数据集的一小部分提供类别标签。通过内部数据中的隐藏结构,半监督学习从标注数据点的小子集中归纳,以有效对从未见过的新样本进行分类。

       要使半监督学习有效,标签数据和无标签数据必须来自相同分布。

       半监督生成对抗网络是一种生成对抗网络,其鉴定器是多分类器,不止区分真假两个类,而是学会区分N+1类,其中N是训练数据集中的类数,生成器 生成的伪样本为一个类。

         SGAN主要关心的是鉴别器。训练过程的目标是使该网络成为仅使用一部分标签数据的半监督分类器,其准确率接近全监督的分类器(就是其训练数据集中的每个样本都有标签)。

1.2架构图

 生成器将随机噪音转换为伪样本;鉴别器输入有标签的真实图像(x,y),无标签的真实图像(x)和生成器生成的伪图像(x*)。先用sigmoid区分真伪,然后使用softmax区分类别。

2.代码实现

2.1导入声明

%matplotlib inline

import matplotlib.pyplot as plt
import numpy as np

from keras import backend as K

from keras.datasets import mnist
from keras.layers import (Activation, BatchNormalization, Concatenate, Dense,
                          Dropout, Flatten, Input, Lambda, Reshape)
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import Conv2D, Conv2DTranspose
from keras.models import Model, Sequential
from keras.optimizer_v2 import adam as Adam
from tensorflow.keras.utils import to_categorical

2.2模型输入维度

img_rows = 28
img_cols = 28
channels = 1

# 输入图像维度
img_shape = (img_rows, img_cols, channels)

# 噪声向量的大小,用作生成器的输入
z_dim = 100

# 数据集中类别的数量
num_classes = 10

2.3数据集

此处使用的是MNIST数据集,里面包含50000张含有标签的图片,但是我们只取其中一部分用于训练,其他的都是假设其没有标签。

class Dataset:
    def __init__(self, num_labeled):

        # 训练中使用的有标签图像的数量
        self.num_labeled = num_labeled

        # 加载MINST数据集
        (self.x_train, self.y_train), (self.x_test,
                       
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值