迁移学习之Multi-Domain Adaptation多领域自适应常用数据集PACS介绍

PACS数据集

Paper:Self-supervised Domain Adaptation for Computer Vision Tasks

GitHub:https://github.com/robertofranceschi/Domain-adaptation-on-PACS-dataset

数据集下载:https://github.com/MachineLearning2020/Homework3-PACS/tree/master/PACS

  • PACS数据集总共9991张图片,每张图片3x227x227
  • 7 classes:Dog, Elephant, Giraffe, Guitar, Horse, House, Person
  • 4 domains: Art painting, Cartoon, Photo and Sketch.
  • Photo (1,670 images), Art Painting (2,048 images), Cartoon (2,344 images) and Sketch (3,929 images)

在这里插入图片描述

用Pytorch加载PACS数据集

PACS原始数据集目录结果:
在这里插入图片描述

from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
import os
import random
from PIL import Image
import torch
import numpy as np
from torchvision.transforms import transforms
from sklearn.model_selection import train_test_split
import os
import torchvision.transforms as transforms
from PIL import Image
from torch.utils.data import Dataset, DataLoader

class PACS(Dataset):
    def __init__(self, root_path, domain, train=True, transform=None, target_transform=None):
        self.root = f"{root_path}/{domain}"
        self.train = train
        self.transform = transform
        self.target_transform = target_transform

        label_name_list = os.listdir(self.root)
        self.label = []
        self.data = []

        if not os.path.exists(f"{root_path}/precessed"):
            os.makedirs(f"{root_path}/precessed")
        if os.path.exists(f"{root_path}/precessed/{domain}_data.pt") and os.path.exists(
                f"{root_path}/precessed/{domain}_label.pt"):
            print(f"Load {domain} data and label from cache.")
            self.data = torch.load(f"{root_path}/precessed/{domain}_data.pt")
            self.label = torch.load(f"{root_path}/precessed/{domain}_label.pt")
        else:
            print(f"Getting {domain} datasets")
            for index, label_name in enumerate(label_name_list):
                label_name_2_index = {
                    'dog': 0,
                    'elephant': 1,
                    'giraffe': 2,
                    'guitar': 3,
                    'horse': 4,
                    'house': 5,
                    'person': 6,
                }
                images_list = os.listdir(f"{self.root}/{label_name}")
                for img_name in images_list:
                    img = Image.open(f"{self.root}/{label_name}/{img_name}").convert('RGB')
                    img = np.array(img)
                    self.label.append(label_name_2_index[label_name])
                    if self.transform is not None:
                        img = self.transform(img)
                    self.data.append(img)
            self.data = torch.stack(self.data)
            self.label = torch.tensor(self.label, dtype=torch.long)
            torch.save(self.data, f"{root_path}/precessed/{domain}_data.pt")
            torch.save(self.label, f"{root_path}/precessed/{domain}_label.pt")

    def __getitem__(self, index):
        img, target = self.data[index], self.label[index]

        if self.transform is not None:
            img = self.transform(img)
        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target

    def __len__(self):
        return len(self.data)


def get_pacs_domain(root_path=f"{DATA_PATH}/PACS", domain='art_painting', verbose=False):
    transform = transforms.Compose([
        transforms.ToPILImage(),
        transforms.ToTensor(),
        # transforms.Resize((224, 224)),
        transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
    ])
    all_data = PACS(root_path, domain, transform=transform)
    # train:test=8:2
    x_train, x_test, y_train, y_test = train_test_split(all_data.data.numpy(), all_data.label.numpy(),
                                                        test_size=0.20, random_state=42)

    return x_train, y_train, x_test, y_test

其他写法:https://github.com/ValerioDiEugenio/DomainAdaptation-PACSDataset/blob/main/DomainAdaptation.ipynb


最后以dog类别为例,用Python代码展示四种不同风格的图片

Python可视化图片数据集的代码

修改dir_path为对应的文件夹,dir_path = f"{DATA_PATH}/PACS/art_painting/dog"

import os
import matplotlib.pyplot as plt
import random
from PIL import Image


def plotPics(data, h=3, w=3, filename="out.jpg"):
    fig, ax_array = plt.subplots(h, w, figsize=(15, 15))

    axes = ax_array.flatten()

    for i, ax in enumerate(axes):
        ri = random.randint(0, len(data) - 1)
        ax.imshow(data[ri], cmap=plt.cm.gray)

    plt.setp(axes, xticks=[], yticks=[], frame_on=False)
    fig.tight_layout()
    fig.savefig(filename)
    plt.show()


DATA_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "../data/raw_data"))

dir_path = f"{DATA_PATH}/PACS/art_painting/dog"

data = []
for pic in os.listdir(dir_path):
    data.append(Image.open(f"{dir_path}/{pic}"))

plotPics(data, h=5, w=5)

art_painting风格

在这里插入图片描述

sketch风格

在这里插入图片描述

cartoon风格

在这里插入图片描述

photo风格

在这里插入图片描述

  • 7
    点赞
  • 27
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
在PyTorch中,可以使用以下步骤实现从源域数据集提取样本到目标域并进行领域适应: 1. 首先,需要准备源域数据集和目标域数据集,并使用PyTorch的DataLoader对数据集进行加载。 2. 接着,可以使用预训练模型或Fine-tuning等方法对源域数据集进行训练,例如在ImageNet上预训练的ResNet模型。 3. 在将模型应用于目标域数据集之前,需要进行领域适应。其中一种方法是通过对目标域数据集进行一些预处理,例如数据增强和标准化,以便与源域数据集更加相似。 4. 另一种方法是使用领域适应算法来调整模型,以便更好地适应目标域数据集。例如,可以使用PyTorch中的DANN(Domain-Adversarial Neural Network)和ADDA(Adversarial Discriminative Domain Adaptation)等算法。 以下是一个简单的示例代码,展示如何使用PyTorch实现领域适应: ```python import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoader from torchvision.datasets import MNIST from torchvision import transforms from models import Net from utils import train, test from domain_adaptation import DANN # 加载源域数据集 source_dataset = MNIST(root='./data', train=True, download=True, transform=transforms.ToTensor()) source_loader = DataLoader(source_dataset, batch_size=64, shuffle=True, num_workers=4) # 加载目标域数据集 target_dataset = MNIST(root='./data', train=False, download=True, transform=transforms.ToTensor()) target_loader = DataLoader(target_dataset, batch_size=64, shuffle=False, num_workers=4) # 定义模型 model = Net() # 定义损失函数和优化器 criterion = nn.CrossEntropyLoss() optimizer = optim.Adam(model.parameters(), lr=0.001) # 在源域数据集上进行训练 for epoch in range(10): train(model, source_loader, criterion, optimizer, epoch) # 使用DANN算法进行领域适应 dann = DANN() dann.train(source_loader, target_loader, model, criterion, optimizer) # 在目标域数据集上进行测试 test(model, target_loader, criterion) ``` 其中,models.py和utils.py分别定义了模型和训练/测试函数,domain_adaptation.py定义了DANN算法。通过以上代码,可以实现从源域数据集提取样本到目标域并进行领域适应的过程。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

捡起一束光

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值