线性回归
1 步骤
- 构建一个类,叫做 LinearRegression
- 在这个类中定义模型
- 计算 MSE 均方误差损失函数
- 定义优化器
- 反向传播
- 预测
举个例子,我们有个汽车公司,如果车价格越低低,我们可以卖更多的车。
car_price_np = np.array([3,4,5,6,7,8,9], dtype=np.float32).reshape(-1,1)
car_price_tensor = Variable(torch.from_numpy(car_price_np))
number_of_car_sell_np = np.array([ 7.5, 7, 6.5, 6.0, 5.5, 5.0, 4.5], dtype=np.float32).reshape(-1,1)
number_of_car_sell_tensor = Variable(torch.from_numpy(number_of_car_sell_np))
2 定义线性回归
class LinearRegression(nn.Module):
def __init__(self, input_size, output_size):
super(LinearRegression, self).__init__()
self.linear = nn.Linear(input_size, output_size)
def forward(self, x):
return self.linear(x)
3 将网络实例化,定义一些参数
input_dim = 1
output_dim = 1
model = LinearRegression(input_dim, output_dim)
mse = nn.MSELoss()
lr = 0.02
optimizer = torch.optim.SGD(model.parameters(), lr=lr)
loss_list = []
total_iter = 1001
4 拟合数据集
for iteration in range(total_iter):
optimizer.zero_grad()
results = model(car_price_tensor)
loss = mse(results, number_of_car_sell_tensor)
loss.backward()
optimizer.step()
loss_list.append(loss.data)
if (iteration % 100 == 0):
print("epoch {}, loss {}".format(iteration, loss.data))
逻辑回归
- 线性回归对于分类任务表现不好
- 可以使用逻辑回归进行分类
- 逻辑回归=线性回归+logistic 函数(softmax)
逻辑回归的步骤,这里使用 MNIST 数据集:
- 准备数据集
- 使用 Mnist
- 数据需要归一化
- 需要划分训练集和测试集
- DataLoader 将 dataset 和 sampler 合并到一起
- 建立逻辑回归模型
- 在 pytorch 中,逻辑回归函数是在损失函数中的
- 创建模型
- 创建损失函数:交叉熵损失函数,其中有包括 softmax
- 创建优化器
- 训练模型
- 预测
1 加载数据集
train = pd.read_csv("./train.csv", dtype=np.float32)
targets_numpy = train.label.values
features_numpy = train.loc[:, train.columns != "label"].values / 255
features_train, features_test, targets_train, targets_test = train_test_split(
features_numpy, targets_numpy, test_size = 0.2, random_state = 42
)
featuresTrain = torch.from_numpy(features_train)
targetsTrain = torch.from_numpy(targets_train).type(torch.LongTensor)
featuresTest = torch.from_numpy(features_test)
targetsTest = torch.from_numpy(targets_test).type(torch.LongTensor)
train = torch.utils.data.TensorDataset(featuresTrain, targetsTrain)
test = torch.utils