错误一
CheckpointIO(ospj(args.checkpoint_dir, '{:06d}_nets.ckpt'), data_parallel=True, **self.nets)
报错信息为:FileNotFoundError: [WinError 3] 系统找不到指定的路径。: '{:'
报错代码
self.ckptios = [
CheckpointIO(ospj(args.checkpoint_dir, '{:06d}_nets.ckpt'), data_parallel=True, **self.nets),
CheckpointIO(ospj(args.checkpoint_dir, '{:06d}_nets_ema.ckpt'), data_parallel=True, **self.nets_ema),
CheckpointIO(ospj(args.checkpoint_dir, '{:06d}_optims.ckpt'), **self.optims)]
查看了源代码,ospj正常输入应该为args.checkpoint_dir/{:06d}_nets.ckpt,但是实际输出为{:06d}_nets.ckpt。导致os.makedirs(os.path.dirname(fname_template), exist_ok=True)运行出错,不能正常创建指定文件夹。
class CheckpointIO(object):
def __init__(self, fname_template, data_parallel=False, **kwargs):
os.makedirs(os.path.dirname(fname_template), exist_ok=True)
self.fname_template = fname_template
self.module_dict = kwargs
self.data_parallel = data_parallel
改正一
self.ckptios = [
CheckpointIO(ospj(args.checkpoint_dir, 'nets_{:06d}.ckpt'), data_parallel=True, **self.nets),
CheckpointIO(ospj(args.checkpoint_dir, 'nets_ema_{:06d}.ckpt'), data_parallel=True, **self.nets_ema),
CheckpointIO(ospj(args.checkpoint_dir, 'optims_{:06d}.ckpt'), **self.optims)]
错误解决
报错二
报错信息为x, y = next(self.iter),AttributeError: 'InputFetcher' object has no attribute 'iter'
这是没有在__init__中加载self.iter
改正二
在data_loader.py中的InputFetcher类修改为
class InputFetcher:
def __init__(self, loader, loader_ref=None, latent_dim=16, mode=''):
self.loader = loader
self.loader_ref = loader_ref
self.latent_dim = latent_dim
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.mode = mode
self.iter = iter(self.loader) # 初始化 self.iter
def _fetch_inputs(self):
try:
x, y = next(self.iter)
except (AttributeError, StopIteration):
self.iter = iter(self.loader)
x, y = next(self.iter)
return x, y
def _fetch_refs(self):
try:
x, x2, y = next(self.iter_ref)
except (AttributeError, StopIteration):
self.iter_ref = iter(self.loader_ref)
x, x2, y = next(self.iter_ref)
return x, x2, y
def __next__(self):
x, y = self._fetch_inputs()
if self.mode == 'train':
x_ref, x_ref2, y_ref = self._fetch_refs()
z_trg = torch.randn(x.size(0), self.latent_dim)
z_trg2 = torch.randn(x.size(0), self.latent_dim)
inputs = Munch(x_src=x, y_src=y, y_ref=y_ref,
x_ref=x_ref, x_ref2=x_ref2,
z_trg=z_trg, z_trg2=z_trg2)
elif self.mode == 'val':
x_ref, y_ref = self._fetch_inputs()
inputs = Munch(x_src=x, y_src=y,
x_ref=x_ref, y_ref=y_ref)
elif self.mode == 'test':
inputs = Munch(x=x, y=y)
else:
raise NotImplementedError
return Munch({k: v.to(self.device)
for k, v in inputs.items()})
# 使其兼容 for-in 语法
def __iter__(self):
return self
问题解决
报错三
AttributeError: Can’t pickle local object “get_train_loader< locals><\lambda>
参考如下博文
https://blog.csdn.net/genous110/article/details/115474244
改正三
在data_loader.py中的get_train_loader()函数附近添加如下代码
# 定义一个常规函数来替代 lambda
def random_crop_transform(x, crop_fn, probability):
if random.random() < probability:
return crop_fn(x)
else:
return x
# 直接使用函数进行变换
class RandomCropTransform:
def __init__(self, crop_fn, probability):
self.crop_fn = crop_fn
self.probability = probability
def __call__(self, x):
return random_crop_transform(x, self.crop_fn, self.probability)
然后将get_train_loader()函数中的rand_crop变量修改为
rand_crop = RandomCropTransform(crop, prob)
问题解决