项目代码:
https://github.com/denguir/student-teacher-anomaly-detection
其实也可以直接随机crop大图区域,然后再crop-patch(65*65)
但是这里我们将大图划分成4块区域去做了
before_dir = os.path.abspath(os.path.join(os.getcwd(), ".."))
print("before_dir", before_dir)
dataset = AnomalyDataset(csv_file=os.path.join(before_dir, 'data/{}/{}.csv'.format(DATASET, DATASET)),
root_dir=os.path.join(before_dir, 'data/{}/img'.format(DATASET)),
transform=transforms.Compose([
# transforms.Grayscale(num_output_channels=3),
# transforms.Resize((imH, imW)),
transforms.RandomHorizontalFlip(),
transforms.RandomVerticalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),
type='train',
label=0)
import os
import numpy as np
import pandas as pd
import torch
from PIL import Image
from einops import rearrange
from torchvision import transforms, utils
from torch.utils.data.dataset import Dataset
from torch.utils.data.dataloader import DataLoader
import cv2
def cut_image(image):
width, height = image.size
item_width = int(width / 2)
item_height = int(height / 2)
box_list = []
# (left, upper, right, lower)
for i in range(0,2):#两重循环,生成4张图片基于原图的位置
for j in range(0,2):
#print((i*item_width,j*item_height,(i+1)*item_width,(j+1)*item_height))
box = (j*item_width,i*item_height,(j+1)*item_width,(i+1)*item_height)
box_list.append(box)
image_list = [image.crop(box) for box in box_list]
return image_list
class AnomalyDataset(Dataset):
'''Anomaly detection dataset'''
def __init__(self, csv_file, root_dir, transform=None, **constraint):
super(AnomalyDataset, self).__init__()
self.root_dir = root_dir
self.transform = transform
self.frame_list = self._get_dataset(csv_file, constraint)
imH = 576
imW = 768
self.resize = transforms.Compose([transforms.Resize((imH, imW))])
def _get_dataset(self, csv_file, constraint):
'''Apply filter based on the contraint dict on the dataset'''
df = pd.read_csv(csv_file)
df = df.loc[(df[list(constraint)] == pd.Series(constraint)).all(axis=1)]
return df
def __len__(self):
return len(self.frame_list)
def __getitem__(self, idx):
if torch.is_tensor(idx):
idx = idx.tolist()
img_name = os.path.join(self.root_dir, self.frame_list.iloc[idx]['image_name'])
label = self.frame_list.iloc[idx]['label']
image_array = cv2.imread(img_name, -1)
#cv2.cvtColor()
image = Image.fromarray(image_array.astype('uint8')).convert('RGB')
#image2 = Image.open(img_name)
image = self.resize(image)
image_list = cut_image(image)
# for m_key, m_val in enumerate(image_list):
# m_val.save('./result_{}.png'.format(m_key))
sample = {'image': [], 'label': []}
for m_key, m_val in enumerate(image_list):
sample['image'].append(self.transform(m_val))
sample['label'].append(label)
# sample = {'image': image, 'label': label}
#
# if self.transform:
# sample['image'] = self.transform(image)
return sample
if __name__ == '__main__':
import matplotlib.pyplot as plt
import sys
DATASET = "mydata"
dataset = AnomalyDataset(csv_file=f'../data/{DATASET}/{DATASET}.csv',
root_dir=f'../data/{DATASET}/img',
transform=transforms.Compose([
#transforms.Grayscale(num_output_channels=3),
transforms.Resize((256, 256)),
transforms.RandomCrop((256, 256)),
transforms.ToTensor()]),
type='train',
label=0)
dataloader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=4)
for i, batch in enumerate(dataloader):
print(i, batch['image'].size(), batch['label'].size())
# display 3rd batch
if i == 3:
n = np.random.randint(0, len(batch['label']))
image = rearrange(batch['image'][n, :, :, :], 'c h w -> h w c')
label = batch['label'][n]
plt.title(f"Sample #{n} - {'Anomalous' if label else 'Normal'}")
plt.imshow(image)
plt.show()
break
for i, batch in tqdm(enumerate(dataloader)):
# zero the parameters gradient
optimizers[j].zero_grad()
# forward pass
# inputs = batch['image'].to(device)
for m_val in range(len(batch['image'])):
inputs = batch['image'][m_val].to(device)
with torch.no_grad():
targets = (teacher(inputs) - t_mu) / torch.sqrt(t_var)
outputs = student(inputs)
loss = student_loss(targets, outputs)
# backward pass
loss.backward()
optimizers[j].step()
running_loss += loss.item()