pytorch 实现RBF网络

直接上代码: 

import torch, random
import torch.nn as nn
import torch.optim as optim

torch.manual_seed(42)

class RBFN(nn.Module):
    """
    以高斯核作为径向基函数
    """
    def __init__(self, centers, n_out=3):
        """

        :param centers: shape=[center_num,data_dim]
        :param n_out:
        """
        super(RBFN, self).__init__()
        self.n_out = n_out
        self.num_centers = centers.size(0) # 隐层节点的个数
        self.dim_centure = centers.size(1) # 
        self.centers = nn.Parameter(centers)
        # self.beta = nn.Parameter(torch.ones(1, self.num_centers), requires_grad=True)
        self.beta = torch.ones(1, self.num_centers)*10
        # 对线性层的输入节点数目进行了修改
        self.linear = nn.Linear(self.num_centers+self.dim_centure, self.n_out, bias=True)
        self.initialize_weights()# 创建对象时自动执行


    def kernel_fun(self, batches):
        n_input = batches.size(0)  # number of inputs
        A = self.centers.view(self.num_centers, -1).repeat(n_input, 1, 1)
        B = batches.view(n_input, -1).unsqueeze(1).repeat(1, self.num_centers, 1)
        C = torch.exp(-self.beta.mul((A - B).pow(2).sum(2, keepdim=False)))
        return C

    def forward(self, batches):
        radial_val = self.kernel_fun(batches)
        class_score = self.linear(torch.cat([batches, radial_val], dim=1))
        return class_score

    def initialize_weights(self, ):
        """
        网络权重初始化
        :return:
        """
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                m.weight.data.normal_(0, 0.02)
                m.bias.data.zero_()
            elif isinstance(m, nn.ConvTranspose2d):
                m.weight.data.normal_(0, 0.02)
                m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                m.weight.data.normal_(0, 0.02)
                m.bias.data.zero_()

    def print_network(self):
        num_params = 0
        for param in self.parameters():
            num_params += param.numel()
        print(self)
        print('Total number of parameters: %d' % num_params)

# centers = torch.rand((5,8))
# rbf_net = RBFN(centers)
# rbf_net.print_network()
# rbf_net.initialize_weights()


if __name__ =="__main__":
    data = torch.tensor([[0.25, 0.75], [0.75,0.75], [0.25,0.5], [0.5,0.5],[0.75,0.5],
                         [0.25,0.25],[0.75,0.25],[0.5,0.125],[0.75,0.125]], dtype=torch.float32)
    label = torch.tensor([[-1,1,-1],[1,-1,-1],[-1,-1,1],[-1,-1,1],[-1,-1,1],
                          [1,-1,-1],[-1,1,-1],[-1,1,-1],[1,-1,-1]], dtype=torch.float32)
    print(data.size())

    centers = data[0:8,:]
    rbf = RBFN(centers,3)
    params = rbf.parameters()
    loss_fn = torch.nn.MSELoss()
    optimizer = torch.optim.SGD(params,lr=0.1,momentum=0.9)

    for i in range(10000):
        optimizer.zero_grad()

        y = rbf.forward(data)
        loss = loss_fn(y,label)
        loss.backward()
        optimizer.step()
        print(i,"\t",loss.data)

    # 加载使用
    y = rbf.forward(data)
    print(y.data)
    print(label.data)

说明:代码在https://goodgoodstudy.blog.csdn.net/article/details/105756137上进行了小修改(原代码应该是错的),并加了一个自己的实验。

评论 5
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值