import tensorflow as tf
import numpy as np
import pandas as pd
train = pd.read_csv('breast-cancer-train.csv')
test = pd.read_csv('breast-cancer-test.csv')
#分割特征与分类目标
x_train = np.float32(train[['Clump Thickness', 'Cell Size']].T)
y_train = np.float32(train['Type'].T)
x_test = np.float32(test[['Clump Thickness', 'Cell Size']].T)
y_test = np.float32(test['Type'].T)
b = tf.Variable(tf.zeros([1]))#产生一个初始值为0且长度为1的变量
w = tf.Variable(tf.random_uniform([1, 2], -1.0, 1.0))
y = tf.matmul(w, x_train) + b
loss = tf.reduce_mean(tf.square(y - y_train))
optimizer = tf.train.GradientDescentOptimizer(0.01)
train = optimizer.minimize(loss)
init = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init)
for step in range(0, 1000):
sess.run(train)
if step%20 ==0 :
print(step, sess.run(w), sess.run(b))
test_negative = test.loc[test['Type']==0][['Clump Thickness', 'Cell Size']]#对于索引是字符串的行或列进行选取用loc,如果是数字索引则用iloc
test_positive = test.loc[test['Type']==1][['Clump Thickness', 'Cell Size']]
import matplotlib.pyplot as plt
plt.scatter(test_negative['Clump Thickness'], test_negative['Cell Size'], marker='o', s=200, c='red')
plt.scatter(test_positive['Clump Thickness'], test_positive['Cell Size'], marker='x', s=200, c='black')
plt.xlabel('Clump Thickness')
plt.ylabel('Cell Size')
lx = np.arange(0, 12)
ly = (0.5 - sess.run(b) - lx*sess.run(w)[0][0]) / sess.run(w)[0][1]#可以将这一句代码转化成数学表达式进行理解: w1*x + w2*y + b = 0.5 => y = (0.5 - b - w1*x) / w2
plt.plot(lx, ly, color='green')
plt.show()
运行结果如下:
0 [[-0.23947293 0.16662335]] [ 0.05221816]
20 [[-0.09618336 0.20805234]] [ 0.03749228]
40 [[-0.05583316 0.17235593]] [ 0.01368934]
60 [[-0.02602061 0.14622101]] [-0.00531601]
80 [[-0.00399889 0.12711641]] [-0.02054434]
100 [[ 0.01226366 0.11317632]] [-0.03278786]
120 [[ 0.02426958 0.10302609]] [-0.04266377]
140 [[ 0.03312998 0.09565362]] [-0.05065468]
160 [[ 0.03966646 0.09031425]] [-0.05713942]
180 [[ 0.04448639 0.08646055]] [-0.06241645]
200 [[ 0.04803878 0.08369045]] [-0.0667218]
220 [[ 0.05065546 0.08170898]] [-0.0702429]
240 [[ 0.05258162 0.08029999]] [-0.07312901]
260 [[ 0.05399844 0.07930534]] [-0.07549952]
280 [[ 0.0550397 0.0786095]] [-0.07745022]
300 [[ 0.05580419 0.07812826]] [-0.07905825]
320 [[ 0.05636485 0.07780034]] [-0.08038589]
340 [[ 0.05677548 0.07758131]] [-0.08148361]
360 [[ 0.05707579 0.07743898]] [-0.08239238]
380 [[ 0.05729502 0.07735023]] [-0.08314564]
400 [[ 0.05745475 0.07729841]] [-0.08377065]
420 [[ 0.05757084 0.07727169]] [-0.08428973]
440 [[ 0.05765497 0.07726164]] [-0.08472122]
460 [[ 0.05771576 0.07726233]] [-0.08508016]
480 [[ 0.05775951 0.07726966]] [-0.08537898]
500 [[ 0.05779084 0.07728079]] [-0.08562788]
520 [[ 0.05781316 0.0772938 ]] [-0.0858353]
540 [[ 0.05782894 0.07730742]] [-0.08600827]
560 [[ 0.05784002 0.07732085]] [-0.08615255]
580 [[ 0.05784772 0.07733358]] [-0.08627296]
600 [[ 0.05785297 0.07734536]] [-0.08637349]
620 [[ 0.0578565 0.07735603]] [-0.08645742]
640 [[ 0.05785881 0.07736558]] [-0.08652751]
660 [[ 0.05786028 0.07737401]] [-0.0865861]
680 [[ 0.05786116 0.0773814 ]] [-0.08663507]
700 [[ 0.05786163 0.07738784]] [-0.08667598]
720 [[ 0.05786182 0.07739341]] [-0.08671017]
740 [[ 0.05786185 0.07739821]] [-0.08673877]
760 [[ 0.05786176 0.07740233]] [-0.08676267]
780 [[ 0.0578616 0.07740586]] [-0.08678267]
800 [[ 0.05786142 0.07740885]] [-0.08679944]
820 [[ 0.05786123 0.0774114 ]] [-0.08681341]
840 [[ 0.05786103 0.07741357]] [-0.08682512]
860 [[ 0.05786083 0.07741541]] [-0.08683491]
880 [[ 0.05786065 0.07741699]] [-0.0868431]
900 [[ 0.05786048 0.0774183 ]] [-0.08684997]
920 [[ 0.05786035 0.07741939]] [-0.0868557]
940 [[ 0.05786021 0.07742034]] [-0.0868605]
960 [[ 0.05786011 0.07742112]] [-0.08686452]
980 [[ 0.05786001 0.07742178]] [-0.08686788]
效果图如下: