【课程作业经验】基于MIndSpore波士顿房价数据预测

基于mindspore实现全连接网络的波士顿数据集房价预测

北京理工大学邱小尧,写本次帖子是基于机器学习实践课程完成的相关使用mindspore深度学习框架完成的任务,写一些分享心得,本次实验我们预采用mindspore进行

数据导入与准备

在这里我们使用准备好的txt文档进行数据读取,其中CRIM,ZN,INDUS,CHAS,NOX,RM,AGE,DIS,RAD,TAX,PTRATI,B,LSTAT,MEDV为其属性,

其中MEDV为我们需要预测的房价。

构建可按照如下方法构造DatasetGenerator并依此得到我们的dataset

class DatasetGenerator:
    def __init__(self):
        self.data = data[:,:-1]
        self.label = data[:,-1]

    def __getitem__(self, index):
        return self.data[index], self.label[index]

    def __len__(self):
        return len(self.data)

dataset_generator = DatasetGenerator()
dataset = ds.GeneratorDataset(dataset_generator, ["data", "label"], shuffle=False)

dataset = dataset.shuffle(buffer_size=15)
dataset = dataset.batch(batch_size=16)
train_dataset,test_dataset=dataset.split([0.8,0.2])
复制

全连接网络模型建立

接下来我们就开始构建全连接网络了仅采用三层全连接,较为简单。

class net(nn.Cell):
    def __init__(self):
        super(net,self).__init__()
        self.fc1 = nn.Dense(13,10)
        self.relu = nn.ReLU()
        self.fc2 = nn.Dense(10,1)

    def construct(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)

        return x
复制

设定优化器以及其他参数

在这里我们采用MSEloss函数,以及Adam优化器,感兴趣的话可以探究其他参数。

net = net()
loss = nn.MSELoss()
learning_rate = 1e-3
# optim = nn.Momentum(net.trainable_params(), learning_rate, 0.9)
# optim = nn.SGD(net.trainable_params(),learning_rate=learning_rate,momentum=0.9)
optim = nn.AdamWeightDecay(net.trainable_params(),learning_rate=learning_rate,weight_decay=1e-5)
cb = LossMonitor()
epochs = 100
复制

模型训练

Mindspore实现的模型训练就很简单了,封装的比较好。

model = Model(net, loss_fn=loss, optimizer=optim)
model.train(epoch=epochs, train_dataset=train_dataset, callbacks=cb)
复制

测试结果

在最终结果上,我们发现神经网络训练的拟合效果并不是很好,大概率因为Boston数据集数据量较少的原因。

 

  • 2
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值