PyTorch MNIST Dataset
1. MNIST Dataset
https://pytorch.org/vision/main/generated/torchvision.datasets.MNIST.html
torchvision.datasets.MNIST(root: Union[str, Path], train: bool = True, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, download: bool = False)
Parameters:
-
root (str or pathlib.Path) - Root directory of dataset where
MNIST/raw/train-images-idx3-ubyte and MNIST/raw/t10k-images-idx3-ubyte
exist. -
train (bool, optional) - If True, creates dataset from
train-images-idx3-ubyte
, otherwise fromt10k-images-idx3-ubyte
. -
download (bool, optional) - If True, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again.
-
transform (callable, optional) - A function/transform that takes in a PIL image and returns a transformed version. E.g,
transforms.RandomCrop
-
target_transform (callable, optional) - A function/transform that takes in the target and transforms it.
Special-members:
__getitem__(index: int) -> Tuple[Any, Any]
-
Parameters:
index (int)
-
Returns:
(image, target)
wheretarget
is index of the target class. -
Return type:
tuple
2. Source code for torchvision.datasets.mnist
https://pytorch.org/vision/main/_modules/torchvision/datasets/mnist.html
mirrors = [
"http://yann.lecun.com/exdb/mnist/",
"https://ossci-datasets.s3.amazonaws.com/mnist/",
]
resources = [
("train-images-idx3-ubyte.gz", "f68b3c2dcbeaaa9fbdd348bbdeb94873"),
("train-labels-idx1-ubyte.gz", "d53e105ee54ea40749a09fcbcd1e9432"),
("t10k-images-idx3-ubyte.gz", "9fb629c4189551a2d022fa330f9573f3"),
("t10k-labels-idx1-ubyte.gz", "ec29112dd5afa0611ce80d1b7f02629c"),
]
training_file = "training.pt"
test_file = "test.pt"
classes = [
"0 - zero",
"1 - one",
"2 - two",
"3 - three",
"4 - four",
"5 - five",
"6 - six",
"7 - seven",
"8 - eight",
"9 - nine",
]
可以通过下面的链接在浏览器中下载,复制到 data/MNIST/raw/
目录下。
https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
2.1. mnist-dataset.py
/home/yongqiang/llm_work/ggml_25_02_15/ggml/examples/mnist/mnist-dataset.py
import torch
import torchvision
import torchvision.datasets
import torchvision.transforms
import numpy as np
import matplotlib.pyplot as plt
print(torch.__version__)
train_data = torchvision.datasets.MNIST(root='./data', train=True, transform=torchvision.transforms.ToTensor(), download=True)
test_data = torchvision.datasets.MNIST(root='./data', train=False, transform=torchvision.transforms.ToTensor(), download=True)
assert len(train_data) == 60000
assert len(test_data) == 10000
print("len(train_data) =", len(train_data))
print("len(test_data) =", len(test_data))
print("type(train_data[0]):", type(train_data[0]))
print("train_data[0].shape:", train_data[0][0].shape)
classes = train_data.classes
print("train_data.classes:", classes)
print("train_data.class_to_idx: ", train_data.class_to_idx)
def ImShow(sample_element, shape = (28, 28)):
plt.imshow(sample_element[0].numpy().reshape(shape), cmap='gray')
plt.title('Label = ' + str(sample_element[1]))
plt.show()
ImShow(train_data[0])
(base) yongqiang@yongqiang:~/llm_work/ggml_25_02_15/ggml/examples/mnist$ python mnist-dataset.py
2.5.1
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 9.91M/9.91M [00:03<00:00, 3.20MB/s]
Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 28.9k/28.9k [00:00<00:00, 117kB/s]
Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 1.65M/1.65M [00:02<00:00, 656kB/s]
Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 4.54k/4.54k [00:00<00:00, 3.05MB/s]
Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw
len(train_data) = 60000
len(test_data) = 10000
type(train_data[0]): <class 'tuple'>
train_data[0].shape: torch.Size([1, 28, 28])
train_data.classes: ['0 - zero', '1 - one', '2 - two', '3 - three', '4 - four', '5 - five', '6 - six', '7 - seven', '8 - eight', '9 - nine']
train_data.class_to_idx: {'0 - zero': 0, '1 - one': 1, '2 - two': 2, '3 - three': 3, '4 - four': 4, '5 - five': 5, '6 - six': 6, '7 - seven': 7, '8 - eight': 8, '9 - nine': 9}
(base) yongqiang@yongqiang:~/llm_work/ggml_25_02_15/ggml/examples/mnist$
3. SSL: CERTIFICATE_VERIFY_FAILED
https://github.com/pytorch/pytorch/issues/33288
<urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: unable to get local issuer certificate (_ssl.c:1007)>
<urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: certificate has expired (_ssl.c:1129)>
<urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: unable to get local issuer certificate (_ssl.c:1123)>
<urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed (_ssl.c:852)>
<urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed (_ssl.c:833)>
<urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed (_ssl.c:581)>
取消证书验证:
import ssl
ssl._create_default_https_context = ssl._create_unverified_context
References
[1] Yongqiang Cheng, https://yongqiang.blog.csdn.net/