1.首先按照github上的issue,将model.py修改一下。
Line15:
nn.Conv2d(in_channels=args.img_channel, out_channels=64, kernel_size=(3, 3), stride=2, padding=1),
Line30:
nn.ConvTranspose2d(in_channels=64, out_channels=args.img_channel, kernel_size=(3, 3), stride=2, padding=1, output_padding=1))
Line100:
nn.Conv3d(3, 64, kernel_size=(3, 3, 3), padding=(1, 1, 1)),
(其实就是把输入输出的通道数修改一下)
2.修改之后会报通道数不匹配问题,于是修改读文件夹的py文件
一开始修改如下
但是仍然报通道数不匹配:
(报的类似错误,只是expected to have 3 channels, but got 1 channels instead,但是我非常确定我放进去的图像是三通道的,所以继续修改代码)
检查之后发现应该将imread读到的数组修改一下,修改代码如下:
Print出来的frame的shape是(3,128,128)
(这一步其实是有问题的,后面会介绍)
3.不会继续报通道不匹配问题了,但是新的问题出现
(忘记保存问题截图,大致意思是说channel数要么是1,要么是3,要么是4,这是最难的问题)
解决问题的关键就是数据格式,我的数据是[N, T, C, H, W](作者把数组扩到了五维,实际每一帧就是后面三个维度组成),但是一张正常的图片应该是[H, W, C],于是把test.py中需要输出的数据格式进行transpose,
short_data = short_data.transpose(dim0=2, dim1=3) # 将(C,H,W)转为(H,W,C)
short_data = short_data.transpose(dim0=3, dim1=4)
out_pred = out_pred.transpose(dim0=2, dim1=3)
out_pred = out_pred.transpose(dim0=3, dim1=4)
(其实两行代码就可以替代,但是最开始没想到)
Imwrite函数也进行修改
从
cv2.imwrite(args.test_result_dir + '/video_' + str(video_i) + '_' + str(frame_start) + '/pred_' +
str(frame_i + args.short_len).zfill(5) + '.jpg',out_pred[batch_i, frame_i, 0, :, :].cpu().numpy() * 255)
修改为:
cv2.imwrite(args.test_result_dir + '/video_' + str(video_i) + '_' + str(frame_start) + '/pred_' +
str(frame_i + args.short_len).zfill(5) + '.jpg',out_pred[batch_i, frame_i, :, :, :].cpu().numpy() * 255)
修改完之后数据格式变为了[N, T, H, W, C],这样就是正常的格式,但是输出的图片很奇怪,虽然成功变为了三通道24位深度图片,但是得到的图片是乱的,得到的图如下:
4.猜测是dataloder的问题,重新定位到读数据集的代码
也就是(2)中的地方,猜测是reshape的问题,查询之后发现果然是reshape,reshape会把数据打乱再reshape,所以得到的图片是错误的,于是不用reshape,改为用transpose
frame = frame.transpose(2,0,1)
也是进行[C,H,W]转换为[H,W,C],但是transpose不会打乱,而是直接进行维度上的抓换。
再进行测试,发现输出的图是正常的。
The end.