修改CIFAR10数据集为带噪声的数据集(对称和非对称)

import os
import pickle
import torchvision as tv

import numpy as np
from PIL import Image
from torch.utils.data import Dataset
import random
import time
from IPython import embed

def parse_args():
    parser = argparse.ArgumentParser(description='command for the first train')
    parser.add_argument('--lr', type=float, default=0.1, help='learning rate')
    parser.add_argument('--batch_size', type=int, default=128, help='#images in each mini-batch')
    parser.add_argument('--test_batch_size', type=int, default=100, help='#images in each mini-batch')
    parser.add_argument('--cuda_dev', type=int, default=0, help='GPU to select')
    parser.add_argument('--epoch', type=int, default=200, help='training epoches')
    parser.add_argument('--num_classes', type=int, default=10, help='Number of in-distribution classes')
    parser.add_argument('--wd', type=float, default=1e-4, help='weight decay')
    parser.add_argument('--momentum', default=0.9, type=float, help='momentum')
    parser.add_argument('--noise_type', default='asymmetric', help='symmetric or asymmetric')
    parser.add_argument('--train_root', default='./data', help='root for train data')
    parser.add_argument('--noise_ratio', type=float, default=0.4, help='percent of noise')
    parser.add_argument('--out', type=str, default='./data/model_data', help='Directory of the output')
    parser.add_argument('--alpha', type=float, default=1.0, help='Beta distribution parameter for mixup')
    parser.add_argument('--download', type=bool, default=False, help='download dataset')
    parser.add_argument('--network', type=str, default='PR18', help='Network architecture')
    parser.add_argument('--seed_initialization', type=int, default=1, help='random seed (default: 1)')
    parser.add_argument('--seed_dataset', type=int, default=42, help='random seed (default: 1)')
    parser.add_argument('--M', action='append', type=int, default=[], help="Milestones for the LR sheduler")
    parser.add_argument('--experiment_name', type=str, default = 'Proof',help='name of the experiment (for the output files)')
    parser.add_argument('--method', type=str, default='MOIT', help='MOIT')
    parser.add_argument('--dataset', type=str, default='CIFAR-10', help='CIFAR-10, CIFAR-100')
    parser.add_argument('--initial_epoch', type=int, default=1, help="Star training at initial_epoch")
    parser.add_argument('--DA', type=str, default="complex", help='Choose simple or complex data augmentation')
    parser.add_argument('--low_dim', type=int, default=128, help='Size of contrastive learning embedding')
    parser.add_argument('--mix_labels', type=int, default=0, help='1: Interpolate two input images and "interpolate" labels')
    parser.add_argument('--batch_t', default=0.1, type=float, help='Contrastive learning temperature')
    parser.add_argument('--aprox', type=int, default=1, help='Warm-up epochs')
    parser.add_argument('--headType', type=str, default="Linear", help='Linear, NonLinear')
    parser.add_argument('--xbm_use', type=int, default=1, help='1: Use xbm')
    parser.add_argument('--xbm_begin', type=int, default=1, help='Epoch to begin using memory')
    parser.add_argument('--xbm_per_class', type=int, default=20, help='Num of samples per class to store in the memory. Memory size = xbm_per_class*num_classes')
    parser.add_argument('--startLabelCorrection', type=int, default=9999, help='Epoch to start label correction')
    parser.add_argument('--k_val', type=int, default=5, help='k for k-nn correction')
    parser.add_argument('--use_cleanLabels', type=int, default=0, help='Train the classifier with clean labels')
    parser.add_argument('--PredictiveCorrection', type=int, default=0, help='Enable predictive label correction')
    parser.add_argument('--balance_crit', type=str, default="none", help='None, max, min. median')
    parser.add_argument('--discrepancy_corrected', type=int, default=1, help='Use corrected label for discrepancy measure')
    parser.add_argument('--validation_exp', type=int, default=0, help='Using clean train subset for validation')
    parser.add_argument('--val_samples', type=int, default=5000, help='Number of samples used for validation')


    args = parser.parse_args()
    return args

def get_dataset(args, transform_train, transform_test):
    # prepare datasets

    if args.validation_exp == 1:
        #################################### Train set #############################################
        temp_dataset = Cifar10Train(args, train=True, transform=transform_train, target_transform=transform_test, download=args.download)
        train_indexes, val_indexes = train_val_split(args, temp_dataset.targets)
        cifar_train = Cifar10Train(args, train=True, transform=transform_train, target_transform=transform_test, sample_indexes=train_indexes)
        #################################### Noise corruption ######################################
        if args.noise_type == "symmetric":
            cifar_train.random_in_noise()

        elif args.noise_type == "asymmetric":
            cifar_train.real_in_noise()

        else:
            print('No noise')

        cifar_train.labelsNoisyOriginal = cifar_train.targets.copy()
        #################################### Test set #############################################
        testset = Cifar10Train(args, train=True, transform=transform_test, sample_indexes=val_indexes)
        ###########################################################################################

    else:
        #################################### Train set #############################################
        cifar_train = Cifar10Train(args, train=True, transform=transform_train, target_transform=transform_test, download=args.download)
        #################################### Noise corruption ######################################
        if args.noise_type == 'symmetric':
            cifar_train.random_in_noise()

        elif args.noise_type == 'asymmetric':
            cifar_train.real_in_noise()

        else:
            print ('No noise')

        cifar_train.labelsNoisyOriginal = cifar_train.targets.copy()

        #################################### Test set #############################################
        testset = tv.datasets.CIFAR10(root=args.train_root, train=False, download=False, transform=transform_test)
        ###########################################################################################

    return cifar_train, testset, cifar_train.clean_labels, cifar_train.noisy_labels, cifar_train.noisy_indexes,  cifar_train.labelsNoisyOriginal

def train_val_split(args, train_val):
    np.random.seed(args.seed_dataset)
    train_val = np.array(train_val)
    train_indexes = []
    val_indexes = []
    val_num = int(args.val_samples / args.num_classes)

    for id in range(args.num_classes):
        indexes = np.where(train_val == id)[0]
        np.random.shuffle(indexes)
        val_indexes.extend(indexes[:val_num])
        train_indexes.extend(indexes[val_num:])
    np.random.shuffle(train_indexes)
    np.random.shuffle(val_indexes)

    return train_indexes, val_indexes

class Cifar10Train(tv.datasets.CIFAR10):
    def __init__(self, args, train=True, transform=None, target_transform=None, sample_indexes=None, download=False):
        super(Cifar10Train, self).__init__(args.train_root, train=train, transform=transform,
                                            target_transform=target_transform, download=download)
        self.root = os.path.expanduser(args.train_root)
        self.transform = transform
        self.target_transform = target_transform
        self.train = train  # Training set or validation set

        self.args = args
        if sample_indexes is not None:
            self.data = self.data[sample_indexes]
            self.targets = list(np.asarray(self.targets)[sample_indexes])

        self.num_classes = self.args.num_classes
        self.in_index = []
        self.out_index = []
        self.noisy_indexes = []
        self.clean_indexes = []
        self.clean_labels = []
        self.noisy_labels = []
        self.out_data = []
        self.out_labels = []
        self.soft_labels = []
        self.labelsNoisyOriginal = []
        self._num = []
        self._count = 1
        self.prediction = []
        self.confusion_matrix_in = np.array([])
        self.confusion_matrix_out = np.array([])
        self.labeled_idx = []
        self.unlabeled_idx = []


        # From in ou split function:
        self.soft_labels = np.zeros((len(self.targets), self.num_classes), dtype=np.float32)
        self._num = int(len(self.targets) * self.args.noise_ratio)


    ################# Random in-distribution noise #########################
    def random_in_noise(self):

        # to be more equal, every category can be processed separately
        np.random.seed(self.args.seed_dataset)
        idxes = np.random.permutation(len(self.targets))
        clean_labels = np.copy(self.targets)
        noisy_indexes = idxes[0:self._num]
        clean_indexes = idxes[self._num:]
        for i in range(len(idxes)):
            if i < self._num:
                self.soft_labels[idxes[i]][self.targets[idxes[i]]] = 0 ## Remove soft-label created during label mapping
                # targets[idxes[i]] -> another category
                label_sym = np.random.randint(self.num_classes, dtype=np.int32)
                while(label_sym==self.targets[idxes[i]]):#To exclude the original label
                    label_sym = np.random.randint(self.num_classes, dtype=np.int32)
                self.targets[idxes[i]] = label_sym
            self.soft_labels[idxes[i]][self.targets[idxes[i]]] = 1

        self.targets = np.asarray(self.targets, dtype=np.long)
        self.noisy_labels = np.copy(self.targets)
        self.noisy_indexes = noisy_indexes
        self.clean_labels = clean_labels
        self.clean_indexes = clean_indexes
        self.confusion_matrix_in = (np.ones((self.args.num_classes, self.args.num_classes)) - np.identity(self.args.num_classes))\
                                    *(self.args.noise_ratio/(self.num_classes -1)) + \
                                    np.identity(self.args.num_classes)*(1 - self.args.noise_ratio)


    ##########################################################################


    ################# Real in-distribution noise #########################

    def real_in_noise(self):
        # to be more equal, every category can be processed separately
        np.random.seed(self.args.seed_dataset)

        ##### Create te confusion matrix #####

        self.confusion_matrix_in = np.identity(self.args.num_classes)

        # truck -> automobile
        self.confusion_matrix_in[9, 9] = 1 - self.args.noise_ratio
        self.confusion_matrix_in[9, 1] = self.args.noise_ratio

        # bird -> airplane
        self.confusion_matrix_in[2, 2] = 1 - self.args.noise_ratio
        self.confusion_matrix_in[2, 0] = self.args.noise_ratio

        # cat -> dog
        self.confusion_matrix_in[3, 3] = 1 - self.args.noise_ratio
        self.confusion_matrix_in[3, 5] = self.args.noise_ratio

        # dog -> cat
        self.confusion_matrix_in[5, 5] = 1 - self.args.noise_ratio
        self.confusion_matrix_in[5, 3] = self.args.noise_ratio

        # deer -> horse
        self.confusion_matrix_in[4, 4] = 1 - self.args.noise_ratio
        self.confusion_matrix_in[4, 7] = self.args.noise_ratio

        idxes = np.random.permutation(len(self.targets))
        clean_labels = np.copy(self.targets)

        for i in range(len(idxes)):
            self.soft_labels[idxes[i]][self.targets[idxes[i]]] = 0  ## Remove soft-label created during label mapping
            current_label = self.targets[idxes[i]]
            if self._num > 0:
                # current_label = self.targets[idxes[i]]
                conf_vec = self.confusion_matrix_in[current_label,:]
                label_sym = np.random.choice(np.arange(0, self.num_classes), p=conf_vec.transpose())
                self.targets[idxes[i]] = label_sym
            else:
                label_sym = current_label

            self.soft_labels[idxes[i]][self.targets[idxes[i]]] = 1

            if label_sym == current_label:
                self.clean_indexes.append(idxes[i])
            else:
                self.noisy_indexes.append(idxes[i])

        self.targets = np.asarray(self.targets, dtype=np.long)
        self.clean_indexes = np.asarray(self.clean_indexes, dtype=np.long)
        self.noisy_indexes = np.asarray(self.noisy_indexes, dtype=np.long)
        self.noisy_labels = self.targets
        self.clean_labels = clean_labels

    def __getitem__(self, index):
        if len(self.data) > self.args.val_samples:
            img, labels, soft_labels, noisy_labels, clean_labels = self.data[index], self.targets[index], self.soft_labels[
                index], self.labelsNoisyOriginal[index], self.clean_labels[index]
            # doing this so that it is consistent with all other datasets.
            img = Image.fromarray(img)

            img_noDA = self.target_transform(img)

            img1 = self.transform(img)

            if self.args.method == "MOIT":
                if self.train is True:
                    img2 = self.transform(img)
                else:
                    img2 = 0

                return img1, img2, img_noDA, labels, soft_labels, index, noisy_labels, clean_labels

            else:
                return img1, img_noDA, labels, soft_labels, index, noisy_labels, clean_labels

        else:
            img, labels = self.data[index], self.targets[index]
            # doing this so that it is consistent with all other datasets.
            img = Image.fromarray(img)

            img = self.transform(img)

            return img, labels

创建噪声标签的工具类

并且自定义了Dataset,从而能够得到想要的返回

return img1, img2, img_noDA, labels, soft_labels, index, noisy_labels, clean_labels
  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
首先,我们需要导入必要的库和数据集: ```python import torch import torchvision import torchvision.transforms as transforms # CIFAR10数据集 transform = transforms.Compose( [transforms.ToTensor()]) trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform) testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform) trainloader = torch.utils.data.DataLoader(trainset, batch_size=4, shuffle=True, num_workers=2) testloader = torch.utils.data.DataLoader(testset, batch_size=4, shuffle=False, num_workers=2) ``` 接下来,我们可以添一些椒盐噪声到图像中: ```python import numpy as np def add_salt_and_pepper(img, noise_ratio=0.05): """ 添椒盐噪声 """ img = np.array(img) h, w, c = img.shape mask = np.random.choice((0, 1, 2), size=(h, w, 1), p=[noise_ratio, noise_ratio, 1-2*noise_ratio]) mask = np.repeat(mask, c, axis=2) img[mask == 0] = 0 img[mask == 1] = 255 return img ``` 然后,我们可以使用 PyTorch 中的卷积神经网络(CNN)来实现图像去噪: ```python import torch.nn as nn import torch.nn.functional as F class Net(nn.Module): def __init__(self): super(Net, self).__init__() self.conv1 = nn.Conv2d(3, 6, 5) self.pool = nn.MaxPool2d(2, 2) self.conv2 = nn.Conv2d(6, 16, 5) self.fc1 = nn.Linear(16 * 5 * 5, 120) self.fc2 = nn.Linear(120, 84) self.fc3 = nn.Linear(84, 10) def forward(self, x): x = self.pool(F.relu(self.conv1(x))) x = self.pool(F.relu(self.conv2(x))) x = x.view(-1, 16 * 5 * 5) x = F.relu(self.fc1(x)) x = F.relu(self.fc2(x)) x = self.fc3(x) return x net = Net() ``` 接下来,我们需要定义损失函数和优化器: ```python import torch.optim as optim criterion = nn.CrossEntropyLoss() optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9) ``` 最后,我们可以训练模型并测试: ```python for epoch in range(2): # 进行两轮训练 running_loss = 0.0 for i, data in enumerate(trainloader, 0): # 获取输入 inputs, labels = data # 添椒盐噪声 inputs = [add_salt_and_pepper(img) for img in inputs] inputs = torch.from_numpy(np.array(inputs)).float() labels = torch.from_numpy(np.array(labels)).long() # 梯度清零 optimizer.zero_grad() # 正向传播,反向传播,优化 outputs = net(inputs) loss = criterion(outputs, labels) loss.backward() optimizer.step() # 打印统计信息 running_loss += loss.item() if i % 2000 == 1999: # 每2000个小批量数据打印一次 print('[%d, %5d] loss: %.3f' % (epoch + 1, i + 1, running_loss / 2000)) running_loss = 0.0 print('Finished Training') # 测试 correct = 0 total = 0 with torch.no_grad(): for data in testloader: images, labels = data # 添椒盐噪声 inputs = [add_salt_and_pepper(img) for img in images] inputs = torch.from_numpy(np.array(inputs)).float() labels = torch.from_numpy(np.array(labels)).long() outputs = net(inputs) _, predicted = torch.max(outputs.data, 1) total += labels.size(0) correct += (predicted == labels).sum().item() print('Accuracy of the network on the 10000 test images: %d %%' % ( 100 * correct / total)) ``` 这样,我们就可以使用 CIFAR10 数据集实现椒盐噪声的图像去噪了。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值