every blog every motto: The shortest answer is doing.
0. 前言
之前有过这方面小结的打算,后来有事耽误了,正好这次又遇到坑了,遂填之。话不多说,下面进入正文。
本节主要是有关fit_transform、transform、fit区别及其使用的相关小结。要注意使用过程中的reshape
说明: 这方面依然没有进行深入探究,有问题欢迎指正,后续有待完善。
1. 正文
1.1 基本概念
- fit():
- 简单来说,求得训练集的均值、方差、最大值、最小值等属性
- 对数据进行拟合。
- transform():
- 在fit的基础上,进行归一化等。
- fit_transform():
- 是fit和transform的组合
1.2 重要知识点
- 一般先对训练数据fit_transform,对测试数据用transform。
- 测试数据(testdata)是用训练数据(traindata)求得的均值、方差等属性进行后续转换,原因参考文章1
- fit没有训练的意思。不要和tensorflow.Sequential中model.fit混淆,具体例子如点我查看
- 使用时需要reshape(-1,1)
1.3 例子验证
1.3.1 numpy 标准化
a = np.arange(4).reshape(2, 2)
mean = np.mean(a)
std = np.std(a)
print('标准化前:\n', a)
print('a.shape:', a.shape)
print('平均数:', mean)
print('方差:', std)
new_array = (a - mean) / std
print('标准化后: \n', new_array)
结果:
1.3.2 fit_transform标准化(正常操作)
import numpy as np
from sklearn.preprocessing import StandardScaler
a = np.arange(4).reshape(2, 2)
print('标准化前:\n', a)
print('a.shape:', a.shape)
scaler = StandardScaler()
b_reshape = scaler.fit_transform(a.reshape(-1, 1)).reshape(2, 2)
print('通过fit_transform标准化后的:\n', b_reshape)
结果:
1.3.3 fit、transform标准化(正常操作)
import numpy as np
from sklearn.preprocessing import StandardScaler
a = np.arange(4).reshape(2, 2)
print('标准化前:\n', a)
print('a.shape:', a.shape)
scaler = StandardScaler()
scaler.fit(a.reshape(-1, 1))
b_reshape = scaler.transform(a.reshape(-1,1)).reshape(2, 2)
print('通过fit、transform标准化后的:\n', b_reshape)
结果:
1.3.4 fit_transform标准化(不进行reshape)
import numpy as np
from sklearn.preprocessing import StandardScaler
a = np.arange(4).reshape(2, 2)
print('标准化前:\n', a)
print('a.shape:', a.shape)
scaler = StandardScaler()
# scaler.fit(a.reshape(-1, 1))
b_reshape = scaler.fit_transform(a).reshape(2, 2)
print('通过fit_transform标准化后的:\n', b_reshape)
错误结果,如下图:
1.3.5 fit、transform标准化(不进行reshape)
import numpy as np
from sklearn.preprocessing import StandardScaler
a = np.arange(4).reshape(2, 2)
print('标准化前:\n', a)
print('a.shape:', a.shape)
scaler = StandardScaler()
scaler.fit(a)
b_reshape = scaler.transform(a).reshape(2, 2)
print('通过fit、transform标准化后的:\n', b_reshape)
错误结果,如下图:
参考文献
[1] https://blog.csdn.net/yyhhlancelot/article/details/85097656?utm_medium=distribute.pc_relevant.none-task-blog-BlogCommendFromMachineLearnPai2-2.nonecase&depth_1-utm_source=distribute.pc_relevant.none-task-blog-BlogCommendFromMachineLearnPai2-2.nonecase
[2] https://tensorflow.google.cn/versions/r2.0/api_docs/python/tf/keras/Sequential
[3] https://blog.csdn.net/weixin_39190382/article/details/104107836
[4] https://blog.csdn.net/anshuai_aw1/article/details/82498374?utm_medium=distribute.pc_relevant.none-task-blog-baidujs-1
[5] https://blog.csdn.net/liuweiyuxiang/article/details/83028667?utm_medium=distribute.pc_relevant.none-task-blog-baidujs-2
[6] https://blog.csdn.net/hellocsz/article/details/89930616