# -*- coding:utf-8 -*-import torch
import torchvision.datasets as dset
dataset = dset.ImageFolder('/home/rane/data/dataset')#数据集目录
dataset.classes
dataset.class_to_idx
data =[]forfilein dataset.imgs:
data.append(file)print(len(data))
train_size =int(0.8*len(data))
test_size =int(0.1*len(data))
val_size =int(0.1*len(data))print("train_size:{}".format(train_size))print("test_size:{}".format(test_size))print("val_size:{}".format(val_size))
train_dataset, test_dataset, val_dataset = torch.utils.data.random_split(data,[train_size, test_size, val_size])
train_imgList_line =[]for line in train_dataset:
train_imgList_line.append(line[0])withopen('train.txt','wt')as f:
i =0for line in train_imgList_line:if i == train_size:break
f.write(str(line)+'\n')
i = i +1
test_imgList_line =[]for line in test_dataset:
test_imgList_line.append(line[0])withopen('test.txt','wt')as f:
i =0for line in test_imgList_line:if i == test_size:break
f.write(str(line)+'\n')
i = i +1
val_imgList_line =[]for line in val_dataset:
val_imgList_line.append(line[0])withopen('val.txt','wt')as f:
j =0for line in val_imgList_line:if j == val_size:break
f.write(str(line)+'\n')
j = j +1