1.SGAN简介
半监督学习只为训练数据集的一小部分提供类别标签。通过内部数据中的隐藏结构,半监督学习从标注数据点的小子集中归纳,以有效对从未见过的新样本进行分类。
要使半监督学习有效,标签数据和无标签数据必须来自相同分布。
半监督生成对抗网络是一种生成对抗网络,其鉴定器是多分类器,不止区分真假两个类,而是学会区分N+1类,其中N是训练数据集中的类数,生成器 生成的伪样本为一个类。
SGAN主要关心的是鉴别器。训练过程的目标是使该网络成为仅使用一部分标签数据的半监督分类器,其准确率接近全监督的分类器(就是其训练数据集中的每个样本都有标签)。
1.2架构图
生成器将随机噪音转换为伪样本;鉴别器输入有标签的真实图像(x,y),无标签的真实图像(x)和生成器生成的伪图像(x*)。先用sigmoid区分真伪,然后使用softmax区分类别。
2.代码实现
2.1导入声明
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
from keras import backend as K
from keras.datasets import mnist
from keras.layers import (Activation, BatchNormalization, Concatenate, Dense,
Dropout, Flatten, Input, Lambda, Reshape)
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import Conv2D, Conv2DTranspose
from keras.models import Model, Sequential
from keras.optimizer_v2 import adam as Adam
from tensorflow.keras.utils import to_categorical
2.2模型输入维度
img_rows = 28
img_cols = 28
channels = 1
# 输入图像维度
img_shape = (img_rows, img_cols, channels)
# 噪声向量的大小,用作生成器的输入
z_dim = 100
# 数据集中类别的数量
num_classes = 10
2.3数据集
此处使用的是MNIST数据集,里面包含50000张含有标签的图片,但是我们只取其中一部分用于训练,其他的都是假设其没有标签。
class Dataset:
def __init__(self, num_labeled):
# 训练中使用的有标签图像的数量
self.num_labeled = num_labeled
# 加载MINST数据集
(self.x_train, self.y_train), (self.x_test,