Deep Learning for Computer Vision with Python
Deep Learning for Computer Vision with Python之SGD
前言
随着人工智能的不断发展,机器学习这门技术也越来越重要,很多人都开启了学习机器学习,本文就介绍了机器学习的基础内容。
一、SGD是什么?
*相比于GD,SGD不是使用整个数据集更新参数,而是按照指定的batch_size更新参数
二、使用步骤
1.引入库
代码如下(示例):
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
from sklearn.datasets import make_blobs
import matplotlib.pyplot as plt
import argparse
2.代码
相比于GD,添加了函数gengerate_batch产生batch_size大小的数据块。
代码如下(示例):
parse = argparse.ArgumentParser()
parse.add_argument("lr", type = float, default = 0.01, help = "input your learning rate")
parse.add_argument("batch_size", type = int, default = 24, help = "input epoch")
parse.add_argument("--epoch", type = int, default = 500, help = "input epoch")
arg = parse.parse_args()
def activation_fuction(x):
return 1 / (1 + np.exp(-x))
def predict(X , W):
preds = activation_fuction(X.dot(W))
preds[preds <= 0.5] = 0
preds[preds > 0.5] = 1
return preds
def generate_bactch(X, Y, batch_size):
for i in range(0, X.shape[0], batch_size):
yield (X[i: batch_size], Y[i: batch_size])
if __name__ == '__main__':
#generate datat sample
(X ,Y) = make_blobs(1000, n_features = 2, centers = 2, cluster_std = 1.5, random_state = 1)
#X的shape = (1000, 2), Y的shape = (1000, 1)
Y = Y.reshape(Y.shape[0], 1)
X = np.c_[X, np.ones((X.shape[0], 1))]
#随机切割样本
trainx, testx, trainy, testy = train_test_split(X, Y, test_size = 0.5, random_state = 42)
######初始化权重矩阵W
W = np.random.randn(3,1)
Loss = []
batchLoss = []
for i in range(arg.epoch):
for (bacthx,batchy) in generate_bactch(trainx, trainy, arg.batch_size):
preds = predict(bacthx, W)
error = batchy - preds
loss = np.sum(1/2*error**2)
batchLoss.append(loss)
gradient = bacthx.T.dot(error)
W += arg.lr * gradient
epochloss = np.average(batchLoss)
Loss.append(epochloss)
if i > 0 and (i + 1)%5:
print(f"epoch {i+1}/{arg.epoch+1}: {epochloss}")
preds = predict(testx, W)
print(classification_report(testy, preds))
color = []
for s in testy:
if s == 1:
color.append('r')
else:
color.append('b')
##第一张图,测试集分布散点图
plt.style.use("ggplot")
plt.figure()
plt.title("data")
plt.scatter(testx[:,0], testx[:,1],c = color)
##第二章图 loss图
plt.style.use("ggplot")
plt.figure()
plt.title("Train Loss")
plt.xlabel("epoch")
plt.ylabel("Loss")
plt.plot(range(0,arg.epoch), Loss)
plt.show()
总结
就这样啦