在cnn训练中使用semi-supervised扩充数据
输入和输出
输入:CNN model,标签不准的Dataset,过滤阈值
输出:经过阈值筛选后的Pseudo Dataset
解决思路
- 获取满足筛选条件的数据的indice,用于生成筛选后的Dataset和pseudo label
- 新建一个自定义Dataset类
import torch
import torch.nn as nn
from torch.utils.data import ConcatDataset, DataLoader, Subset,Dataset
'''
定义一个Dataset类
包含__init__()和__getitem__()方法
'''
# inherit Dataset
class pseudo_dataset(Dataset):
def __init__(self,unlabeled_set, indices, pseudo_labels):
self.data = Subset(unlabeled_set,indices)
self.target = torch.LongTensor(pseudo_labels)[indices]
def __getitem__(self,index):
if index < 0 : #Handle negative indices
index += len(self)
if index >= len(self):
raise IndexError("index %d is out of bounds for axis 0 with size %d"%(index, len(self)))
x = self.data[index][0]
y = self.target[index].item()
return x,y
def get_pseudo_labels(dataset, model, threshold=0.65):
# This functions generates pseudo-labels of a dataset using given model.
# It returns an instance of DatasetFolder containing images whose prediction confidences exceed a given threshold.
# You are NOT allowed to use any models trained on external data for pseudo-labeling.
device = "cuda" if torch.cuda.is_available() else "cpu"
# Construct a data loader.
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
# Make sure the model is in eval mode.
model.eval()
# Define softmax function.
softmax = nn.Softmax(dim=-1)
#recorde the filtered result
masks = []
# recorde the predicted labels
pred_labels = []
# Iterate over the dataset by batches.
for batch in tqdm(data_loader):
img, _ = batch
# Forward the data
# Using torch.no_grad() accelerates the forward process.
with torch.no_grad():
logits = model(img.to(device))
# Obtain the probability distributions by applying softmax on logits.
probs = softmax(logits)
# ---------- TODO ----------
# Filter the data and construct a new dataset.
pred_label = probs.argmax(dim=-1).tolist()
pred_labels.extend(pred_label)
mask = torch.max(probs,dim=1)[0] > threshold
masks.extend(mask)
indices = torch.arange(0,dataset.length)[masks] # len
pseudo_dataset = pseudo_dataset(dataset,indices, pseudo_labels)
print('using {0:.2f}% unlabeld data'.format(100 * len(pseudo_dataset) / len(dataset)))
# # Turn off the eval mode.
model.train()
return pseudo_dataset