PyTorch
导入需要的库
from torch.utils.data import DataLoader
from torchvision import transforms as T
from torchvision.datasets import ImageFolder
数据增强函数
normalize = T.Normalize(mean=[0.4, 0.4, 0.4], std=[0.2, 0.2, 0.2])
transform1 = T.Compose([
T.RandomResizedCrop(224),
T.RandomHorizontalFlip(),
T.ToTensor(),
normalize,
])
导入数据并划分数据集
path = '自己的路径'
dataset = ImageFolder(path, transform=transform1)
train_size = int(0.8 * len(dataset)) # 获取训练集长度
test_size = len(dataset) - train_size # 测试集长度
train_data, test_data = torch.utils.data.random_split(dataset, [train_size, test_size])
train_dataloader = DataLoader(train_data, batch_size=32, shuffle=True)
test_dataloader= DataLoader(test_data, batch_size=32, shuffle=True)
TensorFlow
导入需要的库
import tensorflow as tf
导入数据并划分数据集
base_dir = '自己的路径'
IMAGE_SIZE = 224
BATCH_SIZE = 16
datagen = tf.keras.preprocessing.image.ImageDataGenerator(
rescale=1./255,
validation_split=0.2)
train_generator = datagen.flow_from_directory(
base_dir,
target_size=(IMAGE_SIZE, IMAGE_SIZE),
batch_size=BATCH_SIZE,
subset='training')
val_generator = datagen.flow_from_directory(
base_dir,
target_size=(IMAGE_SIZE, IMAGE_SIZE),
batch_size=BATCH_SIZE,
subset='validation')