torch实现RBF(径向基神经网络)

import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
import time
"""
参考:https://goodgoodstudy.blog.csdn.net/article/details/105756137
"""

def SaveImage(label,pre,path):
    label = label.view(-1).cpu().detach().numpy()
    pre = pre.view(-1).cpu().detach().numpy()
    plt.rcParams['font.sans-serif'] = 'KaiTi'
    plt.rcParams['axes.unicode_minus'] = False
    fig = plt.figure(dpi=400)
    ax = fig.add_subplot(111)
    ax.plot(label, color='blue', label="实际值")
    ax.plot(pre, color='red', linestyle='--', label='拟合值')
    ax.legend()
    fig.savefig(path, dpi=400)

class RBF(nn.Module):
    def __init__(self,centers,n_out=1):
        super(RBF, self).__init__()
        self.n_out = n_out
        self.num_centers = centers.size(0)

        self.centers = nn.Parameter(centers)
        self.beta = nn.Parameter(torch.ones(1,self.num_centers))
        self.linear = nn.Linear(self.num_centers+n_out,self.n_out)
        self.initialize_weights()

    def kernel_fun(self,batches):
        n_input = batches.size(0)
        c = self.centers.view(self.num_centers,-1).repeat(n_input,1,1)# torch.Size([500, 500, 1])
        x = batches.view(n_input,-1).unsqueeze(1).repeat(1,self.num_centers,1)# torch.Size([500, 500, 1])
        radial_val = torch.exp(-self.beta.mul((c-x).pow(2).sum(2)))
        return radial_val

    def forward(self,x):
        # 计算径向基距离函数
        radial_val = self.kernel_fun(x)
        out = self.linear(torch.cat([x,radial_val],dim=1))
        return out

    def initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                m.weight.data.normal_(0, 0.2)
                m.bias.data.zero_()
            elif isinstance(m, nn.ConvTranspose2d):
                m.weight.data.normal_(0, 0.2)
                m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                m.weight.data.normal_(0, 0.2)
                m.bias.data.zero_()


num_centers = 500
n_out = 1
centers = torch.randn(num_centers,n_out)

model = RBF(centers,n_out=1)
optimizer = optim.Adam(model.parameters(),lr=0.1)
loss_fun = nn.MSELoss()

X_ = torch.linspace(-5,5,500).view(500,1)
Y_ = torch.mul(1.1*(1-X_+X_.pow(2).mul(2)),torch.exp(X_.pow(2).mul(-0.5)))

start = time.time()
epochs = 1000
for epoch in range(epochs):
    avg_loss = 0
    Y_pre = model(X_)
    loss = loss_fun(Y_pre,Y_)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    print("epoch:{}\t  loss:{:>.9}".format(epoch+1,loss.item()))
end = time.time()
print("time:",end-start)

Y_pre = model(X_)
SaveImage(Y_,Y_pre,"RBF.png")

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

该用户太丑不配起昵称

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值