https://github.com/orobix/retina-unet/blob/f6ff63bb39409942fd3aaf6907dbce74ca143207/src/retinaNN_training.py
以上是参考链接
#Divide all the full_imgs in pacthes
def extract_ordered_overlap(full_imgs, patch_h, patch_w,stride_h,stride_w):
assert (len(full_imgs.shape)==4) #4D arrays
assert (full_imgs.shape[1]==1 or full_imgs.shape[1]==3) #check the channel is 1 or 3
img_h = full_imgs.shape[2] #height of the full image
img_w = full_imgs.shape[3] #width of the full image
assert ((img_h-patch_h)%stride_h==0 and (img_w-patch_w)%stride_w==0)
N_patches_img = ((img_h-patch_h)//stride_h+1)*((img_w-patch_w)//stride_w+1) #// --> division between integers
N_patches_tot = N_patches_img*full_imgs.shape[0]
print ("Number of patches on h : " +str(((img_h-patch_h)//stride_h+1)))
print ("Number of patches on w : " +str(((img_w-patch_w)//stride_w+1)))
print ("number of patches per image: " +str(N_patches_img) +", totally for this dataset: " +str(N_patches_tot))
patches = np.empty((N_patches_tot,full_imgs.shape[1],patch_h,patch_w))
iter_tot = 0 #iter over the total number of patches (N_patches)
for i in range(full_imgs.shape[0]): #loop over the full images
for h in range((img_h-patch_h)//stride_h+1):
for w in range((img_w-patch_w)//stride_w+1):
patch = full_imgs[i,:,h*stride_h:(h*stride_h)+patch_h,w*stride_w:(w*stride_w)+patch_w]
patches[iter_tot]=patch
iter_tot +=1 #total
assert (iter_tot==N_patches_tot)
return patches #array with all the full_imgs divided in patches
def recompone_overlap(preds, img_h, img_w, stride_h, stride_w):
assert (len(preds.shape)==4) #4D arrays
assert (preds.shape[1]==1 or preds.shape[1]==3) #check the channel is 1 or 3
patch_h = preds.shape[2]
patch_w = preds.shape[3]
N_patches_h = (img_h-patch_h)//stride_h+1
N_patches_w = (img_w-patch_w)//stride_w+1
N_patches_img = N_patches_h * N_patches_w
print ("N_patches_h: " +str(N_patches_h))
print ("N_patches_w: " +str(N_patches_w))
print ("N_patches_img: " +str(N_patches_img))
assert (preds.shape[0]%N_patches_img==0)
N_full_imgs = preds.shape[0]//N_patches_img
print ("According to the dimension inserted, there are " +str(N_full_imgs) +" full images (of " +str(img_h)+"x" +str(img_w) +" each)")
full_prob = np.zeros((N_full_imgs,preds.shape[1],img_h,img_w)) #itialize to zero mega array with sum of Probabilities
full_sum = np.zeros((N_full_imgs,preds.shape[1],img_h,img_w))
k = 0 #iterator over all the patches
for i in range(N_full_imgs):
for h in range((img_h-patch_h)//stride_h+1):
for w in range((img_w-patch_w)//stride_w+1):
full_prob[i,:,h*stride_h:(h*stride_h)+patch_h,w*stride_w:(w*stride_w)+patch_w]+=preds[k]
full_sum[i,:,h*stride_h:(h*stride_h)+patch_h,w*stride_w:(w*stride_w)+patch_w]+=1
k+=1
assert(k==preds.shape[0])
assert(np.min(full_sum)>=1.0) #at least one
final_avg = full_prob/full_sum
print (final_avg.shape)
assert(np.max(final_avg)<=1.0) #max value for a pixel is 1.0
assert(np.min(final_avg)>=0.0) #min value for a pixel is 0.0
return final_avg