EllSeg训练实验记录&代码解析&技巧tips

重定向文件没有结果

加入-u即可,不断刷新缓冲区

 nohup python3 -u ExtractOpenEDS_seg.py --path2ds="../../../DataSet" > OpenEDS_status.out 2>&1 &
数据集train、test分组
python /curObjects/datasetSelections.py

会将DS_selections写入到pkl文件中,在后续createDataloaders_baseline.py被调用。

创建train和test文件

python /curObjects/createDataloaders_baseline.py

对于单个的生成OpenEDS数据集训练文件,修改createDataloaders_baseline.py即可,详细如下:

sys.path.append('..')
import CurriculumLib as CurLib
from CurriculumLib import DataLoader_riteyes

path2data = '../../../DataSet'
path2h5 = os.path.join(path2data, 'All')
keepOld = True

AllDS = CurLib.readArchives(os.path.join(path2data, 'MasterKey'))
list_ds = ['OpenEDS']

# Generate objects per dataset
for setSel in list_ds:

    # Train object
    AllDS_cond = CurLib.selSubset(AllDS, ['train'])
    dataDiv_obj = CurLib.generate_fileList(AllDS_cond, mode='vanilla', notest=True)
    trainObj = DataLoader_riteyes(dataDiv_obj, path2h5, 0, 'train', True, (480, 640), scale=0.5)
    validObj = DataLoader_riteyes(dataDiv_obj, path2h5, 0, 'valid', False, (480, 640), scale=0.5)

    # Test object
    AllDS_cond = CurLib.selSubset(AllDS, ['validation'])
    dataDiv_obj = CurLib.generate_fileList(AllDS_cond, mode='none', notest=True)
    testObj = DataLoader_riteyes(dataDiv_obj, path2h5, 0, 'test', False, (480, 640), scale=0.5)

    path2save = os.path.join(os.getcwd(), 'baseline', 'cond_'+setSel+'.pkl')
    if os.path.exists(path2save) and keepOld:
        print('Preserving old selections ...')

        # This ensure that the original selection remains the same
        trainObj_orig, validObj_orig, testObj_orig = pickle.load(open(path2save, 'rb'))
        trainObj.imList = trainObj_orig.imList
        validObj.imList = validObj_orig.imList
        testObj.imList = testObj_orig.imList
        pickle.dump((trainObj, validObj, testObj), open(path2save, 'wb'))
    else:
        pickle.dump((trainObj, validObj, testObj), open(path2save, 'wb'))
开始训练
./runLocal.sh

代码解析

ExtractOpenEDS_seg.py
sys.path.append('..')

from helperfunctions import ransac, ElliFit, my_ellipse
from helperfunctions import generateEmptyStorage, getValidPoints

def mypause(interval):
    backend = plt.rcParams['backend']
    if backend in matplotlib.rcsetup.interactive_bk:
        figManager = matplotlib._pylab_helpers.Gcf.get_active()
        if figManager is not None:
            canvas = figManager.canvas
            if canvas.figure.stale:
                canvas.draw()
            canvas.start_event_loop(interval)
            return

parser = argparse.ArgumentParser()
parser.add_argument('--noDisp', help='Specify flag to display labelled images', type=int, default=1)
parser.add_argument('--path2ds',
                    help='Path to dataset',
                    type=str,
                    default='/media/rakshit/Monster/Datasets')

args = parser.parse_args()
if args.noDisp:
    noDisp = True
    print('No graphics')
else:
    noDisp = False
    print('Showing figures')

gui_env = ['Qt5Agg','WXAgg','TKAgg','GTKAgg']
for gui in gui_env:
    try:
        print("testing: {}".format(gui))
        matplotlib.use(gui,warn=False, force=True)
        from matplotlib import pyplot as plt
        break
    except:
        continue

print("Using: {}".format(matplotlib.get_backend()))
import matplotlib.pyplot as plt
plt.ion()

ds_num = 0
PATH_OPENEDS = os.path.join(args.path2ds, 'OpenEDS')
PATH_DIR = os.path.join(args.path2ds, 'OpenEDS', 'Semantic_Segmentation_Dataset')
PATH_DS = os.path.join(args.path2ds, 'All')
PATH_MASTER = os.path.join(args.path2ds, 'MasterKey')

print('Extracting OpenEDS')

# Don't append the test set.
listDir = os.listdir(PATH_DIR)
listDir.remove('test')
for dirCond in listDir:
    ds_name = 'OpenEDS_{}_{}'.format(dirCond, ds_num)

    print('Opening the {} folder'.format(dirCond))

    # Read UID
    path2json = os.path.join(PATH_OPENEDS, 'OpenEDS_{}_userID_mapping_to_images.json'.format(dirCond))
    im2UID = json.load(open(path2json, 'r'))

    PATH_IMAGES = os.path.join(PATH_DIR, dirCond, 'images')
    PATH_LABELS = os.path.join(PATH_DIR, dirCond, 'labels')
    PATH_FITS = os.path.join(PATH_DIR, dirCond, 'fits')
    listIm = os.listdir(PATH_IMAGES)
    # 生成一个空的包含所有属性字典,Data包含图像等所有信息,keydict包含关键信息。
    # 目的是统一数据集格式。
    Data, keydict = generateEmptyStorage(name='OpenEDS', subset=dirCond)

    i = 0
    if not noDisp:
        fig, plts = plt.subplots(1,1)

    for pData in im2UID:
        # Image number and UID for each person
        listIm = pData['semantic_segmenation_images']
        pid = int(pData['id'].replace('U', '')) - 111
        for imName_full in listIm:
            imName, _ = os.path.splitext(imName_full)

            # Do not save images without a proper ellipse and iris fit
            # Load image, label map and fits
            # imread的第二个参数传入0代表以灰度图模式读入
            I = cv2.imread(os.path.join(PATH_IMAGES, imName_full), 0)
            LabelMat = np.load(os.path.join(PATH_LABELS, imName+'.npy'))

            #%% Make sure images are 640x480
            # 参数只有一个的时候返回非零元素的坐标,以tuple给出。
            r = np.where(LabelMat)[0]
            c = int(0.5*(np.max(r) + np.min(r)))
            top, bot = (0, c+150-(c-150)) if c-150<0 else (c-150, c+150)

            I = I[top:bot, :]
            LabelMat = LabelMat[top:bot, :]
            I = cv2.resize(I, (640, 480), interpolation=cv2.INTER_LANCZOS4)
            LabelMat = cv2.resize(LabelMat, (640, 480), interpolation=cv2.INTER_NEAREST)
            #%%

            pupilPts, irisPts = getValidPoints(LabelMat)
            if np.sum(LabelMat == 3) > 150 and type(pupilPts) is not list:
                model_pupil = ransac(pupilPts, ElliFit, 15, 40, 5e-3, 15).loop()
                pupil_fit_error = my_ellipse(model_pupil.model).verify(pupilPts)
            else:
                print('Not enough pupil points')
                model_pupil = type('model', (object, ), {})
                model_pupil.model = np.array([-1, -1, -1, -1, -1])
                pupil_fit_error = np.inf

            if np.sum(LabelMat == 2) > 200 and type(irisPts) is not list:
                model_iris = ransac(irisPts, ElliFit, 15, 40, 5e-3, 15).loop()
                iris_fit_error = my_ellipse(model_iris.model).verify(irisPts)
            else:
                print('Not enough iris points')
                model_iris = type('model', (object, ), {})
                model_iris.model = np.array([-1, -1, -1, -1, -1])
                model_iris.Phi = np.array([-1, -1, -1, -1, -1])
                iris_fit_error = np.inf

            if pupil_fit_error >= 0.1:
                print('Not recording pupil. Unacceptable fit.')
                print('Pupil fit error: {}'.format(pupil_fit_error))
                model_pupil.model = np.array([-1, -1, -1, -1, -1])

            if iris_fit_error >= 0.1:
                print('Not recording iris. Unacceptable fit.')
                print('Iris fit error: {}'.format(iris_fit_error))
                model_iris.model = np.array([-1, -1, -1, -1, -1])

            pupil_loc = model_pupil.model[:2]

            # Draw mask no skin
            rr, cc = drawEllipse(pupil_loc[1],
                                 pupil_loc[0],
                                 model_pupil.model[3],
                                 model_pupil.model[2],
                                 rotation=-model_pupil.model[-1])
            pupMask = np.zeros_like(I)
            pupMask[rr.clip(0, I.shape[0]-1), cc.clip(0, I.shape[1]-1)] = 1
            rr, cc = drawEllipse(model_iris.model[1],
                                  model_iris.model[0],
                                  model_iris.model[3],
                                  model_iris.model[2],
                                  rotation=-model_iris.model[-1])
            iriMask = np.zeros_like(I)
            iriMask[rr.clip(0, I.shape[0]-1), cc.clip(0, I.shape[1]-1)] = 1

            if (np.any(pupMask) and np.any(iriMask)) and ((pupil_fit_error<0.1) and (iris_fit_error<0.1)):
                mask_woSkin = 2*iriMask + pupMask # Iris = 2, Pupil = 3
            else:
                # Neither fit exists, mask should be -1s.
                print('Found bad mask: {}'.format(imName))
                mask_woSkin = -np.ones(I.shape)
                continue

            # Add model information
            keydict['archive'].append(ds_name)
            keydict['resolution'].append(I.shape)
            keydict['pupil_loc'].append(pupil_loc)

            # Append images and label map
            Data['Images'].append(I)
            Data['Info'].append(imName_full) # Train or valid
            Data['Masks'].append(LabelMat)
            Data['Masks_noSkin'].append(mask_woSkin)
            Data['pupil_loc'].append(pupil_loc)

            # Append fits
            Data['Fits']['pupil'].append(model_pupil.model)
            Data['Fits']['iris'].append(model_iris.model)

            keydict['Fits']['pupil'].append(model_pupil.model)
            keydict['Fits']['iris'].append(model_iris.model)

            if not noDisp:
                if i == 0:
                    cE = Ellipse(tuple(pupil_loc),
                                 2*model_pupil.model[2],
                                 2*model_pupil.model[3],
                                 angle=np.rad2deg(model_pupil.model[4]))
                    cL = Ellipse(tuple(model_iris.model[0:2]),
                                       2*model_iris.model[2],
                                       2*model_iris.model[3],
                                       np.rad2deg(model_iris.model[4]))
                    cE.set_facecolor('None')
                    cE.set_edgecolor((1.0, 0.0, 0.0))
                    cL.set_facecolor('None')
                    cL.set_edgecolor((0.0, 1.0, 0.0))
                    cI = plts.imshow(I)
                    cM = plts.imshow(mask_woSkin, alpha=0.5)
                    plts.add_patch(cE)
                    plts.add_patch(cL)
                    plt.show()
                    plt.pause(.01)
                else:
                    cE.center = tuple(pupil_loc)
                    cE.angle = np.rad2deg(model_pupil.model[4])
                    cE.width = 2*model_pupil.model[2]
                    cE.height = 2*model_pupil.model[3]
                    cL.center = tuple(model_iris.model[0:2])
                    cL.width = 2*model_iris.model[2]
                    cL.height = 2*model_iris.model[3]
                    cL.angle = np.rad2deg(model_iris.model[-1])
                    cI.set_data(I)
                    cM.set_data(mask_woSkin)
                    mypause(0.01)
            if(i % 1000 == 0):
                print('loading {} images.....!!!\n\n\n\n'.format(i))
            i = i + 1
    print('{} images: {}'.format(dirCond, i))

    # Stack data
    Data['Images'] = np.stack(Data['Images'], axis=0)
    Data['Masks'] = np.stack(Data['Masks'], axis=0)
    Data['Masks_noSkin'] = np.stack(Data['Masks_noSkin'], axis=0)
    Data['pupil_loc'] = np.stack(Data['pupil_loc'], axis=0)
    Data['Fits']['pupil'] = np.stack(Data['Fits']['pupil'], axis=0)
    Data['Fits']['iris'] = np.stack(Data['Fits']['iris'], axis=0)

    keydict['resolution'] = np.stack(keydict['resolution'], axis=0)
    keydict['archive'] = np.stack(keydict['archive'], axis=0)
    keydict['pupil_loc'] = np.stack(keydict['pupil_loc'], axis=0)

    # Save data
    dd.io.save(os.path.join(PATH_DS, ds_name+'.h5'), Data)
    scio.savemat(os.path.join(PATH_MASTER, str(ds_name)+'.mat'), keydict, appendmat=True)
    ds_num=ds_num+1

helperfunctions.py
def getValidPoints(LabelMat, isPartSeg=True):
    '''
    RK: This can only be used specifically for PartSeg
    Given labels, identify pupil and iris points.
    pupil: label == 3, iris: label ==2
    '''
    im = np.uint8(255*LabelMat.astype(np.float32)/LabelMat.max())
    edges = cv2.Canny(im, 50, 100) + cv2.Canny(255-im, 50, 100)
    r, c = np.where(edges)
    pupilPts = []
    irisPts = []
    for loc in zip(c, r):
        temp = LabelMat[loc[1]-1:loc[1]+2, loc[0]-1:loc[0]+2]
        # 瞳孔是23的分界线,因此不能出现0和1
        condPupil = np.any(temp == 0) or np.any(temp == 1) or temp.size==0 # Not a valid pupil point
        # 虹膜是12的分解心啊,因此不能出现0和3,但是对于PartSeg来说,可能出现眼皮挡住出现0的情况
        if isPartSeg:
            condIris = np.any(temp == 0) or np.any(temp == 3) or temp.size==0
        else:
            condIris = np.any(temp == 3) or temp.size==0
        pupilPts.append(np.array(loc)) if not condPupil else None
        irisPts.append(np.array(loc)) if not condIris else None
    pupilPts = np.stack(pupilPts, axis=0) if len(pupilPts) > 0 else []
    irisPts = np.stack(irisPts, axis=0) if len(irisPts) > 0 else []
    return pupilPts, irisPts


class ElliFit():
    def __init__(self, **kwargs):
        self.data = np.array([]) # Nx2
        self.W = np.array([])
        self.Phi = []
        self.pts_lim = 6*2
        for k, v in kwargs.items():
            setattr(self, k, v)
        if np.size(self.W):
            self.weighted = True
        else:
            self.weighted = False
        if np.size(self.data) > self.pts_lim:
            self.model = self.fit()
            self.error = np.mean(self.fit_error(self.data))
        else:
            self.model = [-1, -1, -1, -1, -1]
            self.Phi = [-1, -1, -1, -1, -1]
            self.error = np.inf

    def fit(self):
        # Code implemented from the paper ElliFit
        xm = np.mean(self.data[:, 0])
        ym = np.mean(self.data[:, 1])
        x = self.data[:, 0] - xm
        y = self.data[:, 1] - ym
        X = np.stack([x**2, 2*x*y, -2*x, -2*y, -np.ones((np.size(x), ))], axis=1)
        Y = -y**2
        if self.weighted:
            self.Phi = np.linalg.inv(
                X.T.dot(np.diag(self.W)).dot(X)
                ).dot(
                    X.T.dot(np.diag(self.W)).dot(Y)
                    )
        else:
            try:
                self.Phi = np.matmul(np.linalg.inv(np.matmul(X.T, X)), np.matmul(X.T, Y))
            except:
                self.Phi = -1*np.ones(5, )
        try:
            x0=(self.Phi[2]-self.Phi[3]*self.Phi[1])/((self.Phi[0])-(self.Phi[1])**2)
            y0=(self.Phi[0]*self.Phi[3]-self.Phi[2]*self.Phi[1])/((self.Phi[0])-(self.Phi[1])**2)
            term2=np.sqrt(((1-self.Phi[0])**2+4*(self.Phi[1])**2))
            term3=(self.Phi[4] + (y0)**2 + (x0**2)*self.Phi[0] + 2*self.Phi[1])
            term1=1+self.Phi[0]
            b=(np.sqrt(2*term3/(term1+term2)))
            a=(np.sqrt(2*term3/(term1-term2)))
            alpha=0.5*np.arctan2(2*self.Phi[1],1-self.Phi[0])
            model = [x0+xm, y0+ym, a, b, -alpha]
        except:
            print('Inappropriate model generated')
            model = [np.nan, np.nan, np.nan, np.nan, np.nan]
        if np.all(np.isreal(model)) and np.all(~np.isnan(model)) and np.all(~np.isinf(model)):
            model = model
        else:
            model = [-1, -1, -1, -1, -1]
        return model

    def fit_error(self, data):
        # General purpose function to find the residual
        # model: xc, yc, a, b, theta
        term1 = (data[:, 0] - self.model[0])*np.cos(self.model[-1])
        term2 = (data[:, 1] - self.model[1])*np.sin(self.model[-1])
        term3 = (data[:, 0] - self.model[0])*np.sin(self.model[-1])
        term4 = (data[:, 1] - self.model[1])*np.cos(self.model[-1])
        res = (1/self.model[2]**2)*(term1 - term2)**2 + \
            (1/self.model[3]**2)*(term3 + term4)**2 - 1
        return np.abs(res)


class ransac():
    def __init__(self, data, model, n_min, mxIter, Thres, n_good): #ransac(pupilPts, ElliFit, 15, 40, 5e-3, 15)
        self.data = data
        self.num_pts = data.shape[0]
        self.model = model #拟合的方法
        self.n_min = n_min
        self.D = n_good if n_min < n_good else n_min
        self.K = mxIter
        self.T = Thres
        self.bestModel = self.model(**{'data': data}) #Fit function all data points

    def loop(self): # 选取合适的inliers,使得拟合误差最小
        i = 0
        if self.num_pts > self.n_min:
            while i <= self.K:
                # Pick n_min points at random from dataset
                inlr = np.random.choice(self.num_pts, self.n_min, replace=False)
                loc_inlr = np.in1d(np.arange(0, self.num_pts), inlr)
                outlr = np.where(~loc_inlr)[0]
                potModel = self.model(**{'data': self.data[loc_inlr, :]})
                listErr = potModel.fit_error(self.data[~loc_inlr, :])
                inlr_num = np.size(inlr) + np.sum(listErr < self.T)
                if inlr_num > self.D:
                    pot_inlr = np.concatenate([inlr, outlr[listErr < self.T]], axis=0)
                    loc_pot_inlr = np.in1d(np.arange(0, self.num_pts), pot_inlr)
                    betterModel = self.model(**{'data': self.data[loc_pot_inlr, :]})
                    if betterModel.error < self.bestModel.error:
                        self.bestModel = betterModel
                i += 1
        else:
            # If the num_pts <= n_min, directly return the model
            self.bestModel = self.model(**{'data': self.data})
        return self.bestModel

评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值