利用神经网络模型来训练数据集的时候,num_workers(线程)该选择多少最为合适,话不多说,请看下文:
import time
import torch
import torchvision
import torchvision.transforms as transforms
#测试num_workers取多少值是最合适的
def measure_dataloader_time(num_workers, batch_size=64, num_batches=100):
# CIFAR-10数据集的转换
transform = transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
# 下载并加载CIFAR-10训练数据集
trainset = torchvision.datasets.CIFAR10(root="/home/yc/vs_code/model_neural_network/dataset_CIFAR-10",
train=True,
download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
shuffle=True, num_workers=num_workers)
# 开始计时
start_time = time.time()
# 迭代加载数据
for i, data in enumerate(trainloader, 0):
# print(i)
if i >= num_batches: # 只测量指定批次的时间
break
# 结束计时
end_time = time.time()
# 计算总时间
total_time = end_time - start_time
print(f"Total time with {num_workers} workers is: {total_time:.2f} seconds")
# 测试不同的num_workers设置
for num in range(0, 9): # 例如,测试从0到8的num_workers,由于该电脑为8核的,故只能测试到8
measure_dataloader_time(num_workers=num)
输出:
Total time with 1 workers is: 0.50 seconds
Files already downloaded and verified
Total time with 2 workers is: 0.35 seconds
Files already downloaded and verified
Total time with 3 workers is: 0.30 seconds
Files already downloaded and verified
Total time with 4 workers is: 0.26 seconds
Files already downloaded and verified
Total time with 5 workers is: 0.24 seconds
Files already downloaded and verified
Total time with 6 workers is: 0.26 seconds
Files already downloaded and verified
Total time with 7 workers is: 0.31 seconds
Files already downloaded and verified
Total time with 8 workers is: 0.27 seconds
故选择num_workers=5最优。该值的设置由不同的电脑配置而有所区别