自建ubyte数据集格式使用
Dataset
class MNIST_IMG(Dataset):
"""
自定义MNIST数据集读取,并使用DataLoader加载器加载数据
"""
def __init__(self, root, train=True, transform=None, target_transform=None):
super(MNIST_IMG, self).__init__()
self.train = train
self.transform = transform
self.target_transform = target_transform
if self.train:
images_file = root + r'\train-images.idx3-ubyte'
labels_file = root + r'\train-labels.idx1-ubyte'
else:
images_file = root + r'\test-images.idx3-ubyte'
labels_file = root + r'\test-labels.idx1-ubyte'
offset1, offset2 = 0, 0
fp_img = open(images_file, 'rb').read()
fp_label = open(labels_file, 'rb').read()
magics1, num_img, rows, cols = struct.unpack_from('>IIII', fp_img, offset1)
magics2, num_label = struct.unpack_from('>II', fp_label, offset2)
offset1 += struct.calcsize('>IIII')
offset2 += struct.calcsize('>II')
self.images = np.empty((num_img, rows, cols))
self.labels = np.empty(num_label)
assert num_img == num_label
for i in range(num_img):
self.images[i] = np.array(struct.unpack_from('>' + str(rows * cols) + 'B', fp_img, offset1)).reshape(
(rows, cols))
self.labels[i] = struct.unpack_from('>B', fp_label, offset2)[0]
offset1 += struct.calcsize('>' + str(rows * cols) + 'B')
offset2 += struct.calcsize('>B')
def __getitem__(self, item):
img = self.images[item]
label = self.labels[item]
if self.transform is not None:
img = self.transform(img)
return img, label
def __len__(self):
return len(self.images)
def get_label(self, n):
"""获得第n个数字对应的标签文本"""
text_labels = ['BENIGN', 'Bot', 'DDoS', 'DoS GoldenEye', 'DoS Hulk',
'DoS Slowhttptest', 'DoS slowloris', 'FTP-Patator', 'Heartbleed', 'Infiltration',
'PortScan','SSH-Patator','Brute Force','Sql Injection','XSS']
return text_labels[int(n)]
def get_labels(self, labels):
"""返回Fashion-MNIST数据集的所有标签文本
如labels = [1,3,5,3,6,2]
此函数具有迭代器功能
"""
text_labels = ['BENIGN', 'Bot', 'DDoS', 'DoS GoldenEye', 'DoS Hulk',
'DoS Slowhttptest', 'DoS slowloris', 'FTP-Patator', 'Heartbleed', 'Infiltration',
'PortScan', 'SSH-Patator', 'Brute Force', 'Sql Injection', 'XSS']
return [text_labels[int(i)] for i in labels]
加载器
train_dataset = MNIST_IMG(path, train=True, transform=trans)
test_dataset = MNIST_IMG(path, train=False, transform=trans)
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(dataset=test_dataset,batch_size=batch_size,shuffle=True)
使用
for inputs, target in test_loader:
inputs, target = inputs.to(device), target.to(device)
inputs = inputs.to(torch.float32)
outputs = model(inputs)