Open Set Domain Adaptation by Backpropagation(OSBP)论文数字数据集复现
文章目录
1.准备数据集
MNIST数据集:28*28,共70000张图片,10类数字
USPS数据集:16*16,共20000张图片,10类数字
SVHN数据集:32*32,共73257张图片,10类数字
由于torchvision.datasets中自带的数据集没有USPS数据集,所以使用一个类设置数据集
"""Dataset setting and data loader for USPS.
Modified from
https://github.com/mingyuliutw/CoGAN/blob/master/cogan_pytorch/src/dataset_usps.py
"""
import gzip
import os
import pickle
import urllib
from PIL import Image
import numpy as np
import torch
import torch.utils.data as data
from torch.utils.data.sampler import WeightedRandomSampler
from torchvision import datasets, transforms
class USPS(data.Dataset):
"""USPS Dataset.
Args:
root (string): Root directory of dataset where dataset file exist.
train (bool, optional): If True, resample from dataset randomly.
download (bool, optional): If true, downloads the dataset
from the internet and puts it in root directory.
If dataset is already downloaded, it is not downloaded again.
transform (callable, optional): A function/transform that takes in
an PIL image and returns a transformed version.
E.g, ``transforms.RandomCrop``
"""
url = "https://raw.githubusercontent.com/mingyuliutw/CoGAN/master/cogan_pytorch/data/uspssample/usps_28x28.pkl"
def __init__(self, root, train=True, transform=None, download=False):
"""Init USPS dataset."""
# init params
self.root = os.path.expanduser(root)
self.filename = "usps_28x28.pkl"
self.train = train
# Num of Train = 7438, Num ot Test 1860
self.transform = transform
self.dataset_size = None
# download dataset.
if download:
self.download()
if not self._check_exists():
raise RuntimeError("Dataset not found." +
" You can use download=True to download it")
self.train_data, self.train_labels = self.load_samples()
if self.train:
total_num_samples = self.train_labels.shape[0]
indices = np.arange(total_num_samples)
self.train_data = self.train_data[indices[0:self.dataset_size], ::]
self.train_labels = self.train_labels[indices[0:self.dataset_size]]
self.train_data *= 255.0
self.train_data = np.squeeze(self.train_data).astype(np.uint8)
def __getitem__(self, index):
"""Get images and target for data loader.
Args:
index (int): Index
Returns:
tuple: (image, target) where target is index of the target class.
"""
img, label = self.train_data[index], self.train_labels[index]
img = Image.fromarray(img, mode='L')
img = img.copy()
if self.transform is not None:
img = self.transform(img)
return img, label.astype("int64")
def __len__(self):
"""Return size of dataset."""
return len(self.train_data)
def _check_exists(self):
"""Check if dataset is download and in right place."""
return os.path.exists(os.path.join(self.root, self.filename))
def download(self):
"""Download dataset."""
filename = os.path.join(self.root, self.filename)
dirname = os.path.dirname(filename)
if not os.path.isdir(dirname):
os.makedirs(dirname)
if os.path.isfile(filename):
return
print("Download %s to %s" % (self.url, os.path.abspath(filename)))
urllib.request.urlretrieve(self.url, filename)
print("[DONE]")
return
def load_samples(self):
"""Load sample images from dataset."""
filename = os.path.join(self.root, self.filename)
f = gzip.open(filename, "rb")
data_set = pickle.load(f, encoding="bytes")
f.close()
if self.train:
images = data_set[0][0]
labels = data_set[0][1]
self.dataset_size = labels.shape[0]
else:
images = data_set[1][0]
labels = data_set[1][1]
self.dataset_size = labels.shape[0]
return images, labels
from torchvision import transforms
from torchvision.datasets import MNIST
from torchvision.datasets import SVHN
from .mnist import *
from .svhn import *
from .usps import *
def get_dataset(task):
if task == 's2m':
#注意这里的SVHN与MNIST、USP数据路径的写法不一样,
train_dataset = SVHN('datasets/SVHN', split='train', download=False,#split='train':选择使用SVHN的train数据集
transform=transforms.Compose([
transforms.Resize(32),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
]))
test_dataset = MNIST('datasets', train=True, download=False,
transform=transforms.Compose([
transforms.Resize(32),
transforms.Lambda(lambda x: x.convert("RGB")),#因为SVHN数据集中的数据是通道为3的彩色图片,为了一致将MNIST数据集也转换为彩色
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))#前面的(0.5,0.5,0.5) 是 R G B 三个通道上的均

本文档详细介绍了如何复现Open Set Domain Adaptation by Backpropagation (OSBP) 论文中的实验,特别是在数字数据集(如MNIST, USPS, SVHN)上的应用。首先,准备了所需的数据集,接着描述了模型结构,然后通过将SVHN数据集中的部分样本训练模型,并在MNIST上进行验证。最后,提供了实验结果及代码链接。"
79025209,5720679,JavaScript计时器与清理,"['JavaScript', '前端开发', '定时任务']
最低0.47元/天 解锁文章
1281

被折叠的 条评论
为什么被折叠?



