直接上代码:
import torch, random
import torch.nn as nn
import torch.optim as optim
torch.manual_seed(42)
class RBFN(nn.Module):
"""
以高斯核作为径向基函数
"""
def __init__(self, centers, n_out=3):
"""
:param centers: shape=[center_num,data_dim]
:param n_out:
"""
super(RBFN, self).__init__()
self.n_out = n_out
self.num_centers = centers.size(0) # 隐层节点的个数
self.dim_centure = centers.size(1) #
self.centers = nn.Parameter(centers)
# self.beta = nn.Parameter(torch.ones(1, self.num_centers), requires_grad=True)
self.beta = torch.ones(1, self.num_centers)*10
# 对线性层的输入节点数目进行了修改
self.linear = nn.Linear(self.num_centers+self.dim_centure, self.n_out, bias=True)
self.initialize_weights()# 创建对象时自动执行
def kernel_fun(self, batches):
n_input = batches.size(0) # number of inputs
A = self.centers.view(self.num_centers, -1).repeat(n_input, 1, 1)
B = batches.view(n_input, -1).unsqueeze(1).repeat(1, self.num_centers, 1)
C = torch.exp(-self.beta.mul((A - B).pow(2).sum(2, keepdim=False)))
return C
def forward(self, batches):
radial_val = self.kernel_fun(batches)
class_score = self.linear(torch.cat([batches, radial_val], dim=1))
return class_score
def initialize_weights(self, ):
"""
网络权重初始化
:return:
"""
for m in self.modules():
if isinstance(m, nn.Conv2d):
m.weight.data.normal_(0, 0.02)
m.bias.data.zero_()
elif isinstance(m, nn.ConvTranspose2d):
m.weight.data.normal_(0, 0.02)
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
m.weight.data.normal_(0, 0.02)
m.bias.data.zero_()
def print_network(self):
num_params = 0
for param in self.parameters():
num_params += param.numel()
print(self)
print('Total number of parameters: %d' % num_params)
# centers = torch.rand((5,8))
# rbf_net = RBFN(centers)
# rbf_net.print_network()
# rbf_net.initialize_weights()
if __name__ =="__main__":
data = torch.tensor([[0.25, 0.75], [0.75,0.75], [0.25,0.5], [0.5,0.5],[0.75,0.5],
[0.25,0.25],[0.75,0.25],[0.5,0.125],[0.75,0.125]], dtype=torch.float32)
label = torch.tensor([[-1,1,-1],[1,-1,-1],[-1,-1,1],[-1,-1,1],[-1,-1,1],
[1,-1,-1],[-1,1,-1],[-1,1,-1],[1,-1,-1]], dtype=torch.float32)
print(data.size())
centers = data[0:8,:]
rbf = RBFN(centers,3)
params = rbf.parameters()
loss_fn = torch.nn.MSELoss()
optimizer = torch.optim.SGD(params,lr=0.1,momentum=0.9)
for i in range(10000):
optimizer.zero_grad()
y = rbf.forward(data)
loss = loss_fn(y,label)
loss.backward()
optimizer.step()
print(i,"\t",loss.data)
# 加载使用
y = rbf.forward(data)
print(y.data)
print(label.data)
说明:代码在https://goodgoodstudy.blog.csdn.net/article/details/105756137上进行了小修改(原代码应该是错的),并加了一个自己的实验。