自建ubyte数据集格式使用

自建ubyte数据集格式使用

Dataset

class MNIST_IMG(Dataset):
    """
    自定义MNIST数据集读取,并使用DataLoader加载器加载数据
    """

    def __init__(self, root, train=True, transform=None, target_transform=None):
        super(MNIST_IMG, self).__init__()
        self.train = train
        self.transform = transform
        self.target_transform = target_transform

        if self.train:  # train sets
            images_file = root + r'\train-images.idx3-ubyte'
            labels_file = root + r'\train-labels.idx1-ubyte'
        else:
            images_file = root + r'\test-images.idx3-ubyte'
            labels_file = root + r'\test-labels.idx1-ubyte'

        # 读取二进制数据
        offset1, offset2 = 0, 0
        fp_img = open(images_file, 'rb').read()
        fp_label = open(labels_file, 'rb').read()

        # 解析文件头信息,依次为魔数、图片数量、每张图片的高、宽
        magics1, num_img, rows, cols = struct.unpack_from('>IIII', fp_img, offset1)
        magics2, num_label = struct.unpack_from('>II', fp_label, offset2)

        # 解析数据集
        offset1 += struct.calcsize('>IIII')
        offset2 += struct.calcsize('>II')
        # img_fmt = '>'+str(rows*cols)+'B'    #图像数据像素值的类型为unsignedchar型,对应的format格式为B
        # 这里的图像大小为28*28=784,为了读取784个B格式数据,如果没有则只会读取一个值
        # label_fmt = '>B'

        self.images = np.empty((num_img, rows, cols))
        self.labels = np.empty(num_label)

        assert num_img == num_label  # 判断图像个数是否等于标签个数,成立则往下执行

        for i in range(num_img):
            self.images[i] = np.array(struct.unpack_from('>' + str(rows * cols) + 'B', fp_img, offset1)).reshape(
                (rows, cols))
            self.labels[i] = struct.unpack_from('>B', fp_label, offset2)[0]
            offset1 += struct.calcsize('>' + str(rows * cols) + 'B')
            offset2 += struct.calcsize('>B')

    def __getitem__(self, item):
        img = self.images[item]
        label = self.labels[item]
        if self.transform is not None:
            img = self.transform(img)
        return img, label

    def __len__(self):
        return len(self.images)

    def get_label(self, n):
        """获得第n个数字对应的标签文本"""
        text_labels = ['BENIGN', 'Bot', 'DDoS', 'DoS GoldenEye', 'DoS Hulk',
                       'DoS Slowhttptest', 'DoS slowloris', 'FTP-Patator', 'Heartbleed', 'Infiltration',
                       'PortScan','SSH-Patator','Brute Force','Sql Injection','XSS']
        return text_labels[int(n)]

    def get_labels(self, labels):  # @save
        """返回Fashion-MNIST数据集的所有标签文本
        如labels = [1,3,5,3,6,2]
        此函数具有迭代器功能
        """
        text_labels = ['BENIGN', 'Bot', 'DDoS', 'DoS GoldenEye', 'DoS Hulk',
                       'DoS Slowhttptest', 'DoS slowloris', 'FTP-Patator', 'Heartbleed', 'Infiltration',
                       'PortScan', 'SSH-Patator', 'Brute Force', 'Sql Injection', 'XSS']
        return [text_labels[int(i)] for i in labels]

加载器

	train_dataset = MNIST_IMG(path, train=True, transform=trans)
    test_dataset = MNIST_IMG(path, train=False, transform=trans)
    train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(dataset=test_dataset,batch_size=batch_size,shuffle=True)

使用

    for inputs, target in test_loader:
        inputs, target = inputs.to(device), target.to(device)
        inputs = inputs.to(torch.float32)
        outputs = model(inputs)
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值