from imageio import imwrite
path='/1T/ysh/sg2im-master/models/checkpoint_with_model.pt'
checkpoint=torch.load(path)
train_samples = checkpoint['train_samples']#list 100 4 32*3*64*64
val_samples = checkpoint['val_samples']# tdimen = np.array(train_samples).shape# print('train dim',tdimen)# Save the generated images
output_dir='/1T/ysh/sg2im-master/outputs/samples/train'# for i in range(train_samples.shape[0]): #100for i inrange(1):#100
gt_img = train_samples[i]['gt_img']
gt_box_gt_mask = train_samples[i]['gt_box_gt_mask']
gt_img_dir=os.path.join(output_dir,'gt_img')
gbm_dir=os.path.join(output_dir,'gt_box_gt_mask')for j inrange(gt_img.shape[0]):#batch size 32
gt_imgs = gt_img[j].numpy().transpose(1,2,0)# 需要格式为(H,W,C)
img_path1 = os.path.join(gt_img_dir,'img%d_%d.png'%(i,j))
imwrite(img_path1, gt_imgs)
gt_box_gt_masks = gt_box_gt_mask[j].numpy().transpose(1,2,0)# 需要格式为(H,W,C)
img_path2 = os.path.join(gbm_dir,'img%d_%d.png'%(i,j))
imwrite(img_path2, gt_box_gt_masks)
拼接
import argparse
import os, sys
from os import path
import time
import copy
from torchvision.utils import save_image
from PIL import Image
import numpy as np
import random
import torch
from torch import nn
from torch.nn import functional as F
from torch.autograd import Variable
import torch.utils.data
import torch.utils.data.distributed
# from torchsummary import summaryimport shutil
import scipy.io as sio
classGenerator(nn.Module):def__init__(self, z_dim, nlabels, size, embed_size=256, nfilter=64,**kwargs):super().__init__()
s0 = self.s0 = size //32
nf = self.nf = nfilter
self.z_dim = z_dim
# Submodules
self.embedding = nn.Embedding(nlabels, embed_size)
self.fc = nn.Linear(z_dim + embed_size,16* nf * s0 * s0)
self.resnet_0_0 = ResnetBlock(16* nf,16* nf)
self.resnet_0_1 = ResnetBlock(16* nf,16* nf)
self.resnet_1_0 = ResnetBlock(16* nf,16* nf)
self.resnet_1_1 = ResnetBlock(16* nf,16* nf)
self.resnet_2_0 = ResnetBlock(16* nf,8* nf)
self.resnet_2_1 = ResnetBlock(8* nf,8* nf)
self.resnet_3_0 = ResnetBlock(8* nf,4* nf)
self.resnet_3_1 = ResnetBlock(4* nf,4* nf)
self.resnet_4_0 = ResnetBlock(4* nf,2* nf)
self.resnet_4_1 = ResnetBlock(2* nf,2* nf)
self.resnet_5_0 = ResnetBlock(2* nf,1* nf)
self.resnet_5_1 = ResnetBlock(1* nf,1* nf)
self.conv_img = nn.Conv2d(nf,3,3, padding=1)defforward(self, x, layer=0):
out=x
if layer >=3:print(3)
out = F.interpolate(out, scale_factor=2)
out = self.resnet_3_0(out)
out = self.resnet_3_1(out)if layer >=2:print(2)
out = F.interpolate(out, scale_factor=2)
out = self.resnet_4_0(out)
out = self.resnet_4_1(out)if layer >=1:print(1)
out = F.interpolate(out, scale_factor=2)
out = self.resnet_5_0(out)
out = self.resnet_5_1(out)print(out.shape)
out = self.conv_img(actvn(out))#b*64*128*128
out = torch.tanh(out)return out
classResnetBlock(nn.Module):def__init__(self, fin, fout, fhidden=None, is_bias=True):super().__init__()# Attributes
self.is_bias = is_bias
self.learned_shortcut =(fin != fout)
self.fin = fin
self.fout = fout
if fhidden isNone:
self.fhidden =min(fin, fout)else:
self.fhidden = fhidden
# Submodules
self.conv_0 = nn.Conv2d(self.fin, self.fhidden,3, stride=1, padding=1)
self.conv_1 = nn.Conv2d(self.fhidden, self.fout,3, stride=1, padding=1, bias=is_bias)if self.learned_shortcut:
self.conv_s = nn.Conv2d(self.fin, self.fout,1, stride=1, padding=0, bias=False)defforward(self, x):
x_s = self._shortcut(x)
dx = self.conv_0(actvn(x))
dx = self.conv_1(actvn(dx))
out = x_s +0.1* dx
return out
def_shortcut(self, x):if self.learned_shortcut:
x_s = self.conv_s(x)else:
x_s = x
return x_s
defactvn(x):
out = F.leaky_relu(x,2e-1)return out
defseed_torch(seed=1029):
random.seed(seed)
os.environ['PYTHONHASHSEED']=str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)# if you are using multi-GPU.
torch.backends.cudnn.benchmark =False
torch.backends.cudnn.deterministic =True
seed_torch(999)from gan_training import utils
from gan_training.train import Trainer, update_average
from gan_training.toggle_ImageNet import toggle_grad_G, toggle_grad_D, model_equal_part_G, model_equal_part_D
from gan_training.logger import Logger
from gan_training.checkpoints import CheckpointIO
from gan_training.inputs import get_dataset
from gan_training.distributions import get_ydist, get_zdist
from gan_training.evalimport Evaluator
from gan_training.config import(
load_config, build_models, build_optimizers, build_lr_scheduler, build_models_PRE,)''' ===================--- Set the traning mode ---==========================
DATA: going to train
DATA_FIX: used as a fixed pre-trained model
G_Layer_FIX, D_Layer_FIX: number of layers to fix
============================================================================='''
DATA ='CELEBA'
DATA_FIX ='ImageNet'
config_path ='./configs/celebA.yaml'# you need to change it
main_path ='./'# you need to change it
image_path ="./data/CelebA/"# you need to change it
image_test ="./data/CelebA/"# you need to change it
out_path = main_path +'/results_celebA_GmPn/'# you need to change it
load_dir ='./pretrained_model/'# put the model pretrained on ImageNet here
img_PIL = Image.open('img000000.png')# 读取数据# 将图片转换成np.ndarray格式
img = np.array(img_PIL)
img = img /255# 数据归一化
img = torch.from_numpy(img)
img = torch.tensor(img, dtype=torch.float32)# 数据类型转化print(img.shape)# 128*128*3
img = torch.unsqueeze(img,0)# 1*128*128*3print(img.shape)
img = torch.transpose(img,1,3)
img = torch.transpose(img,2,3)# 1*3*128*128print(img.shape)
model = Generator(256,1000,128)# 构建的新模型 #z_dim, nlabels, size
model_dict = model.state_dict()for k, v in model_dict.items():print('key:',k,'value:',v)
pretrained_dict = torch.load(load_dir + DATA_FIX +'Pre_generator')# pretrained_dict = torch.load('ImageNetPre_generator')for k, v in pretrained_dict.items():print('key:',k,'value:',v)
pretrained_dict ={k: v for k, v in pretrained_dict.items()if k in model_dict}# 取出预训练模型中与新模型的dict中重合的部分
model_dict.update(pretrained_dict)# 用预训练模型参数更新new_model中的部分参数
model.load_state_dict(model_dict)# 将更新后的model_dict加载进new model中for layer inrange(4):# layer=0 只用卷积层 =1,用group1和2,=2,用group1,2,3 ,=3,用group1,2,3,4# =3,输入shape=[b,8*64,16,16]# =2,输入shape=[b,4*64,32,32]# =1,输入shape=[b,2*64,64,64]# =0,输入shape=[b,1*64,128,128]print('layer:', layer)
conv = nn.Conv2d(3,2** layer *64,3, padding=1)
img0 = conv(img)
img0 = F.interpolate(img0,[128//2** layer,128//2** layer])print(img0.shape)
out = model(img0,layer)# print(out.shape)
out = out *255
out = out.squeeze(0)# print(out.shape)
save_image(out,'pj_test'+str(layer)+'.png')