import h5py
import random
import torch
import numpy as np
import pickle
from torch. utils. data import Dataset
from torch. utils. data import DataLoader
from torchvision import transforms
import torchvision. transforms. functional as TF
from torch import nn
from tensorflow. keras. preprocessing. image import ImageDataGenerator, array_to_img, img_to_array, load_img
def load_image ( image_file) :
f = h5py. File( image_file, 'r' )
img_np = f[ 'img' ] [ ( ) ]
img_np = ( img_np / 255.0 ) . astype( 'float32' )
return img_np
def load_mask ( image_path, img_id, attribute= 'pigment_network' ) :
if attribute == 'all' :
mask_file = image_path + '%s_attribute_all.h5' % ( img_id)
f = h5py. File( mask_file, 'r' )
mask_np = f[ 'img' ] [ ( ) ]
else :
mask_file = image_path + '%s_attribute_%s.h5' % ( img_id, mask_attr)
f = h5py. File( mask_file, 'r' )
mask_np = f[ 'img' ] [ ( ) ]
mask_np = mask_np. astype( 'uint8' )
return mask_np
class SkinDataset ( Dataset) :
def __init__ ( self, train_test_id, image_path, train_test_split_file, train= True , attribute= None , transform= None , num_classes = None ) :
self. train_test_id = train_test_id
self. image_path = image_path
self. attribute = attribute
self. attr_types = [ 'pigment_network' , 'negative_network' , 'streaks' , 'milia_like_cyst' , 'globules' ]
self. train = train
self. transform = transform
self. num_classes = num_classes
with open ( train_test_split_file, 'rb' ) as f:
self. mask_ind = pickle. load( f)
if self. train:
self. train_test_id = self. train_test_id[ self. train_test_id[ 'Split' ] == 'train' ] . ID. values
print ( 'Train =' , self. train, 'train_test_id.shape: ' , self. train_test_id. shape)
else :
self. train_test_id = self. train_test_id[ self. train_test_id[ 'Split' ] != 'train' ] . ID. values
print ( 'Train =' , self. train, 'train_test_id.shape: ' , self. train_test_id. shape)
self. n = self. train_test_id. shape[ 0 ]
def __len__ ( ) :
return self. n
def transform_fn ( self, image, mask) :
if self. num_classes == 1 :
image = array_to_img( image, data_format= 'channels_last' )
mask = array_to_img( mask, data_format= 'channels_last' )
if random. random( ) > 0.5 :
image = TF. hflip( image)
mask = TF. hflip( mask)
if random. random( ) > 0.5 :
image = TF. vflip( image)
mask = TF. vflip( mask)
angle = random. randint( 0 , 90 )
translate = ( random. uniform( 0 , 100 ) , random. uniform( 0 , 100 ) )
scale = random. uniform( 0.5 , 2 )
shear = random. uniform( - 10 , 10 )
image = TF. affine( image, angle, translate, scale, shear)
mask = TF. affine( mask, angle, translate, scale, shear)
image = TF. adjust_brightness( image, saturation_factor= random. uniform( 0.8 , 1.2 ) )
image = TF. adjust_saturation( image, saturation_factor= random. uniform( 0.8 , 1.2 ) )
angle = random. randint( 0 , 90 )
image = TF. rotate( image, angle)
mask = TF. rotate( mask, angle)
image = img_to_array( image, data_format= 'channels_last' )
mask = img_to_array( mask, data_format= 'channels_last' )
else :
image = array_to_img( image, data_format= 'channels_last' )
mask_pil_array= [ None ] * mask. shape[ - 1 ]
for i in range ( mask. shape[ - 1 ] ) :
mask_pil_array[ i] = array_to_img( mask[ : , : , i, np. newaxis] , data_format= 'channels_last' )
if random. random( ) > 0.5 :
image = TF. hflip( image)
for i in range ( mask. shape[ - 1 ] ) :
mask_pil_array[ i] = TF. hflip( mask_pil_array[ i] )
if random. random( ) > 0.5 :
image = TF. vflip( image)
for i in range ( mask. shape[ - 1 ] ) :
mask_pil_array[ i] = TF. vflip( mask_pil_array[ i] )
angle = random. randint( 0 , 90 )
translate = ( random. uniform( 0 , 100 ) , random. uniform( 0 , 100 ) )
scale = random. uniform( 0.5 , 2 )
shear = random. uniform( 0 , 0 )
image = TF. affine( image, angle, translate, scale, shear)
for i in range ( mask. shape[ - 1 ] ) :
mask_pil_array[ i] = TF. affine( mask_pil_array[ i] , angle, translate, scale, shear)
image = TF. adjust_brightness( image, saturation_factor= random. uniform( 0.8 , 1.2 ) )
image = TF. adjust_saturation( image, saturation_factor= random. uniform( 0.8 , 1.2 ) )
image = img_to_array( image, data_format= 'channels_last' )
for i in range ( mask. shape[ - 1 ] ) :
mask[ : , : , i] = img_to_array( mask_pil_array[ i] , data_format= 'channels_last' ) [ : , : , 0 ] . astype( 'uint8' )
image = ( image/ 255 ) . astype( 'float32' )
mask = ( mask/ 255 ) . astype( 'uint8' )
return image, mask
def __getitem__ ( self, index) :
img_id = self. train_test_id[ index]
image_file = self. image_path= "%s.h5" % img_id
img_np = load_image( image_file)
mask_np = load_mask( self. image_path, img_id, self. attribute)
if self. train:
img_np, mask_np = self. transform_fn( img_np, mask_np)
img_np = img_np. astype( 'float32' )
ind = self. mask_ind. loc[ index, self. attr_types] . values. astype( 'uint8' )
print ( "imgnp:" , img_np)
print ( "mask_np:" , mask_np)
print ( "ind:" , ind)
return img_np, mask_np, ind