由于后续训练和预测用到的标签格式需要是一维数据,而DermaMNIST类读取到的DermaMNIST标签数据是二维数据,所以需要采取措施让标签数据变为一维的。
train_dataset = DermaMNIST(split="train", transform=transform_train, download=True, size=224, root = '../../dermamnist')
val_dataset = DermaMNIST(split="val", transform=transform_val, download=True, size=224, root = '../../dermamnist')
# 将标签转换为一维数组
train_dataset.labels = train_dataset.labels.reshape(-1)
val_dataset.labels = val_dataset.labels.reshape(-1)