deeplabv3plus
图像处理代码
import os
from PIL import Image
from configs import cfg
import numpy as np
import cv2
import random
#生成数据
def gendata_pro(unit_sizes=[2048, 1024, 512], begin_ind=0, randis=1,stride_rato=0.5,balance=0,mxnm = 10000):
save_img_dir=os.path.join(cfg.DATA_DIR ,'image')
save_mask_dir = os.path.join(cfg.DATA_DIR, 'mask')
Image.MAX_IMAGE_PIXELS = 5000000000
img = Image.open(cfg.DATA_DIR + 'jingwei_round1_train_20190619/image_2.png')
img = np.asarray(img)
print(img.shape)
anno_map = Image.open(cfg.DATA_DIR + 'jingwei_round1_train_20190619/image_2_label.png')
anno_map = np.asarray(anno_map)
print(anno_map.shape)
length, width = img.shape[0], img.shape[1]
for unit_size in unit_sizes:
nullthresh = unit_size * unit_size * 0.7
maxthresh=unit_size * unit_size * 0.3
count_cls = np.zeros(4)
ind = 0
if not os.path.exists(save_img_dir + str(unit_size)):
os.makedirs(save_img_dir + str(unit_size))
if not os.path.exists(save_mask_dir+ str(unit_size)):
os.makedirs(save_mask_dir + str(unit_size))
def save_img(x1,x2,y1,y2):
global ind
im = img[x1:x2, y1:y2, :]
if (im[:, :, 0] == 0).sum() > nullthresh:
return 0
save_img = np.array(im[:, :, 0:3])
save_mask = np.array(anno_map[x1:x2, y1:y2])
if balance:
num_cls = np.array([(save_mask == p).sum() for p in range(4)])
if (count_cls[0] < mxnm and num_cls[0] > maxthresh):
count_cls[0] += 1
elif (count_cls[1] < mxnm and num_cls[1] > maxthresh):
count_cls[1] += 1
elif (count_cls[2] < mxnm and num_cls[2] > maxthresh):
count_cls[2] += 1
elif (count_cls[3] < mxnm and num_cls[3] > maxthresh):
count_cls[3] += 1
else:
if (count_cls.sum() == mxnm * 4):
return 1
else:
return 0
bd = ind + begin_ind
ind = ind + 1
cv2.imwrite(save_img_dir + str(unit_size) + '/%06d.jpg' % bd, save_img)
cv2.imwrite(save_img_dir + str(unit_size) + '/%06d.png' % bd, save_mask)
return 0
if randis:
randnum = 200000 * 1.0 / unit_size
randnum = randnum * randnum
print(unit_size, randnum)
for i in range(int(randnum)):
x1, y1 = random.randint(0, length), random.randint(0, width)
x2, y2 = x1 + unit_size, y1 + unit_size
if x2 > length:
x2, x1 = length, length - unit_size
if y2 > width:
y2, y1 = width, width - unit_size
if save_img(x1, x2, y1, y2)==1:
return
else:
x1=0
while (x1 < length):
x2=x1+unit_size
if x2 > length:
x2, x1 = length, length - unit_size
y1=0
while (y1 < width):
y2 = y1 + unit_size
if y2 > width:
y2, y1 = width, width - unit_size
if save_img(x1, x2, y1, y2) == 1:
return
y1 += unit_size*stride_rato
x1 += unit_size*stride_rato
#生成label
def genlabel(unit_sizes=[2048, 1024, 512]):
import random
save_img_dir=os.path.join(cfg.DATA_DIR ,'image')
save_mask_dir = os.path.join(cfg.DATA_DIR, 'mask')
split_rato=0.8
for unit_size in unit_sizes:
train_txt=open(cfg.DATA_DIR+'train'+str(unit_size)+'.txt','w+')
test_txt=open(cfg.DATA_DIR+'test'+str(unit_size)+'.txt','w+')
img_list=os.listdir(save_img_dir+str(unit_size))
mask_list = os.listdir(save_mask_dir+str(unit_size))
id_list=[img_name.split('.')[0] for img_name in img_list]
for id in id_list:
s=random.random()
if id+'.jpg' not in img_list:
raise ValueError('label gen error')
if id+'.png' not in mask_list:
raise ValueError('label gen error')
if(s>split_rato):
test_txt.write('image' + str(unit_size)+ ' / '+ id + '.jpg mask' + str(unit_size) + '/' + id + '.png\n')
else:
test_txt.write('image' + str(unit_size) + '/' + id + '.jpg mask' + str(unit_size) + '/' + id + '.png\n')
train_txt.close()
test_txt.close()
#提交图片生成
def genbig(unit_sizes=[512],stride_rato=0.5):
from test import *
img=cv2.imread(cfg.DATA_DIR+ 'jingwei_round1_test_a_20190619/image_4.png')
length, width = img.shape[0], img.shape[1]
mask = np.zeros((length, width))
print(img.shape)
for unit_size in unit_sizes:
x1 = 0
while (x1 < length):
x2 = x1 + unit_size
if x2 > length:
x2, x1 = length, length - unit_size
y1 = 0
while (y1 < width):
y2 = y1 + unit_size
if y2 > width:
y2, y1 = width, width - unit_size
im = img[x1:x2, y1:y2, :]
if (im.max() > 0):
continue
#cv2.imwrite("tmp.jpg", im)
#result = pred("tmp.jpg")
#mask[x1:x2, y1:y2] = result
#mix_img = maskAddImg(im, result)
#cv2.imshow('mix_img', mix_img)
#cv2.waitKey(500)
# end
y1 += unit_size * stride_rato
x1 += unit_size * stride_rato
cv2.imwrite("mask.png",mask)
#添加mask
def maskAddImg(img, mask):
mask_red=np.zeros_like(mask)
mask_green = np.zeros_like(mask)
mask_blue = np.zeros_like(mask)
mask_red[mask==1]=1
mask_green[mask == 2]=1
mask_blue[mask == 3]=1
mask_img_n = np.stack((mask_red, mask_green, mask_blue), axis=2)
mix_img = cv2.addWeighted(img, 0.5, mask_img_n * 255, 0.5, 1)
return mix_img
genbig()
#genlabel()
#gendata_pro()
config代码
import torch
import argparse
import os
import sys
import cv2
import time
class Configuration():
def __init__(self):
self.ROOT_DIR ='./'
self.EXP_NAME = 'deeplabv3+tianchi'
self.DATA_DIR = "D:/lengxia/code/deeplabv3plus/data/"
self.TXT_LIST = ["test512.txt","test1024.txt"]
self.DATA_NAME = 'tianchi'
self.DATA_AUG = False
self.DATA_WORKERS = 4
self.DATA_RESCALE = 512
self.DATA_RANDOMCROP = 0
self.DATA_RANDOMROTATION = 0
self.DATA_RANDOMSCALE = 1
self.DATA_RANDOM_H = 10
self.DATA_RANDOM_S = 10
self.DATA_RANDOM_V = 10
self.DATA_RANDOMFLIP = 0.5
self.DATA_SPLIT=6
self.MODEL_NAME = 'deeplabv3plus'
self.MODEL_BACKBONE = 'res101_atrous'
self.MODEL_OUTPUT_STRIDE = 16
self.MODEL_ASPP_OUTDIM = 256
self.MODEL_SHORTCUT_DIM = 48
self.MODEL_SHORTCUT_KERNEL = 1
self.MODEL_NUM_CLASSES = 4
self.MODEL_SAVE_DIR = os.path.join(self.ROOT_DIR,'model',self.EXP_NAME)
self.TRAIN_LR = 0.0007
self.TRAIN_LR_GAMMA = 0.1
self.TRAIN_MOMENTUM = 0.9
self.TRAIN_WEIGHT_DECAY = 0.00004
self.TRAIN_BN_MOM = 0.0003
self.TRAIN_POWER = 0.9
self.TRAIN_GPUS = 1
self.GPUS_ID=[0]
self.TRAIN_BATCHES = 8
self.TRAIN_SHUFFLE = True
self.TRAIN_MINEPOCH = 0
self.TRAIN_EPOCHS = 60
self.TRAIN_LOSS_LAMBDA = 0
self.TRAIN_TBLOG = True
self.TRAIN_CKPT = None#os.path.join(self.ROOT_DIR,'model/deeplabv3+voc/deeplabv3plus_res101_atrous_VOC2012_epoch46_all.pth')
self.LOG_DIR = os.path.join(self.ROOT_DIR,'log',self.EXP_NAME)
self.TEST_MULTISCALE = [1.0]#[0.5, 0.75, 1.0, 1.25, 1.5, 1.75]
self.TEST_FLIP = False#True
self.TEST_CKPT = os.path.join(self.ROOT_DIR,'model/deeplabv3+tianchi/deeplabv3plus_res101_atrous_tianchi_itr15000_256_2.pth')
self.TEST_GPUS = 1
self.TEST_BATCHES = 1
self.__check()
#self.__add_path(os.path.join(self.ROOT_DIR, 'lib'))
def __check(self):
if not torch.cuda.is_available():
raise ValueError('configs.py: cuda is not avalable')
if self.TRAIN_GPUS == 0:
raise ValueError('configs.py: the number of GPU is 0')
if not os.path.isdir(self.LOG_DIR):
os.makedirs(self.LOG_DIR)
if not os.path.isdir(self.MODEL_SAVE_DIR):
os.makedirs(self.MODEL_SAVE_DIR)
def __add_path(self, path):
if path not in sys.path:
sys.path.insert(0, path)
cfg = Configuration()
训练代码
# ----------------------------------------
# Written by Yude Wang
# ----------------------------------------
import torch
import torch.nn as nn
import os
import numpy as np
from configs import cfg
from lib.net.generateNet import generate_net
import torch.optim as optim
from torch.utils.data import DataLoader
from lib.net.loss import MaskCrossEntropyLoss, MaskBCELoss, MaskBCEWithLogitsLoss
from lib.net.sync_batchnorm.replicate import patch_replication_callback
import cv2
import TianchiDataset
def collate_fn(batch):
images = []
seg = []
cs = []
rs = []
names = []
for _,sample in enumerate(batch):
images.append(sample['image'])
seg.append(sample['segmentation'])
rs.append(sample['row'])
cs.append(sample['col'])
names.append(sample['name'])
print(sample['image'].shape)
return {
'image': torch.stack(images,0),
'segmentation': torch.stack(seg,0),
}
def train_net():
dataset = TianchiDataset(cfg.DATA_NAME, cfg, 'train')
dataloader = DataLoader(dataset,
batch_size=cfg.TRAIN_BATCHES,
shuffle=cfg.TRAIN_SHUFFLE,
num_workers=cfg.DATA_WORKERS,
collate_fn=collate_fn,
drop_last=True)
net = generate_net(cfg)
if cfg.TRAIN_TBLOG:
from tensorboardX import SummaryWriter
tblogger = SummaryWriter(cfg.LOG_DIR)
print('Use %d GPU'%cfg.TRAIN_GPUS)
device = torch.device(cfg.GPUS_ID[0])
if cfg.TRAIN_GPUS > 1:
net = nn.DataParallel(net,device_ids=cfg.GPUS_ID)
patch_replication_callback(net)
net.to(device)
if cfg.TRAIN_CKPT:
pretrained_dict = torch.load(cfg.TRAIN_CKPT)
net_dict = net.state_dict()
pretrained_dict = {k: v for k, v in pretrained_dict.items() if (k in net_dict) and (v.shape==net_dict[k].shape)}
net_dict.update(pretrained_dict)
net.load_state_dict(net_dict)
# net.load_state_dict(torch.load(cfg.TRAIN_CKPT),False)
criterion = nn.CrossEntropyLoss(ignore_index=255)
optimizer = optim.SGD(
params = [
{'params': get_params(net.module if cfg.TRAIN_GPUS>1 else net,key='1x'), 'lr': cfg.TRAIN_LR},
{'params': get_params(net.module if cfg.TRAIN_GPUS>1 else net,key='10x'), 'lr': 10*cfg.TRAIN_LR}
],
momentum=cfg.TRAIN_MOMENTUM
)
itr = cfg.TRAIN_MINEPOCH * len(dataloader)
max_itr = cfg.TRAIN_EPOCHS*len(dataloader)
running_loss10 = 0.0
running_loss100 = 0.0
tblogger = SummaryWriter(cfg.LOG_DIR)
net.train()
for epoch in range(cfg.TRAIN_MINEPOCH, cfg.TRAIN_EPOCHS):
for i_batch, sample_batched in enumerate(dataloader):
now_lr = adjust_lr(optimizer, itr, max_itr)
inputs_batched, labels_batched = sample_batched['image'], sample_batched['segmentation']
optimizer.zero_grad()
labels_batched = labels_batched.long().to(device)
inputs_batched=inputs_batched.to(device)
predicts_batched = net(inputs_batched)
loss = criterion(predicts_batched, labels_batched)
loss.backward()
optimizer.step()
running_loss10 += loss.item()
running_loss100 += loss.item()
if i_batch % 10 == 1:
print('epoch:%d/%d\tbatch:%d/%d\titr:%d\tlr:%g\tloss:%g ' %
(epoch, cfg.TRAIN_EPOCHS, i_batch, dataset.__len__()//cfg.TRAIN_BATCHES,
itr+1, now_lr, running_loss10/10))
running_loss10 = 0.0
if cfg.TRAIN_TBLOG and itr%50 == 0:
inputs = inputs_batched[0].cpu().numpy()/2.0 + 0.5
labels = labels_batched[0].cpu().numpy()
predicts = torch.argmax(predicts_batched[0], dim=0).cpu().numpy()
labels_color = dataset.label2colormap(labels).transpose((2,0,1))
predicts_color = dataset.label2colormap(predicts).transpose((2,0,1))
pix_acc = np.sum(labels==predicts)/(cfg.DATA_RESCALE**2)
tblogger.add_scalar('loss', running_loss100/100, itr)
tblogger.add_scalar('lr', now_lr, itr)
tblogger.add_scalar('pixel acc', pix_acc, itr)
tblogger.add_image('Input', inputs, itr)
tblogger.add_image('Label', labels_color, itr)
tblogger.add_image('Output', predicts_color, itr)
#cv2.imshow("label", labels_color.transpose(1, 2, 0))
#cv2.imshow("pred", predicts_color.transpose(1, 2, 0))
#cv2.waitKey(2000)
#cv2.destroyAllWindows()
running_loss100 = 0.0
if itr % 5000 == 0:
save_path = os.path.join(cfg.MODEL_SAVE_DIR,'%s_%s_%s_itr%d.pth'%(cfg.MODEL_NAME,cfg.MODEL_BACKBONE,cfg.DATA_NAME,itr))
torch.save(net.state_dict(), save_path)
print('%s has been saved'%save_path)
itr += 1
save_path = os.path.join(cfg.MODEL_SAVE_DIR,'%s_%s_%s_epoch%d_all.pth'%(cfg.MODEL_NAME,cfg.MODEL_BACKBONE,cfg.DATA_NAME,cfg.TRAIN_EPOCHS))
torch.save(net.state_dict(),save_path)
if cfg.TRAIN_TBLOG:
tblogger.close()
print('%s has been saved'%save_path)
def adjust_lr(optimizer, itr, max_itr):
now_lr = cfg.TRAIN_LR * (1 - itr/(max_itr+1)) ** cfg.TRAIN_POWER
optimizer.param_groups[0]['lr'] = now_lr
optimizer.param_groups[1]['lr'] = 10*now_lr
return now_lr
def get_params(model, key):
for m in model.named_modules():
if key == '1x':
if 'backbone' in m[0] and isinstance(m[1], nn.Conv2d):
for p in m[1].parameters():
yield p
elif key == '10x':
if 'backbone' not in m[0] and isinstance(m[1], nn.Conv2d):
for p in m[1].parameters():
yield p
if __name__ == '__main__':
train_net()
测试代码
# ----------------------------------------
# Written by Yude Wang
# ----------------------------------------
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import cv2
from configs import cfg
from transform import ToTensor
from lib.net.generateNet import generate_net
from lib.net.sync_batchnorm.replicate import patch_replication_callback
from torch.utils.data import DataLoader
from TianchiDataset import TianchiDataset
class TianchiModel:
def __init__(self,cfg):
self.net = generate_net(cfg)
self.cfg=cfg
self.tensor = ToTensor()
print('net initialize')
def inittest(self):
if self.cfg.TEST_CKPT is None:
raise ValueError('test.py: cfg.MODEL_CKPT can not be empty in test period')
print('Use %d GPU' % self.cfg.TEST_GPUS)
device = torch.device(self.cfg.GPUS_ID[0])
if self.cfg.TEST_GPUS > 1:
self.net = nn.DataParallel(self.net)
patch_replication_callback(self.net)
self.net.to(device)
print('start loading model %s' % self.cfg.TEST_CKPT)
model_dict = torch.load(self.cfg.TEST_CKPT, map_location=device)
self.net.load_state_dict(model_dict)
self.net.eval()
def pred(self,img):
#get cv : BGR
image = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
r, c, _ = image.shape
sample = {'image': image, 'name': "ok", 'row': r, 'col': c}
sample = self.tensor(sample)
inputs_batched = sample['image'].unsqueeze(0).to(self.cfg.GPUS_ID[0])
predicts = self.net(inputs_batched).to(self.cfg.GPUS_ID[0])
predicts_batched = predicts.clone()
predicts_batched = F.interpolate(predicts_batched, size=None, scale_factor=1 / 1.0, mode='bilinear',
align_corners=True)
result = torch.argmax(predicts_batched, dim=1).cpu().numpy().astype(np.uint8)
return result[0]
def test_net():
dataset = TianchiDataset(cfg, 'test')
dataloader = DataLoader(dataset,
batch_size=cfg.TEST_BATCHES,
shuffle=False,
num_workers=cfg.DATA_WORKERS)
net = generate_net(cfg)
print('net initialize')
if cfg.TEST_CKPT is None:
raise ValueError('test.py: cfg.MODEL_CKPT can not be empty in test period')
print('Use %d GPU'%cfg.TEST_GPUS)
device = torch.device('cuda')
#net = nn.DataParallel(net)
#patch_replication_callback(net)
net.to(device)
print('start loading model %s'%cfg.TEST_CKPT)
model_dict = torch.load(cfg.TEST_CKPT,map_location=device)
net.load_state_dict(model_dict)
net.eval()
result_list = []
with torch.no_grad():
hist=np.zeros((4,4))
for i_batch, sample_batched in enumerate(dataloader):
name_batched = sample_batched['name']
row_batched = sample_batched['row']
col_batched = sample_batched['col']
[batch, channel, height, width] = sample_batched['image'].size()
multi_avg = torch.zeros((batch, cfg.MODEL_NUM_CLASSES, height, width), dtype=torch.float32).to(0)
for rate in cfg.TEST_MULTISCALE:
inputs_batched = sample_batched['image_%f'%rate]
inputs_batched = inputs_batched.cuda(device)
predicts = net(inputs_batched).to(device)
predicts_batched = predicts.clone()
del predicts
if cfg.TEST_FLIP:
inputs_batched_flip = torch.flip(inputs_batched,[3])
predicts_flip = torch.flip(net(inputs_batched_flip),[3]).to(device)
predicts_batched_flip = predicts_flip.clone()
del predicts_flip
predicts_batched = (predicts_batched + predicts_batched_flip) / 2.0
predicts_batched = F.interpolate(predicts_batched, size=None, scale_factor=1/rate, mode='bilinear', align_corners=True)
multi_avg = multi_avg + predicts_batched
del predicts_batched
multi_avg = multi_avg / len(cfg.TEST_MULTISCALE)
result = torch.argmax(multi_avg, dim=1).cpu().numpy().astype(np.uint8)
predicts_color = dataset.label2colormap(result[0])
labels_batched = sample_batched['segmentation']
labels = labels_batched[0].cpu().numpy().astype(np.uint8)
if(labels.max()>0):
hist += fast_hist(labels,result[0], 4)
print(per_class_iu(hist))
labels_color = dataset.label2colormap(labels)
#print(predicts_color.shape)
orimg= cv2.resize(sample_batched['orimg'][0].numpy(), dsize=(512,512), interpolation=cv2.INTER_CUBIC).astype(np.uint8)
#print(orimg.shape,predicts_color.shape,type(predicts_color),type(orimg))
mix_img = maskAddImg(orimg, result[0])
cv2.imshow('pred', mix_img)
mix_img = maskAddImg(orimg, labels)
cv2.imshow('label', mix_img)
#cv2.imshow("result", predicts_color)
#cv2.imshow("label", labels_color)
cv2.waitKey(500)
print("all",per_class_iu(hist))
def fast_hist(a, b, n):
k = (a >=0) & (a <n)
return np.bincount(n * a[k].astype(int) + b[k], minlength=n ** 2).reshape(n, n)
def per_class_iu(hist):
miou=np.diag(hist) / (hist.sum(1) + hist.sum(0) - np.diag(hist))
miou[miou!=miou]=1
return miou
def maskAddImg(img, mask):
mask_red=np.zeros_like(mask)
mask_green = np.zeros_like(mask)
mask_blue = np.zeros_like(mask)
mask_red[mask==1]=1
mask_green[mask == 2]=1
mask_blue[mask == 3]=1
mask_img_n = np.stack((mask_red, mask_green, mask_blue), axis=2)
print(img.shape, mask_img_n.shape, type(img), type(mask_img_n))
mix_img = cv2.addWeighted(img, 0.5, mask_img_n * 255, 0.5, 1)
return mix_img
if __name__ == '__main__':
test_net()
#tianchimodel=TianchiModel(cfg)
数据块代码
# ----------------------------------------
# Written by Yude Wang
# ----------------------------------------
from __future__ import print_function, division
import sys
import os
import torch
import cv2
import multiprocessing
from PIL import Image
import numpy as np
from torch.utils.data import Dataset
from transform import *
import random
class TianchiDataset(Dataset):
def __init__(self, cfg, period):
self.dataset_name = cfg.DATA_NAME
self.dataset_dir =cfg.DATA_DIR
self.name_list=[]
self.period=period
for list_txt in cfg.TXT_LIST:
with open(self.dataset_dir+list_txt,'r+') as fp:
self.name_list.extend(fp.readlines())
self.rescale = None
self.centerlize = None
self.randomcrop = None
self.randomflip = None
self.randomrotation = None
self.randomscale = None
self.randomhsv = None
self.multiscale = None
self.totensor = ToTensor()
self.cfg = cfg
self.nums = int(len(self.name_list) / self.cfg.DATA_SPLIT)
if 'tianchi' in self.dataset_name:
self.categories = [
'kaoyan', # 1
'yumi', # 2
'yirenmi', # 3
] # 3
self.num_categories = len(self.categories)
assert (self.num_categories + 1 == self.cfg.MODEL_NUM_CLASSES)
if cfg.DATA_RESCALE > 0:
self.rescale = Rescale(cfg.DATA_RESCALE, fix=False)
if 'train' in self.period:
if cfg.DATA_RANDOMCROP > 0:
self.randomcrop = RandomCrop(cfg.DATA_RANDOMCROP)
if cfg.DATA_RANDOMROTATION > 0:
self.randomrotation = RandomRotation(cfg.DATA_RANDOMROTATION)
if cfg.DATA_RANDOMSCALE != 1:
self.randomscale = RandomScale(cfg.DATA_RANDOMSCALE)
if cfg.DATA_RANDOMFLIP > 0:
self.randomflip = RandomFlip(cfg.DATA_RANDOMFLIP)
if cfg.DATA_RANDOM_H > 0 or cfg.DATA_RANDOM_S > 0 or cfg.DATA_RANDOM_V > 0:
self.randomhsv = RandomHSV(cfg.DATA_RANDOM_H, cfg.DATA_RANDOM_S, cfg.DATA_RANDOM_V)
else:
self.multiscale = Multiscale(self.cfg.TEST_MULTISCALE)
def __len__(self):
return self.nums
def __getitem__(self, idx):
idx = random.randint(0, self.cfg.DATA_SPLIT-1) * self.nums + idx
name = self.name_list[idx]
img_file = self.dataset_dir + name.split()[0]
image = cv2.imread(img_file)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
r, c, _ = image.shape
sample = {'image': image, 'name': name, 'row': r, 'col': c,'orimg':image}
seg_file = self.dataset_dir + name.split()[1]
segmentation = np.array(Image.open(seg_file))
sample['segmentation'] = segmentation
if 'train' in self.period:
if self.cfg.DATA_RANDOM_H > 0 or self.cfg.DATA_RANDOM_S > 0 or self.cfg.DATA_RANDOM_V > 0:
sample = self.randomhsv(sample)
if self.cfg.DATA_RANDOMFLIP > 0:
sample = self.randomflip(sample)
if self.cfg.DATA_RANDOMROTATION > 0:
sample = self.randomrotation(sample)
if self.cfg.DATA_RANDOMSCALE != 1:
sample = self.randomscale(sample)
if self.cfg.DATA_RANDOMCROP > 0:
sample = self.randomcrop(sample)
if self.cfg.DATA_RESCALE > 0:
# sample = self.centerlize(sample)
sample = self.rescale(sample)
else:
if self.cfg.DATA_RESCALE > 0:
sample = self.rescale(sample)
sample = self.multiscale(sample)
sample = self.totensor(sample)
#print(sample['image'].shape, sample['segmentation'].shape, sample['r'], sample['c'])
return sample
def label2colormap(self, label):
m = label.astype(np.uint8)
r,c = m.shape
cmap = np.zeros((r,c,3), dtype=np.uint8)
cmap[:,:,0] = (m&1)<<7 | (m&8)<<3
cmap[:,:,1] = (m&2)<<6 | (m&16)<<2
cmap[:,:,2] = (m&4)<<5
return cmap
def check_data():
from data_process import maskAddImg
from configs import cfg
datasets = TianchiDataset(dataset_name='tainchi',cfg=cfg,period='train')
for indx in range(len(datasets)):
sample = datasets.__getitem__(indx)
img = sample['image']#get RGB cv show(BGR)
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
mask = sample['segmentation']
try:
mix_img = maskAddImg(img, mask)
except:
print(sample['name'])
mix_img_s = np.concatenate((sample['orimg'],img, mix_img), 1)
cv2.imshow('mix_img', mix_img_s )
k = cv2.waitKey(0)
if k == ord('q'):
cv2.destroyAllWindows()
break
cv2.destroyAllWindows()
#check_data()