使用Unet+Vision Transformer来构建gan网络实现pix2pix的图像风格迁移 主函数代码如下: def main(): print(f"training on : {config.DEVICE}") # def __init__(self, in_channels, patch_size, embed_dim, num_patches, dropout, # num_head, activation, num_encoders, num_class): in_channels = 3 * 2 # cat操作后,通道数翻倍了 img_size = 256 patch_size = 16 embed_dim = patch_size ** 2 * in_channels num_patches = (img_size // patch_size) ** 2 dropout = 0.01 num_head = 8 activation = "gelu" num_encoders = 10 num_classes = 1 disc = Discriminator(in_channels, patch_size, embed_dim, num_patches, dropout, num_head, activation, num_encoders, num_classes).to(config.DEVICE) gen = Generator(in_channel=3, out_channel=3).to(config.DEVICE) opt_disc = optim.Adam(disc.parameters(), lr=config.LEARNING_RATE, betas=(0.5, 0.999),
使用Unet+Vision Transformer来构建gan网络实现pix2pix的图像风格迁移
最新推荐文章于 2024-09-17 23:15:58 发布