读取tif文件打印格式不对torch.Size([1, 256, 3, 256]),
torch.Size([1, 256, 3, 256]) torch.Size([1, 256, 3, 256]) tensor(0.) tensor(1.) tensor(0.) tensor(1.)
正常的应该是torch.Size([1, 3, 256, 256]) ,需要进行一下转换,转换为RGB格式:
注:原始文件是256×256
if image_fp.endswith('.tif') or image_fp.endswith('.tiff'):
X, Y = rio.open(image_fp).read(), rio.open(label_fp).read()
# 转换为RGB格式
X = Image.fromarray(X.transpose(1, 2, 0)) # 转置通道顺序
X = X.convert('RGB')
# 转换为L模式
Y = Image.fromarray(Y[0]) # 使用第一个通道
Y = Y.convert('L')
X, Y = np.array(X) / 255.0, np.array(Y) / 255.0 # 因为to_tensor接受类型为float所以为255.0
flag = 'remote'
修改结果
torch.Size([1, 3, 256, 256]) torch.Size([1, 1, 256, 256]) tensor(0.) tensor(1.) tensor(0.) tensor(1.)
torch.Size([1, 3, 256, 256]) torch.Size([1, 1, 256, 256]) tensor(0.) tensor(1.) tensor(0.) tensor(1.)
torch.Size([1, 3, 256, 256]) torch.Size([1, 1, 256, 256]) tensor(0.) tensor(1.) tensor(0.) tensor(1.)