1.数据集
自建数据集:1300张左右,训练集1100张左右,测试集200张。
数据处理
由于数据集大小不同,有正方形的有长方形,小的几百k,大的10+M。
所以对数据集进行resize统一处理为512x512,256x256的细节丢失太多,生成效果也差。
我倾向于保留保证的图像,不裁剪不变形,所以选择了填充法。
(以下参考自知乎https://www.zhihu.com/question/360010590/answer/1697449851)
def expand2square(pil_img, background_color):
width, height = pil_img.size
if width == height:
return pil_img
elif width > height:
result = Image.new(pil_img.mode, (width, width), background_color)
result.paste(pil_img, (0, (width - height) // 2))
return result
else:
result = Image.new(pil_img.mode, (height, height), background_color)
result.paste(pil_img, ((height - width) // 2, 0))
return result
由于我的数据集是带透明背景的四通道图像,生成时会会自动转为黑色边缘处理的不好,我更希望是将透明背景改为白色背景。所以在以上代码中加入了RGBA转为RGB的代码。
# -*- coding = utf-8 -*-
# @Author: KKeria
# @Time: 2023/9/27 15:11
# @File: alpha2rgb.py
# @Software: PyCharm
from PIL import Image, ImageChops
import os
import numpy as np
import cv2
def expand2square(pil_img, background_color):
width, height = pil_img.size
if width == height:
return pil_img
elif width > height:
result = Image.new(pil_img.mode, (width, width), background_color)
result.paste(pil_img, (0, (width - height) // 2))
return result
else:
result = Image.new(pil_img.mode, (height, height), background_color)
result.paste(pil_img, ((height - width) // 2, 0))
return result
# 设置输入和输出文件夹
input_folder = '/Users/源码/pytorch-CycleGAN-and-pix2pix-master/datasets/pattern2RE/trainB'
output_folder = '/Users/源码/pytorch-CycleGAN-and-pix2pix-master/datasets/pattern2RE/trainB1'
# 确保输出文件夹存在
if not os.path.exists(output_folder):
os.makedirs(output_folder)
# 定义目标大小
target_size = (512, 512)
filename = sorted([file for file in os.listdir(input_folder) if file.endswith('.png')])
# 批量处理图像
for filename in os.listdir(input_folder):
if filename.endswith('.png') and not filename.startswith('._'):
input_path = os.path.join(input_folder, filename)
output_path = os.path.join(output_folder, filename)
# 打开图像
with Image.open(input_path) as img:
# 将透明像素替换为白色像素
# 创建一个白色背景图像
background = Image.new('RGBA', img.size, (255, 255, 255, 255))
# 使用ImageChops模块中的函数将图像放置在白色背景上
img = ImageChops.composite(img, background, img)
# 将图像转换为RGB模式
expand2square(img, (255, 255, 255)).resize(target_size, Image.LANCZOS).save(output_path)
未更改尺寸与背景前的图像
更改尺寸与背景后的图像
补充:两个文件夹图像同时重命名
import os
class BatchRename():
def __init__(self):
# 我的图片文件夹路径
self.path1 = '/Users/源码/pytorch-CycleGAN-and-pix2pix-master/datasets/pattern2RE/trainA的副本'
self.path2 = '/Users/源码/pytorch-CycleGAN-and-pix2pix-master/datasets/pattern2RE/trainB的副本'
def rename(self):
filelist = os.listdir(self.path1)
total_num = len(filelist)
print(total_num)
i = 0 # 图片编号从多少开始
for item in filelist:
src1 = os.path.join(os.path.abspath(self.path1), item)
dst1 = os.path.join(os.path.abspath(self.path1), 'p_' + str(i) + '.png')
src2 = os.path.join(os.path.abspath(self.path2), item)
dst2 = os.path.join(os.path.abspath(self.path2), 'p_' + str(i) + '.png')
try:
os.rename(src1, dst1)
os.rename(src2, dst2)
print('converting %s to %s ...' % (src1, dst1))
i = i + 1
except:
continue
print("total %d to rename & converted %d pngs" % (total_num, i))
if __name__ == '__main__':
demo = BatchRename()
demo.rename()
最后将数据集分为trainA、trainB、testA、testB放在项目文件的datasets文件夹中
2.训练自己的数据集
云服务器
可以用矩池云或者misgpu
如果要跑512x512的图像,数量级差不多的,建议用矩池云的3090,要24G显存起,16G的显存不够用。
镜像选择pytorch1.11后直接下单
租用后打开JupyterLab
双击terminal后打开mnt文件夹
cd到mnt
克隆项目:
git clone https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix.git
克隆成功后左边会生成对应的文件夹
cd到项目文件夹
安装包
pip install dominate visdom wandb
更改尺寸参数
找到options文件夹中的base_options.py文件,修改以下三个参数为512
开始训练数据集
python train.py --dataroot ./datasets/数据集的文件夹名字 --name 预训练模型或放预训练数据文件夹的名字 --model cycle_gan --display_id -1
示例:
python train.py --dataroot ./datasets/Pattern2RE --name p2re --model cycle_gan --display_id -1
(display_id是指不启用visdom的可视化功能,本地跑的不加这一项)
训练时间: 一共200个epoch,batchsize=1,总耗时接近24小时。
3.测试
dataroot换为自己testA的路径,name与train的name参数保持一致
python test.py --dataroot datasets/pattern2RE/testA --name p2re --model test --no_dropout
如果提示找不到latest_net_G,则还需要把checkpoint中上面的name参数指向的文件夹里的latest_net_G_A.pth重命名为lastest_net_G.pth。
完美,训练结果在checkpoints里看,测试结果在results里查看。