该章节介绍VITGAN对抗生成网络中,Discriminator鉴别器 部分的代码实现。
目录(文章发布后会补上链接):
- 网络结构简介
- Mapping NetWork 实现
- PositionalEmbedding 实现
- MLP 实现
- MSA多头注意力 实现
- SLN自调制 实现
- CoordinatesPositionalEmbedding 实现
- ModulatedLinear 实现
- Siren 实现
- Generator生成器 实现
- PatchEmbedding 实现
- ISN 实现
- Discriminator鉴别器 实现
- VITGAN 实现
Discriminator鉴别器 简介
Discriminator鉴别器 参考的是ViT与BERT结构,加上三项修改。
代码实现
DiscriminatorEncoder 实现
import tensorflow as tf
import sys
sys.path.append('')
from models.msa import MSA
from models.mlp import MLP
class DiscriminatorEncoderLayer(tf.keras.layers.Layer):
def __init__(self, d_model, num_heads, dropout=0.0, discriminator=True):
super().__init__()
self.d_model = d_model
self.num_heads = num_heads
self.dropout = dropout
self.discriminator = discriminator
self.ln1 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
self.msa1 = MSA(d_model, num_heads, discriminator=discriminator)
self.ln2 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
self.mlp1 = MLP(d_model, discriminator=discriminator, dropout=dropout)
def call(self, x, training):
h = x
x = self.ln1(x, training=training)
x = self.msa1(v=x, k=x, q=x, mask=None)
x = x + h
h = x
x = self.ln2(x, training=training)
x = self.mlp1(x)
x = x + h
return x
class DiscriminatorEncoder(tf.keras.layers.Layer):
def __init__(self, d_model, num_heads, num_layers, dropout=0.0, discriminator=True):
super().__init__()
self.d_model = d_model
self.num_heads = num_heads
self.num_layers = num_layers
self.dropout = dropout
self.discriminator = discriminator
self.encoder_layers = [DiscriminatorEncoderLayer(d_model, num_heads, dropout=dropout, discriminator=discriminator) for i in range(num_layers)]
def call(self, x, training):
for encoder_layer in self.encoder_layers:
x = encoder_layer(x=x, training=training)
return x
if __name__ == "__main__":
# layer = DiscriminatorEncoderLayer(256, 8)
layer = DiscriminatorEncoder(256, 8, 4)
x = tf.random.uniform([2,5,256], dtype=tf.float32)
o1 = layer(x, training=True)
tf.print('o1:', tf.shape(o1))
Discriminator鉴别器 实现
import tensorflow as tf
import sys
sys.path.append('')
from models.patch_embedding import PatchEmbedding
from models.discriminator_transformer_encoder import DiscriminatorEncoder
from models.mlp import MLP
from models.positional_embedding import PositionalEmbedding
class Discriminator(tf.keras.layers.Layer):
"""
鉴别器
"""
def __init__(
self,
image_size=224,
patch_size=16,
num_channels=3,
overlapping=3,
d_model=768,
out_dim=1,
dropout=0.0,
discriminator=True,
):
super().__init__()
self.image_size = image_size
self.patch_size = patch_size
self.num_channels = num_channels
self.overlapping = overlapping
self.d_model = d_model
self.out_dim = out_dim
self.dropout = dropout
self.discriminator = discriminator
self.grid_size = image_size // patch_size
self.num_patches = self.grid_size ** 2
self.patch_embedding = PatchEmbedding(
image_size=image_size,
patch_size=patch_size,
overlapping=overlapping,
emb_dim=d_model,
discriminator=discriminator,
)
# 输入位置编码
self.patch_positional_embedding = PositionalEmbedding(
sequence_length=self.num_patches+1,
emb_dim=self.d_model,
)
self.discriminator_transformer_encoder = DiscriminatorEncoder(
self.d_model,
num_heads=8,
num_layers=4,
dropout=dropout,
discriminator=discriminator,
)
self.mlp = MLP(out_dim, discriminator=discriminator, dropout=0.0)
self.cls_token = tf.Variable(tf.random.uniform([1, 1, self.d_model], dtype=tf.float32), dtype=tf.float32)
def call(self, x, training):
batch_size = tf.shape(x)[0]
x = self.patch_embedding(x)
cls_token = tf.repeat(self.cls_token, repeats=batch_size, axis=0)
x = tf.concat([cls_token, x], axis=1)
x_pos = self.patch_positional_embedding()
x += x_pos
x = self.discriminator_transformer_encoder(x, training=training)
x = self.mlp(x)
x = x[:,0,:]
# x = tf.math.sigmoid(x)
return x
if __name__ == "__main__":
layer = Discriminator(
image_size=224,
patch_size=16,
num_channels=3,
d_model=768
)
x = tf.random.uniform([2,224,224,3], dtype=tf.float32)
o1 = layer(x, training=True)
tf.print('o1:', tf.shape(o1))
o1 = layer(x, training=False)
tf.print('o1:', tf.shape(o1))