Sg2im拼接

从pt文件中生成图像

from imageio import imwrite

path='/1T/ysh/sg2im-master/models/checkpoint_with_model.pt'
checkpoint=torch.load(path)

train_samples = checkpoint['train_samples'] #list 100 4 32*3*64*64
val_samples = checkpoint['val_samples']

# tdimen = np.array(train_samples).shape
# print('train dim',tdimen)

# Save the generated images
output_dir='/1T/ysh/sg2im-master/outputs/samples/train'
# for i in range(train_samples.shape[0]): #100
for i in range(1): #100
    gt_img = train_samples[i]['gt_img']
    gt_box_gt_mask = train_samples[i]['gt_box_gt_mask']
    gt_img_dir=os.path.join(output_dir, 'gt_img')
    gbm_dir=os.path.join(output_dir, 'gt_box_gt_mask')
    for j in range(gt_img.shape[0]): #batch size 32
        gt_imgs = gt_img[j].numpy().transpose(1, 2, 0)
        # 需要格式为(H,W,C)
        img_path1 = os.path.join(gt_img_dir, 'img%d_%d.png' % (i,j))
        imwrite(img_path1, gt_imgs)

        gt_box_gt_masks = gt_box_gt_mask[j].numpy().transpose(1, 2, 0)
        # 需要格式为(H,W,C)
        img_path2 = os.path.join(gbm_dir, 'img%d_%d.png' % (i,j))
        imwrite(img_path2, gt_box_gt_masks)

拼接

import argparse
import os, sys
from os import path
import time
import copy
from torchvision.utils import save_image
from PIL import Image

import numpy as np
import random
import torch
from torch import nn
from torch.nn import functional as F

from torch.autograd import Variable
import torch.utils.data
import torch.utils.data.distributed
# from torchsummary import summary
import shutil
import scipy.io as sio
class Generator(nn.Module):
    def __init__(self, z_dim, nlabels, size, embed_size=256, nfilter=64, **kwargs):
        super().__init__()
        s0 = self.s0 = size // 32
        nf = self.nf = nfilter
        self.z_dim = z_dim

        # Submodules
        self.embedding = nn.Embedding(nlabels, embed_size)
        self.fc = nn.Linear(z_dim + embed_size, 16 * nf * s0 * s0)

        self.resnet_0_0 = ResnetBlock(16 * nf, 16 * nf)
        self.resnet_0_1 = ResnetBlock(16 * nf, 16 * nf)

        self.resnet_1_0 = ResnetBlock(16 * nf, 16 * nf)
        self.resnet_1_1 = ResnetBlock(16 * nf, 16 * nf)

        self.resnet_2_0 = ResnetBlock(16 * nf, 8 * nf)
        self.resnet_2_1 = ResnetBlock(8 * nf, 8 * nf)

        self.resnet_3_0 = ResnetBlock(8 * nf, 4 * nf)
        self.resnet_3_1 = ResnetBlock(4 * nf, 4 * nf)

        self.resnet_4_0 = ResnetBlock(4 * nf, 2 * nf)
        self.resnet_4_1 = ResnetBlock(2 * nf, 2 * nf)

        self.resnet_5_0 = ResnetBlock(2 * nf, 1 * nf)
        self.resnet_5_1 = ResnetBlock(1 * nf, 1 * nf)

        self.conv_img = nn.Conv2d(nf, 3, 3, padding=1)

    def forward(self, x, layer=0):
        out=x
        if layer >= 3:
            print(3)
            out = F.interpolate(out, scale_factor=2)
            out = self.resnet_3_0(out)
            out = self.resnet_3_1(out)
        if layer >= 2:
            print(2)
            out = F.interpolate(out, scale_factor=2)
            out = self.resnet_4_0(out)
            out = self.resnet_4_1(out)
        if layer >= 1:
            print(1)
            out = F.interpolate(out, scale_factor=2)
            out = self.resnet_5_0(out)
            out = self.resnet_5_1(out)
        print(out.shape)
        out = self.conv_img(actvn(out))#b*64*128*128
        out = torch.tanh(out)

        return out


class ResnetBlock(nn.Module):
    def __init__(self, fin, fout, fhidden=None, is_bias=True):
        super().__init__()
        # Attributes
        self.is_bias = is_bias
        self.learned_shortcut = (fin != fout)
        self.fin = fin
        self.fout = fout
        if fhidden is None:
            self.fhidden = min(fin, fout)
        else:
            self.fhidden = fhidden

        # Submodules
        self.conv_0 = nn.Conv2d(self.fin, self.fhidden, 3, stride=1, padding=1)
        self.conv_1 = nn.Conv2d(self.fhidden, self.fout, 3, stride=1, padding=1, bias=is_bias)
        if self.learned_shortcut:
            self.conv_s = nn.Conv2d(self.fin, self.fout, 1, stride=1, padding=0, bias=False)

    def forward(self, x):
        x_s = self._shortcut(x)
        dx = self.conv_0(actvn(x))
        dx = self.conv_1(actvn(dx))
        out = x_s + 0.1 * dx

        return out

    def _shortcut(self, x):
        if self.learned_shortcut:
            x_s = self.conv_s(x)
        else:
            x_s = x
        return x_s


def actvn(x):
    out = F.leaky_relu(x, 2e-1)
    return out


def seed_torch(seed=1029):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  # if you are using multi-GPU.
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True


seed_torch(999)

from gan_training import utils
from gan_training.train import Trainer, update_average
from gan_training.toggle_ImageNet import toggle_grad_G, toggle_grad_D, model_equal_part_G, model_equal_part_D
from gan_training.logger import Logger
from gan_training.checkpoints import CheckpointIO
from gan_training.inputs import get_dataset
from gan_training.distributions import get_ydist, get_zdist
from gan_training.eval import Evaluator
from gan_training.config import (
    load_config, build_models, build_optimizers, build_lr_scheduler, build_models_PRE,
)

''' ===================--- Set the traning mode ---==========================
DATA: going to train
DATA_FIX: used as a fixed pre-trained model
G_Layer_FIX, D_Layer_FIX: number of layers to fix
============================================================================='''
DATA = 'CELEBA'
DATA_FIX = 'ImageNet'

config_path = './configs/celebA.yaml'  # you need to change it
main_path = './'  # you need to change it
image_path = "./data/CelebA/"  # you need to change it
image_test = "./data/CelebA/"  # you need to change it
out_path = main_path + '/results_celebA_GmPn/'  # you need to change it
load_dir = './pretrained_model/'  # put the model pretrained on ImageNet here

img_PIL = Image.open('img000000.png')  # 读取数据

# 将图片转换成np.ndarray格式
img = np.array(img_PIL)
img = img / 255  # 数据归一化
img = torch.from_numpy(img)
img = torch.tensor(img, dtype=torch.float32)  # 数据类型转化
print(img.shape)  # 128*128*3
img = torch.unsqueeze(img, 0)  # 1*128*128*3
print(img.shape)
img = torch.transpose(img, 1, 3)
img = torch.transpose(img, 2, 3)  # 1*3*128*128
print(img.shape)

model = Generator(256, 1000, 128)  # 构建的新模型 
#z_dim, nlabels, size
model_dict = model.state_dict()

for k, v in model_dict.items():
    print('key:',k,'value:',v)

pretrained_dict = torch.load(load_dir + DATA_FIX + 'Pre_generator')
# pretrained_dict = torch.load('ImageNetPre_generator')

for k, v in pretrained_dict.items():
    print('key:',k,'value:',v)

pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
# 取出预训练模型中与新模型的dict中重合的部分

model_dict.update(pretrained_dict)  # 用预训练模型参数更新new_model中的部分参数
model.load_state_dict(model_dict)  # 将更新后的model_dict加载进new model中

for layer in range(4):
    # layer=0 只用卷积层 =1,用group1和2,=2,用group1,2,3 ,=3,用group1,2,3,4
    # =3,输入shape=[b,8*64,16,16]
    # =2,输入shape=[b,4*64,32,32]
    # =1,输入shape=[b,2*64,64,64]
    # =0,输入shape=[b,1*64,128,128]
    print('layer:', layer)
    conv = nn.Conv2d(3, 2 ** layer * 64, 3, padding=1)

    img0 = conv(img)
    img0 = F.interpolate(img0, [128 // 2 ** layer, 128 // 2 ** layer])
    print(img0.shape)
    out = model(img0,layer)
    # print(out.shape)

    out = out * 255
    out = out.squeeze(0)
    # print(out.shape)

    save_image(out, 'pj_test' + str(layer) + '.png')


在这里插入图片描述

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值