可视化-谢韦尔钢铁公司带钢缺陷数据集的标注(标注转图像保存)
import numpy as np
import pandas as pd
from torchvision import transforms
def rle2mask(rle, imgshape):
width = imgshape[0]
height = imgshape[1]
mask = np.zeros(width * height).astype(np.uint8)
array = np.asarray([int(x) for x in rle.split()])
starts = array[0::2]
lengths = array[1::2]
current_position = 0
for index, start in enumerate(starts):
mask[int(start):int(start + lengths[index])] = 1
current_position += lengths[index]
return np.flipud(np.rot90(mask.reshape(height, width), k=1))
tr = pd.read_csv(r"D:\BaiduPan\钢铁类相关缺陷数据集\谢韦尔钢铁公司带钢缺陷数据集\train.csv")
print(len(tr))
tr.head()
df_train = tr[tr['EncodedPixels'].notnull()].reset_index(drop=True)
print(len(df_train))
print(df_train.head())
for index in range(len(df_train)):
ImageId_ClassId = df_train['ImageId_ClassId'].iloc[index]
maskName = ImageId_ClassId.split("_")[0]
mask = rle2mask(df_train['EncodedPixels'].iloc[index], (256, 1600))
mask = transforms.ToPILImage()(mask)
mask = mask.point(lambda x: 255 if x >0 else 0)
mask.show()
mask.save(maskName)