pytorch -- linear regression

import numpy as np 
import pandas as pd 
import os
import matplotlib.pyplot as pl
import seaborn as sns
import warnings
data = pd.read_csv('/kaggle/input/vehicle-dataset-from-cardekho/car data.csv')

import torch
inputs = np.asarray(data.values[:,1:3], dtype = 'float32')

targets = np.asarray(data.values[:,3], dtype = 'float32')
targets.shape


inputs = torch.from_numpy(inputs)
targets = torch.from_numpy(targets)
import torch.nn as nn
# Define model
model = nn.Linear(2,1)
print(model.weight)
print(model.bias)
# Import nn.functional
import torch.nn.functional as F
# Define loss function
loss_fn = F.mse_loss
loss = loss_fn(model(inputs), targets)
print(loss)
# Define optimizer
opt = torch.optim.SGD(model.parameters(), lr=1e-5)

# Utility function to train the model
def fit(num_epochs, model, loss_fn, opt, train_dl):
    
    # Repeat for given number of epochs
    for epoch in range(num_epochs):
        
        # Train with batches of data
        for xb,yb in train_dl:
            
            # 1. Generate predictions
            pred = model(xb)
            
            # 2. Calculate loss
            loss = loss_fn(pred, yb)
            
            # 3. Compute gradients
            loss.backward()
            
            # 4. Update parameters using gradients
            opt.step()
            
            # 5. Reset the gradients to zero
            opt.zero_grad()
        
        # Print the progress
        if (epoch+1) % 10 == 0:
            print('Epoch [{}/{}], Loss: {:.4f}'.format(epoch+1, num_epochs, loss.item()))



from torch.utils.data import TensorDataset
# Define dataset
train_ds = TensorDataset(inputs, targets)

from torch.utils.data import DataLoader
# Define data loader
batch_size = 20
train_dl = DataLoader(train_ds, batch_size, shuffle=True)
fit(10000, model, loss_fn, opt, train_dl)

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值