代码属于简单易上手代码,环境安装好即可直接使用。
整个代码如下:
01demo是直接训练使用的,
02demo是用于qt界面的,然后调用01训练得到的模型进行可视化。
运行也是有说明文档的。
01train_cyclegan.py的前几行是配置参数,我们可以自行修改
import torch
import sys
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
from torchvision.utils import save_image
import random, torch, os, numpy as np
import torch.nn as nn
import copy
import torch
import albumentations as A
from albumentations.pytorch import ToTensorV2
from PIL import Image
import os
from torch.utils.data import Dataset
import numpy as np
#配置参数
DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"
TRAIN_DIR = "data/train"
VAL_DIR = "data/val"
BATCH_SIZE = 1
LEARNING_RATE = 1e-5