Open Set Domain Adaptation by Backpropagation(OSBP)论文数字数据集复现

本文档详细介绍了如何复现Open Set Domain Adaptation by Backpropagation (OSBP) 论文中的实验,特别是在数字数据集(如MNIST, USPS, SVHN)上的应用。首先,准备了所需的数据集,接着描述了模型结构,然后通过将SVHN数据集中的部分样本训练模型,并在MNIST上进行验证。最后,提供了实验结果及代码链接。" 79025209,5720679,JavaScript计时器与清理,"['JavaScript', '前端开发', '定时任务']

Open Set Domain Adaptation by Backpropagation(OSBP)论文数字数据集复现

1.准备数据集

MNIST数据集:28*28,共70000张图片,10类数字

USPS数据集:16*16,共20000张图片,10类数字

SVHN数据集:32*32,共73257张图片,10类数字

由于torchvision.datasets中自带的数据集没有USPS数据集,所以使用一个类设置数据集

"""Dataset setting and data loader for USPS.
Modified from
https://github.com/mingyuliutw/CoGAN/blob/master/cogan_pytorch/src/dataset_usps.py
"""

import gzip
import os
import pickle
import urllib
from PIL import Image

import numpy as np
import torch
import torch.utils.data as data
from torch.utils.data.sampler import WeightedRandomSampler
from torchvision import datasets, transforms


class USPS(data.Dataset):
    """USPS Dataset.
    Args:
        root (string): Root directory of dataset where dataset file exist.
        train (bool, optional): If True, resample from dataset randomly.
        download (bool, optional): If true, downloads the dataset
            from the internet and puts it in root directory.
            If dataset is already downloaded, it is not downloaded again.
        transform (callable, optional): A function/transform that takes in
            an PIL image and returns a transformed version.
            E.g, ``transforms.RandomCrop``
    """

    url = "https://raw.githubusercontent.com/mingyuliutw/CoGAN/master/cogan_pytorch/data/uspssample/usps_28x28.pkl"

    def __init__(self, root, train=True, transform=None, download=False):
        """Init USPS dataset."""
        # init params
        self.root = os.path.expanduser(root)
        self.filename = "usps_28x28.pkl"
        self.train = train
        # Num of Train = 7438, Num ot Test 1860
        self.transform = transform
        self.dataset_size = None

        # download dataset.
        if download:
            self.download()
        if not self._check_exists():
            raise RuntimeError("Dataset not found." +
                               " You can use download=True to download it")

        self.train_data, self.train_labels = self.load_samples()
        if self.train:
            total_num_samples = self.train_labels.shape[0]
            indices = np.arange(total_num_samples)
            self.train_data = self.train_data[indices[0:self.dataset_size], ::]
            self.train_labels = self.train_labels[indices[0:self.dataset_size]]
        self.train_data *= 255.0
        self.train_data = np.squeeze(self.train_data).astype(np.uint8)

    def __getitem__(self, index):
        """Get images and target for data loader.
        Args:
            index (int): Index
        Returns:
            tuple: (image, target) where target is index of the target class.
        """
        img, label = self.train_data[index], self.train_labels[index]
        img = Image.fromarray(img, mode='L')
        img = img.copy()
        if self.transform is not None:
            img = self.transform(img)
        return img, label.astype("int64")

    def __len__(self):
        """Return size of dataset."""
        return len(self.train_data)

    def _check_exists(self):
        """Check if dataset is download and in right place."""
        return os.path.exists(os.path.join(self.root, self.filename))

    def download(self):
        """Download dataset."""
        filename = os.path.join(self.root, self.filename)
        dirname = os.path.dirname(filename)
        if not os.path.isdir(dirname):
            os.makedirs(dirname)
        if os.path.isfile(filename):
            return
        print("Download %s to %s" % (self.url, os.path.abspath(filename)))
        urllib.request.urlretrieve(self.url, filename)
        print("[DONE]")
        return

    def load_samples(self):
        """Load sample images from dataset."""
        filename = os.path.join(self.root, self.filename)
        f = gzip.open(filename, "rb")
        data_set = pickle.load(f, encoding="bytes")
        f.close()
        if self.train:
            images = data_set[0][0]
            labels = data_set[0][1]
            self.dataset_size = labels.shape[0]
        else:
            images = data_set[1][0]
            labels = data_set[1][1]
            self.dataset_size = labels.shape[0]
        return images, labels

from torchvision import transforms
from torchvision.datasets import MNIST
from torchvision.datasets import SVHN

from .mnist import *
from .svhn import *
from .usps import *

def get_dataset(task):
    if task == 's2m':
        #注意这里的SVHN与MNIST、USP数据路径的写法不一样,
        train_dataset = SVHN('datasets/SVHN', split='train', download=False,#split='train':选择使用SVHN的train数据集
                transform=transforms.Compose([
                    transforms.Resize(32),
                    transforms.ToTensor(),
                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                ]))
        
        test_dataset = MNIST('datasets', train=True, download=False,
                transform=transforms.Compose([
                    transforms.Resize(32),
                    transforms.Lambda(lambda x: x.convert("RGB")),#因为SVHN数据集中的数据是通道为3的彩色图片,为了一致将MNIST数据集也转换为彩色
                    transforms.ToTensor(),
                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))#前面的(0.5,0.5,0.5) 是 R G B 三个通道上的均
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值