迁移学习之Multi-Domain Adaptation多领域自适应常用数据集PACS介绍

PACS数据集

Paper:Self-supervised Domain Adaptation for Computer Vision Tasks

GitHub:https://github.com/robertofranceschi/Domain-adaptation-on-PACS-dataset

数据集下载:https://github.com/MachineLearning2020/Homework3-PACS/tree/master/PACS

  • PACS数据集总共9991张图片,每张图片3x227x227
  • 7 classes:Dog, Elephant, Giraffe, Guitar, Horse, House, Person
  • 4 domains: Art painting, Cartoon, Photo and Sketch.
  • Photo (1,670 images), Art Painting (2,048 images), Cartoon (2,344 images) and Sketch (3,929 images)

在这里插入图片描述

用Pytorch加载PACS数据集

PACS原始数据集目录结果:
在这里插入图片描述

from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
import os
import random
from PIL import Image
import torch
import numpy as np
from torchvision.transforms import transforms
from sklearn.model_selection import train_test_split
import os
import torchvision.transforms as transforms
from PIL import Image
from torch.utils.data import Dataset, DataLoader

class PACS(Dataset):
    def __init__(self, root_path, domain, train=True, transform=None, target_transform=None):
        self.root = f"{root_path}/{domain}"
        self.train = train
        self.transform = transform
        self.target_transform = target_transform

        label_name_list = os.listdir(self.root)
        self.label = []
        self.data = []

        if not os.path.exists(f"{root_path}/precessed"):
            os.makedirs(f"{root_path}/precessed")
        if os.path.exists(f"{root_path}/precessed/{domain}_data.pt") and os.path.exists(
                f"{root_path}/precessed/{domain}_label.pt"):
            print(f"Load {domain} data and label from cache.")
            self.data = torch.load(f"{root_path}/precessed/{domain}_data.pt")
            self.label = torch.load(f"{root_path}/precessed/{domain}_label.pt")
        else:
            print(f"Getting {domain} datasets")
            for index, label_name in enumerate(label_name_list):
                label_name_2_index = {
                    'dog': 0,
                    'elephant': 1,
                    'giraffe': 2,
                    'guitar': 3,
                    'horse': 4,
                    'house': 5,
                    'person': 6,
                }
                images_list = os.listdir(f"{self.root}/{label_name}")
                for img_name in images_list:
                    img = Image.open(f"{self.root}/{label_name}/{img_name}").convert('RGB')
                    img = np.array(img)
                    self.label.append(label_name_2_index[label_name])
                    if self.transform is not None:
                        img = self.transform(img)
                    self.data.append(img)
            self.data = torch.stack(self.data)
            self.label = torch.tensor(self.label, dtype=torch.long)
            torch.save(self.data, f"{root_path}/precessed/{domain}_data.pt")
            torch.save(self.label, f"{root_path}/precessed/{domain}_label.pt")

    def __getitem__(self, index):
        img, target = self.data[index], self.label[index]

        if self.transform is not None:
            img = self.transform(img)
        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target

    def __len__(self):
        return len(self.data)


def get_pacs_domain(root_path=f"{DATA_PATH}/PACS", domain='art_painting', verbose=False):
    transform = transforms.Compose([
        transforms.ToPILImage(),
        transforms.ToTensor(),
        # transforms.Resize((224, 224)),
        transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
    ])
    all_data = PACS(root_path, domain, transform=transform)
    # train:test=8:2
    x_train, x_test, y_train, y_test = train_test_split(all_data.data.numpy(), all_data.label.numpy(),
                                                        test_size=0.20, random_state=42)

    return x_train, y_train, x_test, y_test

其他写法:https://github.com/ValerioDiEugenio/DomainAdaptation-PACSDataset/blob/main/DomainAdaptation.ipynb


最后以dog类别为例,用Python代码展示四种不同风格的图片

Python可视化图片数据集的代码

修改dir_path为对应的文件夹,dir_path = f"{DATA_PATH}/PACS/art_painting/dog"

import os
import matplotlib.pyplot as plt
import random
from PIL import Image


def plotPics(data, h=3, w=3, filename="out.jpg"):
    fig, ax_array = plt.subplots(h, w, figsize=(15, 15))

    axes = ax_array.flatten()

    for i, ax in enumerate(axes):
        ri = random.randint(0, len(data) - 1)
        ax.imshow(data[ri], cmap=plt.cm.gray)

    plt.setp(axes, xticks=[], yticks=[], frame_on=False)
    fig.tight_layout()
    fig.savefig(filename)
    plt.show()


DATA_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "../data/raw_data"))

dir_path = f"{DATA_PATH}/PACS/art_painting/dog"

data = []
for pic in os.listdir(dir_path):
    data.append(Image.open(f"{dir_path}/{pic}"))

plotPics(data, h=5, w=5)

art_painting风格

在这里插入图片描述

sketch风格

在这里插入图片描述

cartoon风格

在这里插入图片描述

photo风格

在这里插入图片描述

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

捡起一束光

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值