用TinyImageNet数据集进行图像分类实验,test精度极低

博客内容讲述了在使用TinyImageNet数据集时遇到的问题,即直接使用`datasets.ImageFolder`导入val验证集导致精度极低,且test数据集无标注。为解决此问题,提出了自定义`TinyImageNet_load`类来正确处理数据集,包括读取train和val目录,创建类别索引字典,并提供相应的数据加载方法。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

错误原因:

TinyImageNet数据集的val验证集不能直接用datasets.ImageFolder导入,直接使用的话精度只有零点几。而且test数据集是没有标注的。

错误示例:

trainset = datasets.ImageFolder(root=os.path.join(data_dir, data, 'tiny-imagenet-200/train'), transform=transform_train)
testset = datasets.ImageFolder(root=os.path.join(data_dir, data, 'tiny-imagenet-200/test'), transform=transform_test)

解决方案:

使用了已存在的代码:

from torch.utils.data import Dataset, DataLoader
from torchvision import models, utils, datasets, transforms
import numpy as np
import sys
import os
from PIL import Image


class TinyImageNet_load(Dataset):
    def __init__(self, root, train=True, transform=None):
        self.Train = train
        self.root_dir = root
        self.transform = transform
        self.train_dir = os.path.join(self.root_dir, "train")
        self.val_dir = os.path.join(self.root_dir, "val")

        if (self.Train):
            self._create_class_idx_dict_train()
        else:
            self._create_class_idx_dict_val()

        self._make_dataset(self.Train)

        words_file = os.path.join(self.root_dir, "words.txt")
        wnids_file = os.path.join(self.root_dir, "wnids.txt")

        self.set_nids = set()

        with open(wnids_file, 'r') as fo:
            data = fo.readlines()
            for entry in data:
                self.set_nids.add(entry.strip("\n"))

        self.class_to_label = {}
        with open(words_file, 'r') as fo:
            data = fo.readlines()
            for entry in data:
                words = entry.split("\t")
                if words[0] in self.set_nids:
                    self.class_to_label[words[0]] = (words[1].strip("\n").split(","))[0]

    def _create_class_idx_dict_train(self):
        if sys.version_info >= (3, 5):
            classes = [d.name for d in os.scandir(self.train_dir) if d.is_dir()]
        else:
            classes = [d for d in os.listdir(self.train_dir) if os.path.isdir(os.path.join(self.train_dir, d))]
        classes = sorted(classes)
        num_images = 0
        for root, dirs, files in os.walk(self.train_dir):
            for f in files:
                if f.endswith(".JPEG"):
                    num_images = num_images + 1

        self.len_dataset = num_images;

        self.tgt_idx_to_class = {i: classes[i] for i in range(len(classes))}
        self.class_to_tgt_idx = {classes[i]: i for i in range(len(classes))}

    def _create_class_idx_dict_val(self):
        val_image_dir = os.path.join(self.val_dir, "images")
        if sys.version_info >= (3, 5):
            images = [d.name for d in os.scandir(val_image_dir) if d.is_file()]
        else:
            images = [d for d in os.listdir(val_image_dir) if os.path.isfile(os.path.join(self.train_dir, d))]
        val_annotations_file = os.path.join(self.val_dir, "val_annotations.txt")
        self.val_img_to_class = {}
        set_of_classes = set()
        with open(val_annotations_file, 'r') as fo:
            entry = fo.readlines()
            for data in entry:
                words = data.split("\t")
                self.val_img_to_class[words[0]] = words[1]
                set_of_classes.add(words[1])

        self.len_dataset = len(list(self.val_img_to_class.keys()))
        classes = sorted(list(set_of_classes))
        # self.idx_to_class = {i:self.val_img_to_class[images[i]] for i in range(len(images))}
        self.class_to_tgt_idx = {classes[i]: i for i in range(len(classes))}
        self.tgt_idx_to_class = {i: classes[i] for i in range(len(classes))}

    def _make_dataset(self, Train=True):
        self.images = []
        if Train:
            img_root_dir = self.train_dir
            list_of_dirs = [target for target in self.class_to_tgt_idx.keys()]
        else:
            img_root_dir = self.val_dir
            list_of_dirs = ["images"]

        for tgt in list_of_dirs:
            dirs = os.path.join(img_root_dir, tgt)
            if not os.path.isdir(dirs):
                continue

            for root, _, files in sorted(os.walk(dirs)):
                for fname in sorted(files):
                    if (fname.endswith(".JPEG")):
                        path = os.path.join(root, fname)
                        if Train:
                            item = (path, self.class_to_tgt_idx[tgt])
                        else:
                            item = (path, self.class_to_tgt_idx[self.val_img_to_class[fname]])
                        self.images.append(item)

    def return_label(self, idx):
        return [self.class_to_label[self.tgt_idx_to_class[i.item()]] for i in idx]

    def __len__(self):
        return self.len_dataset

    def __getitem__(self, idx):
        img_path, tgt = self.images[idx]
        with open(img_path, 'rb') as f:
            sample = Image.open(img_path)
            sample = sample.convert('RGB')
        if self.transform is not None:
            sample = self.transform(sample)

        return sample, tgt

调用示例:

trainset = datasets.ImageFolder(root=os.path.join(data_dir, data, 'tiny-imagenet-200/train'), transform=transform_train)

testset= TinyImageNet.TinyImageNet_load('./TINY/tiny-imagenet-200/', train=False, transform=transform_test)
### Tiny ImageNet 数据集介绍 Tiny ImageNet 是斯坦福大学提供的图像分类数据集,该数据集中包含200个类别,每个类别包含500张训练图像,50张验证图像及50张测试图像[^4]。此数据集旨在作为大规模ImageNet数据集的一个较小版本,适合用于资源有限情况下的研究和学习。 #### 训练、验证与测试集特点 - **训练集**:遵循标准的`ImageFolder`格式,即不同类别的图片分别存放在各自的文件夹内。 - **验证集(val)**:所有图片位于同一目录下,并通过名为`val_annotations.txt`的标签文件指定其所属类别。这种结构不同于常见的按类别分文件夹的方式,因此直接应用PyTorch中的`datasets.ImageFolder`函数可能导致获取到错误的标签信息,进而影响模型性能评估的结果准确性[^3]。 - **测试集(test)**:未提供具体标签,在实际操作中通常仅限于提交预测结果给官方平台评测使用;本地开发阶段可利用验证集代替完成超参数调整等工作流程优化任务。 ### 使用方法示例 为了正确处理上述提到的不同集合之间的差异,下面给出一段Python代码片段展示如何读取并预处理这些数据: ```python import os from PIL import Image import pandas as pd from torchvision.datasets.folder import default_loader, IMG_EXTENSIONS class CustomDataset(object): def __init__(self, root_dir='path/to/dataset', transform=None): self.root = root_dir self.transform = transform val_anno_file = os.path.join(self.root,'val','val_annotations.txt') df_val = pd.read_csv(val_anno_file, sep='\t', names=['filename', 'label']) # 构建字典映射关系方便后续查找 label_map = {row['filename']: row['label'] for idx,row in df_val.iterrows()} images_paths = [] labels = [] for ext in IMG_EXTENSIONS: files = glob.glob(os.path.join(root_dir,'val/images/*'+ext)) for f in files: fname = os.path.basename(f) if fname in label_map.keys(): images_paths.append(f) labels.append(label_map[fname]) self.samples = list(zip(images_paths,labels)) def __getitem__(self,index): path,label=self.samples[index] img=default_loader(path) if self.transform is not None: img=self.transform(img) return img,label def __len__(self): return len(self.samples) dataset=CustomDataset() print('Total samples:',len(dataset)) ``` 这段脚本定义了一个自定义的数据加载器`CustomDataset`,它能够解析`val_annotations.txt`并将验证集按照正确的标签分配给每一张图片。对于其他部分(如训练集),可以直接采用默认方式加载即可。
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值