pytorch项目Student-Teacher anomaly detection修改训练方式,能将大图划分成块载入

项目代码:

https://github.com/denguir/student-teacher-anomaly-detection

其实也可以直接随机crop大图区域,然后再crop-patch(65*65)

但是这里我们将大图划分成4块区域去做了

before_dir = os.path.abspath(os.path.join(os.getcwd(), ".."))
print("before_dir", before_dir)    
dataset = AnomalyDataset(csv_file=os.path.join(before_dir, 'data/{}/{}.csv'.format(DATASET, DATASET)),
                             root_dir=os.path.join(before_dir, 'data/{}/img'.format(DATASET)),
                             transform=transforms.Compose([
                                 # transforms.Grayscale(num_output_channels=3),
                                 # transforms.Resize((imH, imW)),
                                 transforms.RandomHorizontalFlip(),
                                 transforms.RandomVerticalFlip(),
                                 transforms.ToTensor(),
                                 transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),
                             type='train',
                             label=0)
import os
import numpy as np
import pandas as pd
import torch
from PIL import Image
from einops import rearrange
from torchvision import transforms, utils
from torch.utils.data.dataset import Dataset
from torch.utils.data.dataloader import DataLoader
import cv2

def cut_image(image):
    width, height = image.size
    item_width = int(width / 2)
    item_height = int(height / 2)
    box_list = []
    # (left, upper, right, lower)
    for i in range(0,2):#两重循环,生成4张图片基于原图的位置
        for j in range(0,2):
            #print((i*item_width,j*item_height,(i+1)*item_width,(j+1)*item_height))
            box = (j*item_width,i*item_height,(j+1)*item_width,(i+1)*item_height)
            box_list.append(box)
    image_list = [image.crop(box) for box in box_list]
    return image_list

class AnomalyDataset(Dataset):
    '''Anomaly detection dataset'''

    def __init__(self, csv_file, root_dir, transform=None, **constraint):
        super(AnomalyDataset, self).__init__()
        self.root_dir = root_dir
        self.transform = transform
        self.frame_list = self._get_dataset(csv_file, constraint)

        imH = 576
        imW = 768
        self.resize = transforms.Compose([transforms.Resize((imH, imW))])
    
    def _get_dataset(self, csv_file, constraint):
        '''Apply filter based on the contraint dict on the dataset'''
        df = pd.read_csv(csv_file)
        df = df.loc[(df[list(constraint)] == pd.Series(constraint)).all(axis=1)]
        return df
    
    def __len__(self):
        return len(self.frame_list)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        img_name = os.path.join(self.root_dir, self.frame_list.iloc[idx]['image_name'])
        label = self.frame_list.iloc[idx]['label']

        image_array = cv2.imread(img_name, -1)
        #cv2.cvtColor()

        image = Image.fromarray(image_array.astype('uint8')).convert('RGB')
        #image2 = Image.open(img_name)

        image = self.resize(image)

        image_list = cut_image(image)

        # for m_key, m_val in enumerate(image_list):
        #     m_val.save('./result_{}.png'.format(m_key))

        sample = {'image': [], 'label': []}

        for m_key, m_val in enumerate(image_list):
            sample['image'].append(self.transform(m_val))
            sample['label'].append(label)

        # sample = {'image': image, 'label': label}
        #
        # if self.transform:
        #     sample['image'] = self.transform(image)
        return sample


if __name__ == '__main__':
    import matplotlib.pyplot as plt
    import sys 
    
    DATASET = "mydata"
    dataset = AnomalyDataset(csv_file=f'../data/{DATASET}/{DATASET}.csv',
                                   root_dir=f'../data/{DATASET}/img',
                                   transform=transforms.Compose([
                                       #transforms.Grayscale(num_output_channels=3),
                                       transforms.Resize((256, 256)),
                                       transforms.RandomCrop((256, 256)),
                                       transforms.ToTensor()]),
                                    type='train',
                                    label=0)
    
    dataloader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=4)
    
    for i, batch in enumerate(dataloader):
        print(i, batch['image'].size(), batch['label'].size())
        # display 3rd batch
        if i == 3:
            n = np.random.randint(0, len(batch['label']))

            image = rearrange(batch['image'][n, :, :, :], 'c h w -> h w c')
            label = batch['label'][n]

            plt.title(f"Sample #{n} - {'Anomalous' if label else 'Normal'}")
            plt.imshow(image)
            plt.show()
            break
            for i, batch in tqdm(enumerate(dataloader)):
                # zero the parameters gradient
                optimizers[j].zero_grad()

                # forward pass
                # inputs = batch['image'].to(device)

                for m_val in range(len(batch['image'])):

                    inputs = batch['image'][m_val].to(device)

                    with torch.no_grad():
                        targets = (teacher(inputs) - t_mu) / torch.sqrt(t_var)
                    outputs = student(inputs)
                    loss = student_loss(targets, outputs)

                    # backward pass
                    loss.backward()
                    optimizers[j].step()
                    running_loss += loss.item()

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值