StyleMapGAN代码解读

generate.py代码调试

  1. 加载模型
if __name__ == "__main__":
    device = "cuda"

    parser = argparse.ArgumentParser()

    parser.add_argument(
        "--mixing_type",
        choices=[
            "local_editing",
            "transplantation",
            "w_interpolation",
            "reconstruction",
            "stylemixing",
            "random_generation",
        ],
        required=True,
    )
    parser.add_argument("--ckpt", metavar="CHECKPOINT", required=True)
    parser.add_argument("--test_lmdb", type=str)
    parser.add_argument("--batch", type=int, default=1)
    parser.add_argument("--num_workers", type=int, default=2)
    parser.add_argument("--save_image_dir", type=str, default="expr")

    # Below argument is needed for local editing.
    parser.add_argument(
        "--local_editing_part",
        type=str,
        default=None,
        choices=[
            "nose",
            "hair",
            "background",
            "eye",
            "eyebrow",
            "lip",
            "neck",
            "cloth",
            "skin",
            "ear",
        ],
    )

    args = parser.parse_args()

    ckpt = torch.load(args.ckpt)
    train_args = ckpt["train_args"]

args中的变量
在这里插入图片描述
train_args中的变量
在这里插入图片描述

  1. train_args命名空间中的变量添加到args命名空间中的变量中。
...
    for key in vars(train_args):
        if not (key in vars(args)):
            setattr(args, key, getattr(train_args, key))
    print(args)

输出打印args命名空间中变量

# 输出打印变量
Namespace(batch=1, batch_per_gpu=8, channel_multiplier=2, ckpt='expr/checkpoints/celeba_hq_256_8x8.pt', d_reg_every=16, dataset='celeba_hq', iter=1400000, lambda_adv_loss=1, lambda_d_loss=1, lambda_indomainGAN_D_loss=1, lambda_indomainGAN_E_loss=1, lambda_perceptual_loss=1, lambda_w_rec_loss=1, lambda_x_rec_loss=1, latent_channel_size=64, latent_spatial_size=8, local_editing_part='cloth', lr=0.002, lr_mul=0.01, mapping_layer_num=8, mapping_method='MLP', mixing_type='local_editing', n_sample=16, ngpus=2, normalize_mode='LayerNorm', num_workers=2, r1=10, remove_indomain=False, remove_w_rec=False, save_image_dir='expr', size=256, small_generator=False, start_iter=0, test_lmdb='data/celeba_hq/LMDB_test', train_lmdb='/data/celeba_hq_lmdb/train/LMDB_train', val_lmdb='/data/celeba_hq_lmdb/train/LMDB_val')

在这里插入图片描述

  1. 调用Model()
    ...
    dataset_name = args.dataset
    args.save_image_dir = os.path.join(
        args.save_image_dir, args.mixing_type, dataset_name
    )

    model = Model().to(device)
  1. 进入Model类中的初始化函数中,调用training.model.py中的Generator()类中的初始化函数。
# generate.py
class Model(nn.Module):
    def __init__(self, device="cuda"):
        super(Model, self).__init__()
        self.g_ema = Generator(
            args.size,
            args.mapping_layer_num,
            args.latent_channel_size,
            args.latent_spatial_size,
            lr_mul=args.lr_mul,
            channel_multiplier=args.channel_multiplier,
            normalize_mode=args.normalize_mode,
            small_generator=args.small_generator,
        )
  1. 进入training.model.py中的Generator()类中的初始化函数,调用training.model.py中的PixelNorm()类中的初始化函数。
class Generator(nn.Module):
    def __init__(
        self,
        size,
        mapping_layer_num,
        style_dim,
        latent_spatial_size,
        lr_mul,
        channel_multiplier,
        normalize_mode,
        blur_kernel=[1, 3, 3, 1],
        small_generator=False,
    ):
        super().__init__()

        self.latent_spatial_size = latent_spatial_size
        self.style_dim = style_dim

        layers = [PixelNorm()]

输入参数
在这里插入图片描述

  1. 进入training.model.py中的PixelNorm()类中的初始化函数。
class PixelNorm(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, input):
        return input * torch.rsqrt(torch.mean(input ** 2, dim=1, keepdim=True) + 1e-8)
  1. 返回training.model.py中的Generator()类中的初始化函数,调用training.model.py中的EqualLinear类的初始化函数。
class Generator(nn.Module):
    def __init__(
        self,
        size,
        mapping_layer_num,
        style_dim,
        latent_spatial_size,
        lr_mul,
        channel_multiplier,
        normalize_mode,
        blur_kernel=[1, 3, 3, 1],
        small_generator=False,
    ):
        ...
        for i in range(mapping_layer_num):
        if i != (mapping_layer_num - 1):
            in_channel = style_dim
            out_channel = style_dim
        else:
            in_channel = style_dim
            out_channel = style_dim * latent_spatial_size * latent_spatial_size

        layers.append(
            EqualLinear(
                in_channel, out_channel, lr_mul=lr_mul, activation="fused_lrelu"
            )
        )
  1. 进入training.model.py中的EqualLinear类的初始化函数。
class EqualLinear(nn.Module):
    def __init__(
        self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None
    ):
        super().__init__()

        self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul))

        if bias:
            self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init))

        else:
            self.bias = None

        self.activation = activation

        self.scale = (1 / math.sqrt(in_dim)) * lr_mul
        self.lr_mul = lr_mul

这里是引用
torch.nn.Parameter是继承自torch.Tensor的子类,其主要作用是作为nn.Module中的可训练参数使用。它与torch.Tensor的区别就是nn.Parameter会自动被认为是module的可训练参数,即加入到parameter()这个迭代器中去;而module中非nn.Parameter()的普通tensor是不在parameter中的。
注意到,nn.Parameter的对象的requires_grad属性的默认值是True,即是可被训练的,这与torth.Tensor对象的默认值相反。
在nn.Module类中,pytorch也是使用nn.Parameter来对每一个module的参数进行初始化的。

  1. 返回training.model.py中的Generator()类中的初始化函数,完成layers创建,调用training.model.py中的Decoder类中的初始化函数。
    在这里插入图片描述
class Generator(nn.Module):
    def __init__(
        self,
        size,
        mapping_layer_num,
        style_dim,
        latent_spatial_size,
        lr_mul,
        channel_multiplier,
        normalize_mode,
        blur_kernel=[1, 3, 3, 1],
        small_generator=False,
    ):
        ...
        self.mapping_z = nn.Sequential(*layers)

        self.decoder = Decoder(
            size,
            style_dim,
            latent_spatial_size,
            channel_multiplier=channel_multiplier,
            blur_kernel=blur_kernel,
            normalize_mode=normalize_mode,
            lr_mul=1,
            small_generator=small_generator,
        )  # always 1, always zero padding
  1. 进入training.model.py中的Decoder类中的初始化函数,调用training.model.py中的ConstantInput类中的初始化函数。
class Decoder(nn.Module):
    def __init__(
        self,
        size,
        style_dim,
        latent_spatial_size,
        channel_multiplier,
        blur_kernel,
        normalize_mode,
        lr_mul,
        small_generator,
    ):
        super().__init__()

        self.size = size

        channels = {
            4: 512,
            8: 512,
            16: 512,
            32: 512,
            64: 256 * channel_multiplier,
            128: 128 * channel_multiplier,
            256: 64 * channel_multiplier,
            512: 32 * channel_multiplier,
            1024: 16 * channel_multiplier,
        }

        self.log_size = int(math.log(size, 2))

        self.input = ConstantInput(
            channels[latent_spatial_size], size=latent_spatial_size
        )

输入参数变量
在这里插入图片描述

  1. 进入training.model.py中的ConstantInput类中的初始化函数。
class ConstantInput(nn.Module):
    def __init__(self, channel, size=4):
        super().__init__()

        self.input = nn.Parameter(torch.randn(1, channel, size, size))

    def forward(self, batch):
        out = self.input.repeat(batch, 1, 1, 1)

        return out
  1. 返回training.model.py中的Decoder类中的初始化函数,调用training.model.py中的StyledConv类的初始化函数。
class Decoder(nn.Module):
    def __init__(
        self,
        size,
        style_dim,
        latent_spatial_size,
        channel_multiplier,
        blur_kernel,
        normalize_mode,
        lr_mul,
        small_generator,
    ):
        ...
                if small_generator:
            stylecode_dim = style_dim
        else:
            stylecode_dim = channels[latent_spatial_size]

        self.conv1 = StyledConv(
            channels[latent_spatial_size],
            channels[latent_spatial_size],
            3,
            stylecode_dim,
            blur_kernel=blur_kernel,
            normalize_mode=normalize_mode,
        )
  1. 进入training.model.py中的StyledConv类的初始化函数,调用training.model.py中的ModulatedConv2d类的初始化函数。
class StyledConv(nn.Module):
    def __init__(
        self,
        in_channel,
        out_channel,
        kernel_size,
        style_dim,
        blur_kernel,
        normalize_mode,
        upsample=False,
        activate=True,
    ):
        super().__init__()

        self.conv = ModulatedConv2d(
            in_channel,
            out_channel,
            kernel_size,
            style_dim,
            upsample=upsample,
            blur_kernel=blur_kernel,
            normalize_mode=normalize_mode,
        )
  1. 进入training.model.py中的ModulatedConv2d类的初始化函数,调用training.model.py中的EqualConv2d类的初始化函数。
class ModulatedConv2d(nn.Sequential):
    def __init__(
        self,
        in_channel,
        out_channel,
        kernel_size,
        style_dim,
        normalize_mode,
        blur_kernel,
        upsample=False,
        downsample=False,
    ):
        super().__init__()

        self.eps = 1e-8
        self.kernel_size = kernel_size
        self.in_channel = in_channel
        self.out_channel = out_channel
        self.upsample = upsample
        self.downsample = downsample

        if upsample:
            factor = 2
            p = (len(blur_kernel) - factor) - (kernel_size - 1)
            pad0 = (p + 1) // 2 + factor - 1
            pad1 = p // 2 + 1

            self.blur = Blur(blur_kernel, pad=(pad0, pad1), upsample_factor=factor)

        if downsample:
            factor = 2
            p = (len(blur_kernel) - factor) + (kernel_size - 1)
            pad0 = (p + 1) // 2
            pad1 = p // 2

            self.blur = Blur(blur_kernel, pad=(pad0, pad1))

        fan_in = in_channel * kernel_size ** 2
        self.scale = 1 / math.sqrt(fan_in)
        self.padding = kernel_size // 2

        self.weight = nn.Parameter(
            torch.randn(1, out_channel, in_channel, kernel_size, kernel_size)
        )

        self.normalize_mode = normalize_mode
        if normalize_mode == "InstanceNorm2d":
            self.norm = nn.InstanceNorm2d(in_channel, affine=False)
        elif normalize_mode == "BatchNorm2d":
            self.norm = nn.BatchNorm2d(in_channel, affine=False)

        self.beta = None

        self.gamma = EqualConv2d(
            style_dim,
            in_channel,
            kernel_size=3,
            padding=1,
            stride=1,
            bias=True,
            bias_init=1,
        )
        self.beta = EqualConv2d(
            style_dim,
            in_channel,
            kernel_size=3,
            padding=1,
            stride=1,
            bias=True,
            bias_init=0,
        )

  1. 进入training.model.py中的EqualConv2d类的初始化函数。
class EqualConv2d(nn.Module):
    def __init__(
        self,
        in_channel,
        out_channel,
        kernel_size,
        stride=1,
        padding=0,
        lr_mul=1,
        bias=True,
        bias_init=0,
        conv_transpose2d=False,
        activation=False,
    ):
        super().__init__()

        self.out_channel = out_channel
        self.kernel_size = kernel_size

        self.weight = nn.Parameter(
            torch.randn(out_channel, in_channel, kernel_size, kernel_size).div_(lr_mul)
        )
        self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2) * lr_mul

        self.stride = stride
        self.padding = padding

        if bias:
            self.bias = nn.Parameter(torch.zeros(out_channel).fill_(bias_init))

            self.lr_mul = lr_mul
        else:
            self.lr_mul = None

        self.conv_transpose2d = conv_transpose2d

        if activation:
            self.activation = ScaledLeakyReLU(0.2)
            # self.activation = FusedLeakyReLU(out_channel)
        else:
            self.activation = False

输入参数变量
在这里插入图片描述

  1. 返回training.model.py中的ModulatedConv2d类的初始化函数,调用training.model.py中的EqualConv2d类的初始化函数。执行完后返回ModulatedConv2d类的初始化函数,执行完后返回StyledConv类的初始化函数, 调用training.model.op.fused_act.py中的FusedLeakyReLU类的初始化函数。
class StyledConv(nn.Module):
    def __init__(
        self,
        in_channel,
        out_channel,
        kernel_size,
        style_dim,
        blur_kernel,
        normalize_mode,
        upsample=False,
        activate=True,
    ):
        ...
        if activate:
            self.activate = FusedLeakyReLU(out_channel)
        else:
            self.activate = None
  1. 进入training.model.op.fused_act.py中的FusedLeakyReLU类的初始化函数。
class FusedLeakyReLU(nn.Module):
    def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5):
        super().__init__()

        self.bias = nn.Parameter(torch.zeros(channel))
        self.negative_slope = negative_slope
        self.scale = scale
  1. 返回StyledConv类的初始化函数,执行完最后一句,返回training.model.py中的Decoder类中的初始化函数,调用training.model.py中的ConvLayer类中的初始化函数。
class Decoder(nn.Module):
    def __init__(
        self,
        size,
        style_dim,
        latent_spatial_size,
        channel_multiplier,
        blur_kernel,
        normalize_mode,
        lr_mul,
        small_generator,
    ):
        ...
        in_channel = channels[latent_spatial_size]

        self.start_index = int(math.log(latent_spatial_size, 2)) + 1  # if 4x4 -> 3
        self.convs = nn.ModuleList()
        self.convs_latent = nn.ModuleList()

        self.convs_latent.append(
            ConvLayer(
                style_dim, stylecode_dim, 3, bias=True, activate=True, lr_mul=lr_mul
            )
        self.convs_latent.append(
            ConvLayer(
                stylecode_dim, stylecode_dim, 3, bias=True, activate=True, lr_mul=lr_mul
            )
        )
  1. 进入training.model.py中的ConvLayer类中的初始化函数。
class ConvLayer(nn.Sequential):
    def __init__(
        self,
        in_channel,
        out_channel,
        kernel_size,
        upsample=False,
        downsample=False,
        blur_kernel=[1, 3, 3, 1],
        bias=True,
        activate=True,
        lr_mul=1,
    ):
        assert not (upsample and downsample)
        layers = []

        if upsample:
            stride = 2
            self.padding = 0
            layers.append(
                EqualConv2d(
                    in_channel,
                    out_channel,
                    kernel_size,
                    padding=self.padding,
                    stride=stride,
                    bias=bias and not activate,
                    conv_transpose2d=True,
                    lr_mul=lr_mul,
                )
            )

            factor = 2
            p = (len(blur_kernel) - factor) - (kernel_size - 1)
            pad0 = (p + 1) // 2 + factor - 1
            pad1 = p // 2 + 1

            layers.append(Blur(blur_kernel, pad=(pad0, pad1), upsample_factor=factor))

        else:

            if downsample:
                factor = 2
                p = (len(blur_kernel) - factor) + (kernel_size - 1)
                pad0 = (p + 1) // 2
                pad1 = p // 2

                layers.append(Blur(blur_kernel, pad=(pad0, pad1)))

                stride = 2
                self.padding = 0

            else:
                stride = 1
                self.padding = kernel_size // 2

            layers.append(
                EqualConv2d(
                    in_channel,
                    out_channel,
                    kernel_size,
                    padding=self.padding,
                    stride=stride,
                    bias=bias and not activate,
                )
            )

        if activate:
            if bias:
                layers.append(FusedLeakyReLU(out_channel))

            else:
                layers.append(ScaledLeakyReLU(0.2))

        super().__init__(*layers)

输入参数
在这里插入图片描述

  1. 返回training.model.py中的Decoder类中的初始化函数,调用training.model.py中的StyledResBlock类中的初始化函数。
class Decoder(nn.Module):
    def __init__(
        self,
        size,
        style_dim,
        latent_spatial_size,
        channel_multiplier,
        blur_kernel,
        normalize_mode,
        lr_mul,
        small_generator,
    ):
        ...
        for i in range(self.start_index, self.log_size + 1):  # 8x8~ 128x128
            if small_generator:
                stylecode_dim_prev, stylecode_dim_next = style_dim, style_dim
            else:
                stylecode_dim_prev = channels[2 ** (i - 1)]
                stylecode_dim_next = channels[2 ** i]
            self.convs_latent.append(
                ConvLayer(
                    stylecode_dim_prev,
                    stylecode_dim_next,
                    3,
                    upsample=True,
                    bias=True,
                    activate=True,
                    lr_mul=lr_mul,
                )
            )
            self.convs_latent.append(
                ConvLayer(
                    stylecode_dim_next,
                    stylecode_dim_next,
                    3,
                    bias=True,
                    activate=True,
                    lr_mul=lr_mul,
                )
            )

        if small_generator:
            stylecode_dim = style_dim
        else:
            stylecode_dim = None

        for i in range(self.start_index, self.log_size + 1):
            out_channel = channels[2 ** i]
            self.convs.append(
                StyledResBlock(
                    in_channel,
                    out_channel,
                    stylecode_dim,
                    blur_kernel,
                    normalize_mode=normalize_mode,
                )
            )

            in_channel = out_channel
  1. 进入training.model.py中的StyledResBlock类中的初始化函数,。
class StyledResBlock(nn.Module):
    def __init__(
        self,
        in_channel,
        out_channel,
        style_dim,
        blur_kernel,
        normalize_mode,
        global_feature_channel=None,
    ):
        super().__init__()

        if style_dim is None:
            if global_feature_channel is not None:
                self.conv1 = StyledConv(
                    in_channel,
                    out_channel,
                    3,
                    in_channel + global_feature_channel,
                    blur_kernel=blur_kernel,
                    upsample=True,
                    normalize_mode=normalize_mode,
                )
                self.conv2 = StyledConv(
                    out_channel,
                    out_channel,
                    3,
                    out_channel + global_feature_channel,
                    blur_kernel=blur_kernel,
                    normalize_mode=normalize_mode,
                )
            else:
                self.conv1 = StyledConv(
                    in_channel,
                    out_channel,
                    3,
                    in_channel,
                    blur_kernel=blur_kernel,
                    upsample=True,
                    normalize_mode=normalize_mode,
                )
                self.conv2 = StyledConv(
                    out_channel,
                    out_channel,
                    3,
                    out_channel,
                    blur_kernel=blur_kernel,
                    normalize_mode=normalize_mode,
                )
        else:
            self.conv1 = StyledConv(
                in_channel,
                out_channel,
                3,
                style_dim,
                blur_kernel=blur_kernel,
                upsample=True,
                normalize_mode=normalize_mode,
            )
            self.conv2 = StyledConv(
                out_channel,
                out_channel,
                3,
                style_dim,
                blur_kernel=blur_kernel,
                normalize_mode=normalize_mode,
            )

        self.skip = ConvLayer(
            in_channel, out_channel, 1, upsample=True, activate=False, bias=False
        )

输入参数
在这里插入图片描述

  1. 返回training.model.py中的Decoder类中的初始化函数,执行完后面的代码。
class Decoder(nn.Module):
    def __init__(
        self,
        size,
        style_dim,
        latent_spatial_size,
        channel_multiplier,
        blur_kernel,
        normalize_mode,
        lr_mul,
        small_generator,
    ):
        ...
                if small_generator:
            stylecode_dim = style_dim
        else:
            stylecode_dim = channels[size]

        # add adain to to_rgb
        self.to_rgb = StyledConv(
            channels[size],
            3,
            1,
            stylecode_dim,
            blur_kernel=blur_kernel,
            normalize_mode=normalize_mode,
        )

        self.num_stylecodes = self.log_size * 2 - 2 * (
            self.start_index - 2
        )  # the number of AdaIN layer(stylecodes)
        assert len(self.convs) * 2 + 2 == self.num_stylecodes

        self.latent_spatial_size = latent_spatial_size
  1. 返回training.model.py中的Generator()类中的初始化函数,返回Model类中的初始化函数中,调用training.model.py中的Encoder类中的初始化函数。
class Model(nn.Module):
    def __init__(self, device="cuda"):
        ....
            self.e_ema = Encoder(
            args.size,
            args.latent_channel_size,
            args.latent_spatial_size,
            channel_multiplier=args.channel_multiplier,
        )
  1. 进入training.model.py中的Encoder类中的初始化函数,调用training.model.py中的ResBlock类初始化函数。
class Encoder(nn.Module):
    def __init__(
        self,
        size,
        style_dim,
        latent_spatial_size,
        channel_multiplier,
        blur_kernel=[1, 3, 3, 1],
    ):
        super().__init__()

        channels = {
            1: 512,
            2: 512,
            4: 512,
            8: 512,
            16: 512,
            32: 512,
            64: 256 * channel_multiplier,
            128: 128 * channel_multiplier,
            256: 64 * channel_multiplier,
            512: 32 * channel_multiplier,
            1024: 16 * channel_multiplier,
        }

        self.from_rgb = ConvLayer(3, channels[size], 1)
        self.convs = nn.ModuleList()

        log_size = int(math.log(size, 2))
        self.log_size = log_size

        in_channel = channels[size]
        end = int(math.log(latent_spatial_size, 2))

        for i in range(self.log_size, end, -1):
            out_channel = channels[2 ** (i - 1)]

            self.convs.append(
                ResBlock(in_channel, out_channel, blur_kernel, return_features=True)
            )

输入参数
在这里插入图片描述

  1. 进入training.model.py中的ResBlock类初始化函数。
class ResBlock(nn.Module):
    def __init__(
        self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1], return_features=False
    ):
        super().__init__()

        self.conv1 = ConvLayer(in_channel, in_channel, 3)
        self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True)
        self.skip = ConvLayer(
            in_channel, out_channel, 1, downsample=True, activate=False, bias=False
        )
        self.return_features = return_features
  1. 返回training.model.py中的Encoder类中的初始化函数,执行后面的代
    码。
class Encoder(nn.Module):
    def __init__(
        self,
        size,
        style_dim,
        latent_spatial_size,
        channel_multiplier,
        blur_kernel=[1, 3, 3, 1],
    ):
        ...
            in_channel = out_channel

        self.final_conv = ConvLayer(in_channel, style_dim, 3)
  1. 返回Model类中的初始化函数中,返回generate.py中的主程序中。继续执行,调用training.dataset.py中的GTMaskDataset("data/celeba_hq", transform, args.size)类初始化函数。
    model.g_ema.load_state_dict(ckpt["g_ema"])
    model.e_ema.load_state_dict(ckpt["e_ema"])
    model.eval()

    batch = args.batch

    device = "cuda"
    transform = transforms.Compose(
        [
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True),
        ]
    )

    if args.mixing_type == "random_generation":
        os.makedirs(args.save_image_dir, exist_ok=True)
    elif args.mixing_type in [
        "w_interpolation",
        "reconstruction",
        "transplantation",
        "stylemixing",
    ]:
        os.makedirs(args.save_image_dir, exist_ok=True)
        dataset = MultiResolutionDataset(args.test_lmdb, transform, args.size)
    elif args.mixing_type == "local_editing":

        if dataset_name == "afhq":
            args.save_image_dir = os.path.join(args.save_image_dir)
            for kind in [
                "mask",
                "source_image",
                "source_reconstruction",
                "reference_image",
                "reference_reconstruction",
                "synthesized_image",
            ]:
                os.makedirs(os.path.join(args.save_image_dir, kind), exist_ok=True)
        else:  # celeba_hq
            args.save_image_dir = os.path.join(
                args.save_image_dir,
                args.local_editing_part,
            )
            for kind in [
                "mask",
                "mask_ref",
                "mask_src",
                "source_image",
                "source_reconstruction",
                "reference_image",
                "reference_reconstruction",
                "synthesized_image",
            ]:
                os.makedirs(os.path.join(args.save_image_dir, kind), exist_ok=True)
            mask_path_base = f"data/{dataset_name}/local_editing"

        # GT celeba_hq mask images
        if dataset_name == "celeba_hq":
            assert "celeba_hq" in args.test_lmdb

            dataset = GTMaskDataset("data/celeba_hq", transform, args.size)
  1. 进入training.dataset.py中的GTMaskDataset("data/celeba_hq", transform, args.size)类初始化函数。
class GTMaskDataset(Dataset):
    def __init__(self, dataset_folder, transform, resolution=256):

        self.env = lmdb.open(
            f"{dataset_folder}/LMDB_test",
            max_readers=32,
            readonly=True,
            lock=False,
            readahead=False,
            meminit=False,
        )

        if not self.env:
            raise IOError("Cannot open lmdb dataset", f"{dataset_folder}/LMDB_test")

        with self.env.begin(write=False) as txn:
            self.length = int(txn.get("length".encode("utf-8")).decode("utf-8"))

        self.resolution = resolution
        self.transform = transform

        # convert filename to celeba_hq index
        CelebA_HQ_to_CelebA = (
            f"{dataset_folder}/local_editing/CelebA-HQ-to-CelebA-mapping.txt"
        )
        CelebA_to_CelebA_HQ_dict = {}

        original_test_path = f"{dataset_folder}/raw_images/test/images"
        mask_label_path = f"{dataset_folder}/local_editing/GT_labels"

        with open(CelebA_HQ_to_CelebA, "r") as fp:
            read_line = fp.readline()
            attrs = re.sub(" +", " ", read_line).strip().split(" ")
            while True:
                read_line = fp.readline()

                if not read_line:
                    break

                idx, orig_idx, orig_file = (
                    re.sub(" +", " ", read_line).strip().split(" ")
                )

                CelebA_to_CelebA_HQ_dict[orig_file] = idx

        self.mask = []

        for filename in os.listdir(original_test_path):
            CelebA_HQ_filename = CelebA_to_CelebA_HQ_dict[filename]
            CelebA_HQ_filename = CelebA_HQ_filename + ".png"
            self.mask.append(os.path.join(mask_label_path, CelebA_HQ_filename))

在这里插入图片描述

  1. 返回generate.py中的主程序中,继续执行,调用data_sampler(dataset, shuffle=False)函数。
            parts_index = {
                "background": [0],
                "skin": [1],
                "eyebrow": [6, 7],
                "eye": [3, 4, 5],
                "ear": [8, 9, 15],
                "nose": [2],
                "lip": [10, 11, 12],
                "neck": [16, 17],
                "cloth": [18],
                "hair": [13, 14],
            }

        # afhq, coarse(half-and-half) masks
        else:
            assert "afhq" in args.test_lmdb and "afhq" == dataset_name
            dataset = MultiResolutionDataset(args.test_lmdb, transform, args.size)

    if args.mixing_type in [
        "w_interpolation",
        "reconstruction",
        "stylemixing",
        "local_editing",
    ]:
        n_sample = len(dataset)
        sampler = data_sampler(dataset, shuffle=False)

这里是引用

  1. 进入data_sample(dataset, shuffle)函数
def data_sampler(dataset, shuffle):
    if shuffle:
        return data.RandomSampler(dataset)
    else:
        return data.SequentialSampler(dataset)
  1. 返回主程序。
        loader = data.DataLoader(
            dataset,
            batch,
            sampler=sampler,
            num_workers=args.num_workers,
            pin_memory=True,
            drop_last=False,
        )

        # generated images should match with n sample
        if n_sample % batch == 0:
            assert len(loader) == n_sample // batch
        else:
            assert len(loader) == n_sample // batch + 1

        total_latents = torch.Tensor().to(device)
        real_imgs = torch.Tensor().to(device)

        if args.mixing_type == "local_editing":
            if dataset_name == "afhq":
                masks = (
                    -2 * torch.ones(n_sample, args.size, args.size).to(device).float()
                )

                mix_type = list(range(n_sample))
                random.shuffle(mix_type)
                horizontal_mix = mix_type[: n_sample // 2]
                vertical_mix = mix_type[n_sample // 2 :]

                masks[horizontal_mix, :, args.size // 2 :] = 2
                masks[vertical_mix, args.size // 2 :, :] = 2
            else:
                masks = torch.Tensor().to(device).long()

    with torch.no_grad():
        if args.mixing_type == "random_generation":
            truncation = 0.7
            truncation_sample = 5000
            truncation_mean_latent = torch.Tensor().to(device)
            for _ in range(truncation_sample // batch):
                z = make_noise(batch, args.latent_channel_size, device)
                partial_mean_latent = model(z, mode="calculate_mean_stylemap")
                truncation_mean_latent = torch.cat(
                    [truncation_mean_latent, partial_mean_latent], dim=0
                )
            truncation_mean_latent = truncation_mean_latent.mean(0, keepdim=True)

            # refer to stylegan official repository: https://github.com/NVlabs/stylegan/blob/master/generate_figures.py
            cx, cy, cw, ch, rows, lods = 0, 0, 1024, 1024, 3, [0, 1, 2, 2, 3, 3]

            for seed in range(0, 4):
                torch.manual_seed(seed)
                png = f"{args.save_image_dir}/random_generation_{seed}.png"
                print(png)

                total_images_len = sum(rows * 2 ** lod for lod in lods)
                total_images = torch.Tensor()

                while total_images_len > 0:
                    num = batch if total_images_len > batch else total_images_len
                    z = make_noise(num, args.latent_channel_size, device)
                    total_images_len -= batch

                    images = model(
                        (z, truncation, truncation_mean_latent),
                        mode="random_generation",
                    )

                    images = images.permute(0, 2, 3, 1)
                    images = images.cpu()
                    total_images = torch.cat([total_images, images], dim=0)

                total_images = torch.clamp(total_images, min=-1.0, max=1.0)
                total_images = (total_images + 1) / 2 * 255
                total_images = total_images.numpy().astype(np.uint8)

                canvas = Image.new(
                    "RGB",
                    (sum(cw // 2 ** lod for lod in lods), ch * rows),
                    "white",
                )
                image_iter = iter(list(total_images))
                for col, lod in enumerate(lods):
                    for row in range(rows * 2 ** lod):
                        image = Image.fromarray(next(image_iter), "RGB")
                        # image = image.crop((cx, cy, cx + cw, cy + ch))
                        image = image.resize(
                            (cw // 2 ** lod, ch // 2 ** lod), Image.ANTIALIAS
                        )
                        canvas.paste(
                            image,
                            (
                                sum(cw // 2 ** lod for lod in lods[:col]),
                                row * ch // 2 ** lod,
                            ),
                        )
                canvas.save(png)

        elif args.mixing_type == "reconstruction":
            for i, real_img in enumerate(tqdm(loader, mininterval=1)):
                real_img = real_img.to(device)
                recon_image = model(real_img, "reconstruction")

                for i_b, (img_1, img_2) in enumerate(zip(real_img, recon_image)):
                    save_images(
                        [img_1, img_2],
                        [
                            f"{args.save_image_dir}/{i*batch+i_b}_real.png",
                            f"{args.save_image_dir}/{i*batch+i_b}_recon.png",
                        ],
                    )

        elif args.mixing_type == "transplantation":

            for kind in [
                "source_image",
                "source_reconstruction",
                "reference_image",
                "reference_reconstruction",
                "synthesized_image",
            ]:
                os.makedirs(os.path.join(args.save_image_dir, kind), exist_ok=True)

            # AFHQ
            transplantation_dataset = [
                (62, 271, [((4, 2), (3, 2), 2, 4), ((0, 1), (0, 1), 3, 2)])
            ]

            for index_src, index_ref, coordinates in transplantation_dataset:
                src_img = dataset[index_src].to(device)
                ref_img = dataset[index_ref].to(device)

                mixed_image, recon_img_src, recon_img_ref = model(
                    (src_img, ref_img, coordinates), mode="transplantation"
                )

                ratio = 256 // 8

                src_img = (src_img + 1) / 2
                ref_img = (ref_img + 1) / 2

                colors = [(0, 0, 255), (0, 255, 0), (0, 255, 0)]

                for color_i, (
                    (src_p_y, src_p_x),
                    (ref_p_y, ref_p_x),
                    height,
                    width,
                ) in enumerate(coordinates):
                    for i in range(2):
                        img = src_img if i == 0 else ref_img
                        img = img.cpu()
                        img = transforms.ToPILImage()(img)
                        img = np.asarray(img)
                        img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
                        if i == 0:
                            img = cv2.rectangle(
                                img,
                                (src_p_x * ratio, src_p_y * ratio),
                                (
                                    (src_p_x + width) * ratio,
                                    (src_p_y + height) * ratio,
                                ),
                                colors[color_i],
                                2,
                            )
                        else:
                            img = cv2.rectangle(
                                img,
                                (ref_p_x * ratio, ref_p_y * ratio),
                                (
                                    (ref_p_x + width) * ratio,
                                    (ref_p_y + height) * ratio,
                                ),
                                colors[color_i],
                                2,
                            )
                        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
                        img = transforms.ToTensor()(img)

                        if i == 0:
                            src_img = img
                        else:
                            ref_img = img

                save_images(
                    [mixed_image[0], recon_img_src[0], recon_img_ref[0]],
                    [
                        f"{args.save_image_dir}/synthesized_image/{index_src}_{index_ref}.png",
                        f"{args.save_image_dir}/source_reconstruction/{index_src}_{index_ref}.png",
                        f"{args.save_image_dir}/reference_reconstruction/{index_src}_{index_ref}.png",
                    ],
                )

                save_images(
                    [src_img, ref_img],
                    [
                        f"{args.save_image_dir}/source_image/{index_src}_{index_ref}.png",
                        f"{args.save_image_dir}/reference_image/{index_src}_{index_ref}.png",
                    ],
                    range=(0, 1),
                )

        else:
            for i, real_img in enumerate(tqdm(loader, mininterval=1)):

在这里插入图片描述

  1. 进入training.dataset.py中的GTMaskDataset类中的__getitem__函数中。
class GTMaskDataset(Dataset):
    def __getitem__(self, index):
        with self.env.begin(write=False) as txn:
            key = f"{self.resolution}-{str(index).zfill(5)}".encode("utf-8")
            img_bytes = txn.get(key)

        buffer = BytesIO(img_bytes)
        img = Image.open(buffer)
        img = self.transform(img)

        mask = Image.open(self.mask[index])

        mask = mask.resize((self.resolution, self.resolution), Image.NEAREST)
        mask = transforms.ToTensor()(mask)

        mask = mask.squeeze()
        mask *= 255
        mask = mask.long()

        assert mask.shape == (self.resolution, self.resolution)
        return img, mask
  1. 返回主程序调用Model类的model函数.
                if (args.mixing_type == "local_editing") and (
                    dataset_name == "celeba_hq"
                ):
                    real_img, mask = real_img
                    mask = mask.to(device)
                    masks = torch.cat([masks, mask], dim=0)
                real_img = real_img.to(device)

                latents = model(real_img, "projection")

  1. 进入Model类的forward函数.
class Model(nn.Module):
    def forward(self, input, mode):
        if mode == "projection":
            fake_stylecode = self.e_ema(input)

            return fake_stylecode
  1. 进入Encoder类的forward函数.
class Encoder(nn.Module):
    def forward(self, input):
        out = self.from_rgb(input)
  1. 进入EqualConv2d类的forward函数.
class EqualConv2d(nn.Module):
    def forward(self, input):
        if self.lr_mul != None:
            bias = self.bias * self.lr_mul
        else:
            bias = None

        if self.conv_transpose2d:
            # group version for fast training
            batch, in_channel, height, width = input.shape
            input_temp = input.view(1, batch * in_channel, height, width)
            weight = self.weight.unsqueeze(0).repeat(batch, 1, 1, 1, 1)
            weight = weight.transpose(1, 2).reshape(
                batch * in_channel, self.out_channel, self.kernel_size, self.kernel_size
            )
            out = F.conv_transpose2d(
                input_temp,
                weight * self.scale,
                bias=bias,
                padding=self.padding,
                stride=2,
                groups=batch,
            )
            _, _, height, width = out.shape
            out = out.view(batch, self.out_channel, height, width)

        else:
            out = F.conv2d(
                input,
                self.weight * self.scale,
                bias=bias,
                stride=self.stride,
                padding=self.padding,
            )

        if self.activation:
            out = self.activation(out)

        return out
  1. 返回Encoder类的forward函数.
class Encoder(nn.Module):
    def forward(self, input):
        ...
        for convs in self.convs:
            out, _, _ = convs(out)

        out = self.final_conv(out)

        return out  # spatial style code
  1. 进入EqualConv2d类的forward函数,然后进入ResBlock类的forward函数.
class ResBlock(nn.Module):
    def forward(self, input):
        out1 = self.conv1(input)
ModuleList(
  (0): ResBlock(
    (conv1): ConvLayer(
      (0): EqualConv2d(128, 128, 3, stride=1, padding=1)
      (1): FusedLeakyReLU()
    )
    (conv2): ConvLayer(
      (0): Blur()
      (1): EqualConv2d(128, 256, 3, stride=2, padding=0)
      (2): FusedLeakyReLU()
    )
    (skip): ConvLayer(
      (0): Blur()
      (1): EqualConv2d(128, 256, 1, stride=2, padding=0)
    )
  )
  (1): ResBlock(
    (conv1): ConvLayer(
      (0): EqualConv2d(256, 256, 3, stride=1, padding=1)
      (1): FusedLeakyReLU()
    )
    (conv2): ConvLayer(
      (0): Blur()
      (1): EqualConv2d(256, 512, 3, stride=2, padding=0)
      (2): FusedLeakyReLU()
    )
    (skip): ConvLayer(
      (0): Blur()
      (1): EqualConv2d(256, 512, 1, stride=2, padding=0)
    )
  )
  (2): ResBlock(
    (conv1): ConvLayer(
      (0): EqualConv2d(512, 512, 3, stride=1, padding=1)
      (1): FusedLeakyReLU()
    )
    (conv2): ConvLayer(
      (0): Blur()
      (1): EqualConv2d(512, 512, 3, stride=2, padding=0)
      (2): FusedLeakyReLU()
    )
    (skip): ConvLayer(
      (0): Blur()
      (1): EqualConv2d(512, 512, 1, stride=2, padding=0)
    )
  )
  (3): ResBlock(
    (conv1): ConvLayer(
      (0): EqualConv2d(512, 512, 3, stride=1, padding=1)
      (1): FusedLeakyReLU()
    )
    (conv2): ConvLayer(
      (0): Blur()
      (1): EqualConv2d(512, 512, 3, stride=2, padding=0)
      (2): FusedLeakyReLU()
    )
    (skip): ConvLayer(
      (0): Blur()
      (1): EqualConv2d(512, 512, 1, stride=2, padding=0)
    )
  )
  (4): ResBlock(
    (conv1): ConvLayer(
      (0): EqualConv2d(512, 512, 3, stride=1, padding=1)
      (1): FusedLeakyReLU()
    )
    (conv2): ConvLayer(
      (0): Blur()
      (1): EqualConv2d(512, 512, 3, stride=2, padding=0)
      (2): FusedLeakyReLU()
    )
    (skip): ConvLayer(
      (0): Blur()
      (1): EqualConv2d(512, 512, 1, stride=2, padding=0)
    )
  )
)
  1. 进入EqualConv2d类的forward函数.

  2. 返回ResBlock类的forward函数。

class ResBlock(nn.Module):
    def forward(self, input):
        ...
        out2 = self.conv2(out1)
  1. 进入Blur类的forward函数,调用upfirdn2d(input, self.kernel, pad=self.pad)函数。
class Blur(nn.Module):
    def forward(self, input):
        out = upfirdn2d(input, self.kernel, pad=self.pad)

        return out
  1. 进入training.op.upfirdn2d.py中的upfirdn2d(input, self.kernel, pad=self.pad)函数。
def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
    out = UpFirDn2d.apply(
        input, kernel, (up, up), (down, down), (pad[0], pad[1], pad[0], pad[1])
    )

    return out
  1. 进入UpFirDn2d类的forward函数中。
class UpFirDn2d(Function):
    @staticmethod
    def forward(ctx, input, kernel, up, down, pad):
        up_x, up_y = up
        down_x, down_y = down
        pad_x0, pad_x1, pad_y0, pad_y1 = pad

        kernel_h, kernel_w = kernel.shape
        batch, channel, in_h, in_w = input.shape
        ctx.in_size = input.shape

        input = input.reshape(-1, in_h, in_w, 1)

        ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1]))

        out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
        out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
        ctx.out_size = (out_h, out_w)

        ctx.up = (up_x, up_y)
        ctx.down = (down_x, down_y)
        ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1)

        g_pad_x0 = kernel_w - pad_x0 - 1
        g_pad_y0 = kernel_h - pad_y0 - 1
        g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1
        g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1

        ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1)

        out = upfirdn2d_op.upfirdn2d(
            input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
        )
        # out = out.view(major, out_h, out_w, minor)
        out = out.view(-1, channel, out_h, out_w)

        return out
  1. 返回Blur类的forward函数。进入EqualConv2d类的forward函数.

  2. 返回ResBlock类的forward函数。

class ResBlock(nn.Module):
    def forward(self, input):
        ...
        skip = self.skip(input)
        out = (out2 + skip) / math.sqrt(2)

        if self.return_features:
            return out, out1, out2
        else:
            return out
  1. 返回Encoder类的forward函数.
class Encoder(nn.Module):
    def forward(self, input):
         ...
        for convs in self.convs:
            out, _, _ = convs(out)

        out = self.final_conv(out)

        return out  # spatial style code
  1. 返回Model类的forward函数.返回主程序。继续执行。
                total_latents = torch.cat([total_latents, latents], dim=0)
                real_imgs = torch.cat([real_imgs, real_img], dim=0)
                            elif args.mixing_type == "local_editing":
                if dataset_name == "afhq":
                    # change it later
                    indices = list(range(len(total_latents)))
                    random.shuffle(indices)
                    indices1 = indices[: len(total_latents) // 2]
                    indices2 = indices[len(total_latents) // 2 :]

                else:
                    with open(
                        f"{mask_path_base}/celeba_hq_test_GT_sorted_pair.pkl",
                        "rb",
                    ) as f:
                        sorted_similarity = pickle.load(f)

                    indices1 = []
                    indices2 = []
                    for (i1, i2), _ in sorted_similarity[args.local_editing_part]:
                        indices1.append(i1)
                        indices2.append(i2)

在这里插入图片描述
在这里插入图片描述
51. 继续执行主程序,调用模型model((total_latents[index1], total_latents[index2], mask),local_editing",)

            for loop_i, (index1, index2) in tqdm(
                enumerate(zip(indices1, indices2)), total=n_sample
            ):
                if args.mixing_type == "w_interpolation":
                    imgs = model(
                        (total_latents[index1], total_latents[index2]),
                        "w_interpolation",
                    )
                    assert len(imgs) == 1
                    save_image(
                        imgs[0],
                        f"{args.save_image_dir}/{loop_i}.png",
                    )
                elif args.mixing_type == "stylemixing":
                    n_rows = len(index2)
                    coarse_img, fine_img = model(
                        (
                            torch.stack([total_latents[index1] for _ in range(n_rows)]),
                            torch.stack([total_latents[i2] for i2 in index2]),
                        ),
                        "stylemixing",
                    )

                    save_images(
                        [coarse_img, fine_img],
                        [
                            f"{args.save_image_dir}/{index1}_coarse.png",
                            f"{args.save_image_dir}/{index1}_fine.png",
                        ],
                    )

                elif args.mixing_type == "local_editing":
                    src_img = real_imgs[index1]
                    ref_img = real_imgs[index2]

                    if dataset_name == "celeba_hq":
                        mask1_logit = masks[index1]
                        mask2_logit = masks[index2]

                        mask1 = -torch.ones(mask1_logit.shape).to(
                            device
                        )  # initialize with -1
                        mask2 = -torch.ones(mask2_logit.shape).to(
                            device
                        )  # initialize with -1

                        for label_i in parts_index[args.local_editing_part]:
                            mask1[(mask1_logit == label_i) == True] = 1
                            mask2[(mask2_logit == label_i) == True] = 1

                        mask = mask1 + mask2
                        mask = mask.float()
                    elif dataset_name == "afhq":
                        mask = masks[index1]

                    mixed_image, recon_img_src, recon_img_ref = model(
                        (total_latents[index1], total_latents[index2], mask),
                        "local_editing",
                    )
  1. 进入Model类的forward函数,调用Generator类的forward函数。
class Model(nn.Module):
    def forward(self, input, mode):
        ...
        elif mode == "local_editing":
            w1, w2, mask = input
            w1, w2, mask = w1.unsqueeze(0), w2.unsqueeze(0), mask.unsqueeze(0)

            if dataset_name == "celeba_hq":
                mixed_image = self.g_ema(
                    [w1, w2],
                    input_is_stylecode=True,
                    mix_space="w_plus",
                    mask=mask,
                )[0]
  1. 进入Generator类的forward函数,调用decoder(stylecode, mix_space=mix_space, mask=mask)函数。
class Generator(nn.Module):
    def forward(
        self,
        input,
        return_stylecode=False,
        input_is_stylecode=False,
        mix_space=None,
        mask=None,
        calculate_mean_stylemap=False,
        truncation=None,
        truncation_mean_latent=None,
    ):
        if calculate_mean_stylemap:  # calculate mean_latent
            stylecode = self.mapping_z(input)
            return stylecode.mean(0, keepdim=True)
        else:
            if input_is_stylecode:
                stylecode = input
            else:
                stylecode = self.mapping_z(input)
                if truncation != None and truncation_mean_latent != None:
                    stylecode = truncation_mean_latent + truncation * (
                        stylecode - truncation_mean_latent
                    )
                N, C = stylecode.shape
                stylecode = stylecode.reshape(
                    N, -1, self.latent_spatial_size, self.latent_spatial_size
                )

            image = self.decoder(stylecode, mix_space=mix_space, mask=mask)

            if return_stylecode == True:
                return image, stylecode
            else:
                return image, None

在这里插入图片描述

  1. 进入Decoder类的forward函数中。
class Decoder(nn.Module):
   def forward(self, style_code, mix_space=None, mask=None):
       ...
       else:
           batch = style_code[0].shape[0]

       style_codes = []
       ...
       elif mix_space == "w_plus":  # mix stylemaps in W+ space
	       style_code1 = style_code[0]
	       style_code2 = style_code[1]
	       style_codes1 = []
	       style_codes2 = []
	
	       for up_layer in self.convs_latent:
	           style_code1 = up_layer(style_code1)
	           style_code2 = up_layer(style_code2)
	           style_codes1.append(style_code1)
	           style_codes2.append(style_code2)
	
	       for i in range(0, len(style_codes2)):
	           _, C, H, W = style_codes2[i].shape
	           ratio = self.size // H
	           # print(mask)
	           mask_for_latent = nn.MaxPool2d(kernel_size=ratio, stride=ratio)(mask)
	           mask_for_latent = mask_for_latent.unsqueeze(1).repeat(1, C, 1, 1)
	           style_codes2[i] = torch.where(
	               mask_for_latent > -1, style_codes2[i], style_codes1[i]
	           )
	
	       style_codes = style_codes2
	    ....
        out = self.input(batch)
        out = self.conv1(out, style_codes[0])

        for i in range(len(self.convs)):
            out = self.convs[i](out, [style_codes[2 * i + 1], style_codes[2 * i + 2]])
        image = self.to_rgb(out, style_codes[-1])

        return image

self.cons_latent结构

ModuleList(
  (0): ConvLayer(
    (0): EqualConv2d(64, 512, 3, stride=1, padding=1)
    (1): FusedLeakyReLU()
  )
  (1): ConvLayer(
    (0): EqualConv2d(512, 512, 3, stride=1, padding=1)
    (1): FusedLeakyReLU()
  )
  (2): ConvLayer(
    (0): EqualConv2d(512, 512, 3, stride=2, padding=0)
    (1): Blur()
    (2): FusedLeakyReLU()
  )
  (3): ConvLayer(
    (0): EqualConv2d(512, 512, 3, stride=1, padding=1)
    (1): FusedLeakyReLU()
  )
  (4): ConvLayer(
    (0): EqualConv2d(512, 512, 3, stride=2, padding=0)
    (1): Blur()
    (2): FusedLeakyReLU()
  )
  (5): ConvLayer(
    (0): EqualConv2d(512, 512, 3, stride=1, padding=1)
    (1): FusedLeakyReLU()
  )
  (6): ConvLayer(
    (0): EqualConv2d(512, 512, 3, stride=2, padding=0)
    (1): Blur()
    (2): FusedLeakyReLU()
  )
  (7): ConvLayer(
    (0): EqualConv2d(512, 512, 3, stride=1, padding=1)
    (1): FusedLeakyReLU()
  )
  (8): ConvLayer(
    (0): EqualConv2d(512, 256, 3, stride=2, padding=0)
    (1): Blur()
    (2): FusedLeakyReLU()
  )
  (9): ConvLayer(
    (0): EqualConv2d(256, 256, 3, stride=1, padding=1)
    (1): FusedLeakyReLU()
  )
  (10): ConvLayer(
    (0): EqualConv2d(256, 128, 3, stride=2, padding=0)
    (1): Blur()
    (2): FusedLeakyReLU()
  )
  (11): ConvLayer(
    (0): EqualConv2d(128, 128, 3, stride=1, padding=1)
    (1): FusedLeakyReLU()
  )
)
  1. 返回Generator类的forward函数,返回Model类的forward函数,继续执行后面的函数。
class Model(nn.Module):
    def forward(self, input, mode):
        ...
        
            recon_img_src, _ = self.g_ema(w1, input_is_stylecode=True)
            recon_img_ref, _ = self.g_ema(w2, input_is_stylecode=True)

            return mixed_image, recon_img_src, recon_img_ref
  1. 返回主程序
                    mixed_image, recon_img_src, recon_img_ref = model(
                        (total_latents[index1], total_latents[index2], mask),
                        "local_editing",
                    )

                    save_images(
                        [
                            mixed_image[0],
                            recon_img_src[0],
                            src_img,
                            ref_img,
                            recon_img_ref[0],
                        ],
                        [
                            f"{args.save_image_dir}/synthesized_image/{index1}.png",
                            f"{args.save_image_dir}/source_reconstruction/{index1}.png",
                            f"{args.save_image_dir}/source_image/{index1}.png",
                            f"{args.save_image_dir}/reference_image/{index1}.png",
                            f"{args.save_image_dir}/reference_reconstruction/{index1}.png",
                        ],
                    )

                    mask[mask < -1] = -1
                    mask[mask > -1] = 1

                    save_image(
                        mask,
                        f"{args.save_image_dir}/mask/{index1}.png",
                    )

                    if dataset_name == "celeba_hq":
                        save_images(
                            [mask1, mask2],
                            [
                                f"{args.save_image_dir}/mask_src/{index1}.png",
                                f"{args.save_image_dir}/mask_ref/{index1}.png",
                            ],
                        )

pair_masks.py代码调试

以celeba_hq的人脸分析成分,保存对应的输入图像的交并比。

  1. 主程序执行,进入group_pair_GT()函数。
if __name__ == "__main__":
    device = "cuda"

    parser = argparse.ArgumentParser()
    parser.add_argument("--num_workers", type=int, default=1)
    parser.add_argument("--batch", type=int, default=1)
    parser.add_argument(
        "--save_dir", type=str, default="../data/celeba_hq/local_editing"
    )

    args = parser.parse_args()
    args.dataset_name = "celeba_hq"
    os.makedirs(args.save_dir, exist_ok=True)
    args.path = f"../data/{args.dataset_name}"
    args.mask_origin = 'GT_test'
    with torch.no_grad():
        # our CelebA-HQ test dataset contains 500 images
        # change this value if you have the different number of GT_labels
        args.n_sample = 5
        group_pair_GT()
  1. 进入group_pair_GT()函数,调用GTMaskDataset(args.path, transform, images_size)函数.
@torch.no_grad()
def group_pair_GT():
    device = "cuda"
    args.n_sample = 5

    transform = transforms.Compose(
        [
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True),
        ]
    )

    images_size = 256  # you can use other resolution for calculating if your LMDB(args.path) has different resolution.
    dataset = GTMaskDataset(args.path, transform, images_size)

  1. 进入GTMaskDataset(args.path, transform, images_size)函数.
# dataset.py
class GTMaskDataset(Dataset):
    def __init__(self, dataset_folder, transform, resolution=256):

        self.env = lmdb.open(
            f"{dataset_folder}/LMDB_test",
            max_readers=32,
            readonly=True,
            lock=False,
            readahead=False,
            meminit=False,
        )

        if not self.env:
            raise IOError("Cannot open lmdb dataset", f"{dataset_folder}/LMDB_test")

        with self.env.begin(write=False) as txn:
            self.length = int(txn.get("length".encode("utf-8")).decode("utf-8"))

        self.resolution = resolution
        self.transform = transform

        # convert filename to celeba_hq index
        CelebA_HQ_to_CelebA = (
            f"{dataset_folder}/local_editing/CelebA-HQ-to-CelebA-mapping.txt"
        )
        CelebA_to_CelebA_HQ_dict = {}

        original_test_path = f"{dataset_folder}/raw_images/test/images"
        mask_label_path = f"{dataset_folder}/local_editing/GT_labels"

        with open(CelebA_HQ_to_CelebA, "r") as fp:
            read_line = fp.readline()
            attrs = re.sub(" +", " ", read_line).strip().split(" ")
            while True:
                read_line = fp.readline()

                if not read_line:
                    break

                idx, orig_idx, orig_file = (
                    re.sub(" +", " ", read_line).strip().split(" ")
                )

                CelebA_to_CelebA_HQ_dict[orig_file] = idx
        self.mask = []

        for filename in os.listdir(original_test_path):
            CelebA_HQ_filename = CelebA_to_CelebA_HQ_dict[filename]
            CelebA_HQ_filename = CelebA_HQ_filename + ".png"
            self.mask.append(os.path.join(mask_label_path, CelebA_HQ_filename))

在这里插入图片描述
在这里插入图片描述

  1. 返回group_pair_GT()函数,继续执行。
@torch.no_grad()
def group_pair_GT():
    ...
    parts_index = {
        "all": None,
        "background": [0],
        "skin": [1],
        "eyebrow": [6, 7],
        "eye": [3, 4, 5],
        "ear": [8, 9, 15],
        "nose": [2],
        "lip": [10, 11, 12],
        "neck": [16, 17],
        "cloth": [18],
        "hair": [13, 14],
    }

    indexes = range(args.n_sample)

    similarity_dict = {}
    parts = parts_index.keys()

    for part in parts:
        similarity_dict[part] = {}

    for src, ref in tqdm(
        itertools.combinations(indexes, 2),
        total=sum(1 for _ in itertools.combinations(indexes, 2)),
    ):
        _, mask1 = dataset[src]
        _, mask2 = dataset[ref]
        mask1 = mask1.to(device)
        mask2 = mask2.to(device)
        for part in parts:
            if part == "all":
                similarity = torch.sum(mask1 == mask2).item() / (images_size ** 2)
                similarity_dict["all"][src, ref] = similarity
            else:
                part1 = torch.zeros(
                    [images_size, images_size], dtype=torch.bool, device=device
                )
                part2 = torch.zeros(
                    [images_size, images_size], dtype=torch.bool, device=device
                )

                for p in parts_index[part]:
                    part1 = part1 | (mask1 == p)
                    part2 = part2 | (mask2 == p)

                intersection = (part1 & part2).sum().float().item()
                union = (part1 | part2).sum().float().item()
                if union == 0:
                    similarity_dict[part][src, ref] = 0.0
                else:
                    sim = intersection / union
                    similarity_dict[part][src, ref] = sim

    sorted_similarity = {}

    for part, similarities in similarity_dict.items():
        all_indexes = set(range(args.n_sample))
        sorted_similarity[part] = []

        sorted_list = sorted(similarities.items(), key=(lambda x: x[1]), reverse=True)

        for (i1, i2), prob in sorted_list:
            if (i1 in all_indexes) and (i2 in all_indexes):
                all_indexes -= {i1, i2}
                sorted_similarity[part].append(((i1, i2), prob))
            elif len(all_indexes) == 0:
                break

        assert len(sorted_similarity[part]) == args.n_sample // 2

    with open(
        f"{args.save_dir}/{args.dataset_name}_test_{args.mask_origin}_sorted_pair.pkl",
        "wb",
    ) as handle:
        pickle.dump(sorted_similarity, handle)

这里是引用

在这里插入图片描述

prepare_data.py调试

以lmdb的方式保存归一化后的图像和图像数量。

  1. 主程序执行,调用prepare(env, imgset, args.n_worker, sizes=sizes, resample=resample)函数。
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--out", type=str)
    parser.add_argument("--size", type=str, default="128,256,512,1024")
    parser.add_argument("--n_worker", type=int, default=1)
    parser.add_argument("--resample", type=str, default="bilinear")
    parser.add_argument("path", type=str)

    args = parser.parse_args()

    resample_map = {"lanczos": Image.LANCZOS, "bilinear": Image.BILINEAR}
    resample = resample_map[args.resample]

    sizes = [int(s.strip()) for s in args.size.split(",")]
    print(f"Make dataset of image sizes:", ", ".join(str(s) for s in sizes))

    imgset = datasets.ImageFolder(args.path)
    with lmdb.open(args.out, map_size=1024 ** 4, readahead=False) as env:
        prepare(env, imgset, args.n_worker, sizes=sizes, resample=resample)

ImageFolder是torchvision的函数,在读取时路径的设置需要注意,例如图像的路径为/data/test/raw_test/images/目录下的xxx1.jpg … xxxn.jpg,在传入ImageFolder函数时路径应为/data/test/raw_test,而不是/data/test/raw_test/images。

  1. 进入prepare(env, dataset, n_worker, sizes=(128, 256, 512, 1024), resample=Image.LANCZOS)函数。调用resize_worker函数。
def prepare(
    env, dataset, n_worker, sizes=(128, 256, 512, 1024), resample=Image.LANCZOS
):
    resize_fn = partial(resize_worker, sizes=sizes, resample=resample)
  1. 进入resize_worker函数,调用resize_multiple(img, sizes=sizes, resample=resample) 函数
def resize_worker(img_file, sizes, resample):
    i, file = img_file
    img = Image.open(file)
    img = img.convert("RGB")
    out = resize_multiple(img, sizes=sizes, resample=resample)

    return i, out
  1. 进入resize_multiple(img, sizes=(128, 256, 512, 1024), resample=Image.LANCZOS, quality=100)函数, 调用resize_and_convert(img, size, resample, quality)函数。
def resize_multiple(
    img, sizes=(128, 256, 512, 1024), resample=Image.LANCZOS, quality=100
):
    imgs = []

    for size in sizes:
        imgs.append(resize_and_convert(img, size, resample, quality))

    return imgs
  1. 进入resize_and_convert(img, size, resample, quality=100)函数。
def resize_and_convert(img, size, resample, quality=100):
    img = trans_fn.resize(img, (size, size), resample)
    # img = trans_fn.center_crop(img, size)
    buffer = BytesIO()
    img.save(buffer, format="jpeg", quality=quality)
    val = buffer.getvalue()

    return val
  1. 返回resize_multiple(img, sizes=(128, 256, 512, 1024), resample=Image.LANCZOS, quality=100)函数,返回resize_worker函数,返回prepare(env, dataset, n_worker, sizes=(128, 256, 512, 1024), resample=Image.LANCZOS)函数。
def prepare(
    env, dataset, n_worker, sizes=(128, 256, 512, 1024), resample=Image.LANCZOS
):
    ...
    files = sorted(dataset.imgs, key=lambda x: x[0])
    files = [(i, file) for i, (file, label) in enumerate(files)]
    total = 0

    with multiprocessing.Pool(n_worker) as pool:
        for i, imgs in tqdm(pool.imap_unordered(resize_fn, files)):
            for size, img in zip(sizes, imgs):
                key = f"{size}-{str(i).zfill(5)}".encode("utf-8")

                with env.begin(write=True) as txn:
                    txn.put(key, img)

            total += 1

        with env.begin(write=True) as txn:
            txn.put("length".encode("utf-8"), str(total).encode("utf-8"))

torch.nn.Parameter理解
详细介绍Python进度条tqdm的使用
Python itertools模块combinations方法

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值