传统耦合神经网络(pcnn)算法的实现(python):
参数的设定没有具体参考,这是一篇文献中的解释:
# coding:utf-8 #
from PIL import Image
from pylab import *
from scipy import signal as sg
from PCNN.noise import sp_noise
import numpy as np
class Pcnn_class():
def PCNN(self, img_arr, iteration_num,
Af = 0.60, Al = 1.0, Atop = 0.80,
Vf = 0.2, Vl = 1, Vtop = 3000.0,
Beta = 0.1):
# def PCNN(self, img_arr, iteration_num,
# Af = 0.06931, Al = 1.47, Atop = 0.80,
# Vf = 20, Vl = 1.9, Vtop = 3000.0,
# Beta = 0.16):
# 获得矩阵的维度
x_len, y_len = img_arr.shape
print(" shape: ",x_len,y_len)
# 定义神经元矩阵,建立n=0时的F、L矩阵(都为0矩阵)
W = np.array([[0.7070, 1, 0.7070], [1, 0, 1], [0.7070, 1, 0.7070]])
M = np.array([[0.7070, 1, 0.7070], [1, 0, 1], [0.7070, 1, 0.7070]])
Y = np.zeros_like(img_arr, dtype="float64")
F = np.zeros_like(img_arr, dtype="float64")
L = np.zeros_like(img_arr, dtype="float64")
top = np.zeros_like(img_arr, dtype="float64")
# 定义点火过程
for i in range(0, iteration_num):
K = sg.convolve2d(Y, M, boundary='symm', mode="same")
print ("第 %s 次迭代,K矩阵:" % i, K)
F = exp(-Af) * F + Vf * K + img_arr
L = exp(-Al) * L + Vl * K
# print "第 %s 次迭代,L矩阵:" % i, "\n", L
U = F * (1 + Beta * L)
top = exp(-Atop) * top + Vtop * Y
for x_axis in range(0, x_len):
for y_axis in range(0, y_len):
if (U[x_axis, y_axis] > top[x_axis, y_axis]):
Y[x_axis, y_axis] = 1.0
else:
Y[x_axis, y_axis] = 0.0
print ("第 %s 次迭代完成。\n" % i)
print(len(Y[Y == 0]))
print(len(Y[Y == 1]))
return Y
#####################################################################
# 获得图像矩阵
if __name__ == "__main__":
img = Image.open(r'C:\Users\dell\Desktop\bs\test-14-MR.jpg').convert('L')
img = np.array(img)
noise_img = sp_noise(img, 0.05)
p=Pcnn_class()
img_out = p.PCNN(img_arr=img, iteration_num=10)
print ("done!")
fig,(ax1,ax2,ax3)=plt.subplots(1,3,figsize=(10,6)) #建立1行2列的图fig
ax1.imshow(noise_img,cmap='gray') #显示原始的图
ax1.set_axis_off() #不显示坐标轴
ax1.set_title('PIC1_L')
ax2.imshow(img_out,cmap='gray') #显示原始的图
ax2.set_axis_off() #不显示坐标轴
ax2.set_title('PIC2_L')
ax3.imshow(img, cmap='gray') # 显示原始的图
ax3.set_axis_off() # 不显示坐标轴
ax3.set_title('origil')
plt.show()