Python基础

内建数据结构与函数

数据结构

可变对象的值发生变化时对象的内存地址不变,如列表、字典、集合、Numpy数组,不可变对象的值发生变化内存地址也会发生变化,如元组、字符串、数值。

元组(tuple)

不可变对象,长度固定,同一个元组的元素类型可以是不同的。

最简单的建立方式是用逗号分隔元素,也可用小括号将元素括起来。

元组一旦创建则各位置上的对象不能被改变,但是对象内部的元素可以被改变。

若将元组赋值给变量,则元组会被拆包。

tup1 = 1, 2, 3           # 等于tup = (1, 2, 3)
tup2 = ('t', 1, [1, 2])  # 元组的元素可以是不同类型的
tup1 + tup2              # 可用+连接两个元组
tup3 = tuple([1, 2, 3])  # 任意序列(包括字符串)或迭代器都能用tuple函数转化成元组
a, b, c = tup1           # 元组拆包
列表(list)

长度可变,内容可更改,同一个元组的元素类型可以是不同的。

用中括号或list函数创建,用append方法增加元素,用insert方法插入元素到指定位置,用pop方法移除指定位置元素,用remove方法移除第一个符合要求的值。

用sort方法给列表元素排序。

用in关键字检查一个值是否在列表中,not in关键字检查一个值是否不在列表中。

可用+连接两个列表(这样会创建新列表),也可用extent方法给列表添加多个元素(这种方式比前者更快)。

list1 = [1, 2, 3]              # 用中括号创建
list2 = list(range(0, 10, 2))  # 用list函数创建, 作用对象也可以是元组

list1.append(3)                # 添加元素到列表结尾
list1.insert(2, 10)            # 在位置2处添加元素10(原来的位置2及之后的元素整体往后移一位)
list1.pop(2)                   # 移除位置2的元素
list1.remove(3)                # 移除第一个等于3的元素

list1.sort()                   # 数值型列表排序

1 in list1                     # 判断对象是否是列表元素
1 not in list1                 # 判断对象是否不是列表元素

list1 + list2                  # 连接两个列表, 计算代价较大, 大型列表的构建中最好用extend方法添加元素
list1.extend([1, '2'])         # 给列表添加多个元素

a, b, c = list1                # 列表拆包
字典(dictionary)

又称哈希表或关联数组,是键值对作为元素的序列。

创建时最外面是大括号,里面键和值之间用冒号分,键值对之间用逗号分隔。

列表和元组的元素通过索引来访问,而字典的元素通过键来访问。

用in关键字判断字典是否含有一个键。

用keys()方法和values()方法获得字典的键和值。

用update方法合并字典(给字典加入新值)。

用del关键字或pop方法删除字典的键值对。

注意字典的键必须是不可变对象。

dict1 = {'a': 'value', 'b': [1, 2, 3]}

print(dict1['b'])     # 通过键访问值
>> [1, 2, 3]

dict1[3] = 'val3'     # 通过键在字典中插入值, 这里3是键, 'val3'是对应的值
print(dict1)
>> {'a': 'value', 'b': [1, 2, 3], 3: 'val3'}

'a' in dict1          # 检查字典是否含有一个键(类似还有not in)
>> True

list(dict1.keys())    # 字典的键
list(dict1.values())  # 字典的值

dict1.update({'b': 1, 4: [1, 2]})  # 添加的内容中有和原字典键相同的键'b', 此时添加的内容覆盖原内容
>> {'a': 'value', 'b': 1, 3: 'val3', 4: [1, 2]}

del dict1[3]          # 删除 3: 'val3' 这一键值对
dict1.pop('b')        # 删除 'b': [1, 2, 3] 这一键值对, 此命令除删除键值对外还会返回值[1, 2, 3]
集合(set)

集合是无序且元素唯一的数据类型,类似于只有键没有值的字典。

用set函数或大括号内列出元素的方式创建。

集合对象支持数学上的交并差等运算。

set1 = set([1, 2, 3, 2, 1, 4])
set2 = {1, 2, 2, 3}

内建序列函数

序列函数enumerate

有时我们需要在遍历序列的同时追踪元素的索引,enumerate函数能够起到这样的作用,其为python内置函数,作用在一个序列上能够产生一个序列(数据类型为enumerate),序列的每个元素是(索引, 元素值)形式的元组,用在 for 循环当中,语法为enumerate(iterable, start=0):

for i, val in enumerate([1, 2, 3], start=0):  # start表示索引i从什么值开始, 可省略, 默认为0
    print(i, val)

>> 0 1  
   1 2  
   2 3
序列排序函数sorted

作用在任意序列上,返回一个已经排序的新建列表。

序列配对函数zip

将若干序列配对形成一个以元组为元素的列表:

seq1 = ['a', 'b', 'c']
seq2 = [1, 2, 3]
seq3 = (True, False)

z1 = zip(seq1, seq2)        # 将作用到的序列的元素按索引相同者合并为元组再组成列表
list(z1)
>> [('a', 1), ('b', 2), ('c', 3)]

z2 = zip(seq1, seq2, seq3)  # 作用在长短不一的序列上, 返回结果的长度等于最短序列的长度
list(z2)
>> [('a', 1, True), ('b', 2, False)]
序列倒置函数reverse

生成器,将序列倒置。

seq = ['a', 'b', 'c', 1, True]
list(reversed(seq))
>> [True, 1, 'c', 'b', 'a']
range

使用格式为range(start, stop, step),其返回一个数字组成的可迭代对象(不是列表),元素以start开始,每次前进step,最后一个元素小于stop(即不包含stop),另有两种用法,range(start, stop)表示step=1的range(start, stop, step),range(stop)表示start=0且step=1的range(start, stop, step)。

函数

函数定义

使用def关键字定义函数,返回值使用return关键字:

# 函数定义格式: def 函数名(自变量1, 自变量2, ...): 函数体 return 结果
def fun(x, y, z=1.0):
    f = (x + y) / z
    return f

# 以下调用方式都可以
fun(1, 2, z=3)
fun(1, 2, 3)
fun(1, 2)
fun(x=1, y=2, z=3)
fun(y=2, x=1, z=3)

# 函数自变量的位置可以什么都不写
def fun():
    f = 1
    return f

# 函数返回多个值
def fun(x):
    f1 = x + 1
    f2 = x + 2
    return f1, f2

上面的x和y叫做位置参数,z叫做关键字参数,后者用于指定默认值和可选参数,且必须跟在位置参数后面,调用函数时,若不写x,y等参数名则在对应位置上写上它们的值,若写上参数名则不必按照定义函数时参数的书写顺序。z的取值可写可不写(不写就用默认值)。

希望定义的函数返回多个值,则可以在最后return的时候将多个结果用逗号分开,回忆前面元组的定义方式可知这样实际上是将多个结果组成一个元组来作为函数输出。类似地也可以将结果组成列表、字典等作为输出。

注意函数也可以作为参数传入其他函数中,如:

def fun1(x):
    f = x + 1
    return f

def fun2(x):
    f = x + 2
    return f

def fun3(x, funs):
    f = x
    for fun in funs:
        f = fun(f)
    return f

fun3(x=1, funs=[fun1, fun2])
>> 4
全局变量与局部变量

在函数内部声明的变量是局部变量。它们只在函数内部可见,函数外部无法直接访问。在函数外部声明的变量是全局变量。它们在整个程序中都是可见的,可以在任何函数内部访问(除非在函数内部被同名的局部变量覆盖)。如果想在函数内部修改全局变量,你需要使用global关键字来声明该变量是全局的。否则,Python会将其视为一个新的局部变量。

python本身没有内置方法判断变量是全局的还是局部的,但是如果你在一个函数中并且没有使用global关键字,那么任何你创建的变量都是局部变量。而在函数外部创建的变量就是全局变量。

匿名函数

又称lambda函数,是通过单个语句生成的函数,使用lambda关键字定义。由于匿名函数代码量小因此简单的功能尽量使用匿名函数来定义。如:

f = lambda x: x + 1

print(f(1.5))
>> 2.5
函数是否改变输入的参数

函数是否改变输入的参数(自变量)取决于参数传递的方式(值传递或引用传递)以及函数内部对参数的操作。

值传递:对于不可变类型(如数值、字符串和元组),Python实际上是通过值传递的。这意味着当你将一个不可变对象传递给函数时,函数会接收到这个对象的一个副本,而不是原始对象本身。因此,在函数内部对参数所做的任何修改都不会影响到原始对象。

引用传递:对于可变类型(如列表、字典和集合),Python是通过引用传递的。这意味着当你将一个可变对象传递给函数时,函数会接收到这个对象的引用,即指向原始对象的指针。因此,在函数内部对参数所做的修改会影响到原始对象。

全局变量与局部变量:如果函数内部修改了一个全局变量的值(使用global关键字),那么这个修改会影响到全局作用域中的变量。但请注意,在函数内部直接给一个变量赋值(而不是修改它的属性或元素),通常会创建一个新的局部变量,而不是修改全局变量。

x = 10

def change_x():
    x = 20  # 创建一个新的局部变量x,但不修改全局变量x

change_x()
print(x)
>> 10
函数定义中加入注解

一般来说定义函数第一行的形式为def fun(val):,有时为了代码的可读性还可以对传入函数的参数以及函数返回值加上注解表明其数据类型,如def fun(val: int) -> bool:实际上和不加注解的效果是一样的,但是能够告诉用户val参数应当是整型的,而该函数的返回值是布尔值。

索引

1. 现有n维序列对象arr(列表、numpy数组、tensor),列表索引用arr[i1][i2]...[in],numpy数组和tensor索引用arr[i1][i2]...[in]或arr[i1,i2,...,in]。

2. 对numpy数组或tensor,上面的i1可以用整数值一维列表、numpy数组、tensor代替。

3. 由于python的索引是从0开始的,因此索引中i代表第i-1个元素。

4. -i代表倒数第i个元素。

5. i1可用a:b:c代替,a、b、c分别代表开始位置、结束位置+1、跳跃间隔,可以使用省略模式,a:b表示c=1时的a:b:c,:b:c表示a=0时的a:b:c,a::c表示直到最后一个元素为止,a、b、c都可取负数,c取负数代表每次倒退c个位置。::-1表示翻转列表。

类(class)

描述具有某些特点的对象(对象就是类的实例,这里实例指具体化表达,是一个具体的对象,即类是一个抽象模板,实例是用类创建出来的具体的对象),用于定义对象的结构,在类中可以定义对象的属性和方法(所谓方法就是在类中定义的函数)。类是对象的图纸,同时类也是一个对象(用于创建对象的对象)。

用class关键字定义,内部定义两类函数,一是名为__init__的函数,二是其他名称的函数,这些函数的第一个参数一定是self(表示这个类自身),可以没有其他参数。__init__函数用来定义我们认为类必须具备的属性,不用写return,内部语句为self.attr = attr。

用类创建实例的方式为inst = classname(attr, value),其中attr是__init__中定义的类属性,value是属性取值,假如类中没有__init__或__init__只有一个self参数或__init__的所有参数都有默认值,则可用inst = classname()来创建实例。

除了给类附加各种属性的__init__外,还可以给类写各种方法,用于对类的属性进行操作,除第一个参数必须是self外,其他和定义普通函数一样,如

class classname:

    def  __init__(self, name, score):
        self.name = name
        self.score = score

    def print_class(self, temp):
        if temp == 'name':
            print(self.name)
        else:
            print(self.score)

调用类方法的时候不用在括号中写self,假如方法的参数仅有一个self则用inst.print_class()的形式调用,否则就在括号中写self后面的那些参数,如inst.print_class('score')。类定义的时候若一个类方法fun1调用了另一个类方法fun2则应当在fun1的函数体内部以self.fun2的形式调用fun2。

Numpy

多维数组对象

Numpy的基本对象为numpy数组,又称数组或ndarray,其每个元素都是同种类型(同质),多用于处理数值类数据。

数组的属性:每个数组都有一个shape属性和一个dtype属性,前者表示数组的每个维的长度(以元组形式呈现),后者表示数组元素的数据类型,除非所有元素都是整数值,否则数组的默认数据类型是float64。数组还有ndim属性,表示维数。还有size属性,表示数组元素个数。

创建数组:基本方式为将array函数作用在任意序列型对象上,用嵌套序列的方式创建多维数组。

全0数组和全1数组:有一些函数可以创建具有特定形状的特殊数组,如zeros和ones函数创建元素全为0和1的数组,二者需要作用在一个整数或整数组成的元组上(表示数组形状)。zeros_like和ones_like函数作用在ndarray上返回和参数数组形状相同的元素全为0或1的ndarray。

转换数组的数据类型:astype方法可以转换ndarray的数据类型,用这一方式原数组不会发生变化,而是会产生新的数组。

range的数组版本:arange函数,参数和range相同,但是返回的是ndarray。

数组算数:两个形状相同的数组之间的算数操作等于二者对应元素的算术操作,而一个标量和一个数组的算术操作等于这个标量和数组每个元素进行算数操作,形状相同的两个数组进行比较会产生一个形状与二者相同的布尔型数组。不同尺寸的数组之间的操作会用到广播特性。

转置:arr.transpose()和arr.T都能得到数组arr的转置,前者是方法,后者是属性。

排序:用数组的sort方法可将数组排序,注意排过之后原数组就变了。sort方法只能对向量进行排序,对高维数组,需要用axis=i传入被排序的轴。

针对一维数组的集合操作:unique返回数组的所有元素取值(排序后的),每个元素只列一次。in1d(arr1, arr2)检查数组arr1的值是否在数组arr2中,返回和第一个数组同等长度的布尔型数组。此外还有取交集的intersect1d、取并集的union1d等。

拼接矩阵:将矩阵A、B竖着拼接np.vstack((A, B)),横着拼接np.hstack((A, B))

import numpy as np

# 创建一维数组
arr1 = np.array([1, 2, 3])
print(arr1.shape)
>> (3,)
print(arr1.dtype)
>> dtype('int32')
print(arr1.ndim)
>> 1

# 创建多维数组
arr2 = np.array([[1,2.0], [3, 4]])
print(arr2.shape)
>> (2, 2)
print(arr2.dtype)
>> dtype('float64')
print(arr2.ndim)
>> 2

# 用函数创建特殊数组
arr3 = np.ones(10)
arr3 = np.ones((2,2))
arr3 = np.zeros((2,3,4))
print(arr3.shape)
>> (2, 3, 4)

# 转换数组数据类型
arr4 = np.array(['1', '2', '3'])
arr4.astype(np.float64)

# arange
print(np.arange(1, 10, 2))
>> [1 3 5 7 9]

# 排序
arr5 = np.random.randn(5)
arr5.sort()
arr6 = np.random.randn(2, 3)
arr6.sort(axis=1)

切片索引

数据的索引方式见前面的“索引”一节,需要注意的是单个索引是对原数组的值的复制,而切片索引是原数组的视图,我们令一个对象等于arr[i],改变这个对象的取值不会改变原数组,但是若令一个对象等于数组切片,则改变这个对象的元素时原数组也会相应变化,若我们希望得到数组切片的拷贝而非视图则应该使用数组的copy方法。

可用布尔值数组对数组进行索引,需要注意布尔值数组的长度必须和它索引的那个数组的轴索引长度一致。and和or关键字对布尔值数组无效,只能使用&和|。

import numpy as np

arr = np.arange(10)
print(arr)
>> [0 1 2 3 4 5 6 7 8 9]

a1 = arr[3]
a1 = -1
print(arr)
>> array([ 0,  1,  2, 3, 4, 5,  6,  7,  8,  9])

a2 = arr[3:6]
a2[:] = -1
print(arr)
>> array([ 0,  1,  2, -1, -1, -1,  6,  7,  8,  9])

a3 = arr[3:6].copy()
a3[:] = -1
print(arr)
>> array([ 0,  1,  2, 3, 4, 5,  6,  7,  8,  9])

逐元素数组函数

将函数作用到数组的每个元素上,常用的有:

import numpy as np

arr = np.random.randn(2, 2)

np.abs(arr)             # 对每个元素取绝对值
np.sqrt(arr)            # 对每个元素取平方根
np.exp(arr)             # 对每个元素x取e^x
np.log(arr)             # 对每个元素取log
np.sign(arr)            # 计算每个元素的符号值(正值变成1, 负值变成-1, 0变成0)
np.ceil(arr)            # 对每个元素取平方根
np.floor(arr)           # 对每个元素x取e^x

arr1 = np.random.randn(2, 2)
arr2 = np.random.randn(2, 2)

np.add(arr1, arr2)      # 对应位置元素求和
np.maximum(arr1, arr2)  # 对应位置元素求最大值
np.minimum(arr1, arr2)  # 对应位置元素求最小值

np.where

三元函数,第一个参数是布尔型数组代表条件,第二、三个参数是标量或与第一个参数形状相同的数组,条件成立时返回第二个参数否则返回第三个参数。

此函数也可以仅传入一个参数即条件,此时返回的结果是数组中符合条件的元素的坐标。

import numpy as np

arr = np.random.randn(3, 3)

np.where(arr>0, 1, -1)   # 将数组中大于0的元素变成1, 小于等于0的变成-1
np.where(arr>0, 1, arr)  # 将数组中大于0的元素变成1, 其他不变
np.where(arr>0)          # 返回arr中大于0的元素的坐标

统计学数组方法

常用的有sum、mean、min、max、argmin、argmax、std、var等(标准差和方差默认分母是n,可以设置成n-1),这些数组方法可以不传入参数,此时是对整个数组进行操作返回一个标量(如arr.sum()返回整个数组的和),也可以传入axis=i参数,表示对哪一个轴进行函数操作。

布尔值数组

布尔值数组在上节的方法中True和False分别作为1和0处理,因此可以用(arr>0).sum()这样的操作计算True的个数。布尔值数组有两个常用方法any和all,现有布尔值数组arrbool,用arrbool.any()和arrbool.all()返回其中是否含有True和是否全是True。

数组文件的读取和保存

使用np.save和np.load函数,前者的参数是文件名和数组,或者的参数是文件名。此时数组文件后缀名是npy,用这种方式保存的数组读取后仍然是原来的数组。

可以使用np.savez将多个数组保存进同一个文件中,此时数组文件后缀名是npz。加载此数组文件会得到一个字典对象,用字典的键来得到之前的数组。可用np.savez_compressed函数将数组存入已经存在的npz文件中。

import numpy as np

arr = np.random.randn(3, 3)

np.save('array', arr)  # 此处文件名可以写后缀.npy也可不写
np.load('array.npy')

arr1 = np.random.randn(3, 3)
arr2 = np.random.randn(2, 3)

np.savez('arrayz', a1=arr1, a2=arr2)
a = np.load('arrayz.npz')
a['a1']

arr3 = np.random.randn(2, 2)
np.savez_compressed('arrayz.npz', a3=arr3)

线性代数运算

矩阵乘法用数组方法dot或函数np.dot实现。

计算矩阵行列式np.det。

np.diag以一维数组的形式返回矩阵对角元素,或将一维序列转换成以此为对角线的矩阵。

另有求逆函数inv、求特征值函数eig、QR分解函数qr、求迹函数trace、解线性方程组函数solve等。

Pandas

处理表格型或异质型数据,基本对象为series和dataframe。

Series

一维序列型对象,用pd.Series(值序列, 标签序列)建立,Series有两个重要属性即值(values)和索引(index)。字典可以直接转化为Series。

可以使用标签来索引,也可以使用值的条件表达式进行布尔值索引。注意python中大部分索引是不包含尾部的,但是Series索引包含。

两个Series对象进行操作会自动对齐。

import pandas as pd

# 创建
series1 = pd.Series([1, 2, 3, 4], index=['a', 'b', 'c', 'd'])  # 标准创建方式
series2 = pd.Series([1, 2, 3, 4])                              # 不写index时默认其为range(5)
series3 = pd.Series({'a': 1, 'b': 2, 'c': 3, 'd': 4})          # 用字典创建

# 两个重要属性
print(series1.values)
>> [1 2 3 4]
print(series1.index)
>> Index(['a', 'b', 'c', 'd'], dtype='object')
series3.index = ['a', '2', '3', 4]

# 索引
print(series1[['a', 'c']])                                     # 标签索引
>> a    1
   c    3
   dtype: int64
print(series1[series1>2])                                      # 布尔值索引
>> c    3
   d    4
   dtype: int64

# 对齐
series4 = pd.Series([1.5, 1, 3, 6], index=['f', 'b', 'c', 'g'])
print(series1 + series4)
>> a    NaN
   b    3.0
   c    6.0
   d    NaN
   f    NaN
   g    NaN
   dtype: float64

DataFrame

数据框是矩阵形式的数据表,不同列可以是不同的数据类型,既有行索引也有列索引。

建立数据框的方法很多,常用pd.DataFrame函数作用在一个字典data上建立,需要注意的是其每列需要等长。创建时可选columns参数指定列的顺序,若传入columns的参数中有不是data中的键的则数据框中还会有这一列但是全部是缺失值。可选index参数指定每一行的行名称(索引)。

数据框有index和column属性(相当于行名和列名)。数据框的values属性会以二维ndarray的形式返回数据框内容。

head方法返回数据框前五行。

用DataFrame["列名"]或DataFrame.列名的方式索引一列,可以列索引的方式给列赋值,若该索引不存在,则创建新列(但DataFrame.列名不能用于创建新列),注意数据框索引是视图,因此原数据的变化会体现在数据框上,若不想这样则应当使用copy方法。

用del关键字删除列。

DataFrame.T是数据框转置。

用loc方法按行、列名索引,用iloc方法按行序号和列序号索引。

两个DataFrame对象进行操作会自动对齐,行和列都会对齐。

import pandas as pd

# 创建
data = {'c1': [1,2,3,4,5], 'c3': ['a', 'b', 'c', 'd', 'e'], 'c2': range(5)}
cols = ['c1', 'c2', 'c3']
rows = range(1, 6)
frame = pd.DataFrame(data, columns=cols, index=rows)

# loc, iloc索引, 注意后面是中括号
frame.loc[1:3, 'c1':'c3']
frame.iloc[0:3, [0,2]]

常用操作

reindex方法重排数据框或Series(将行顺序重排),创建新对象而不改变原始对象。之前不存在的行或列会被创建成缺失值。对数据框该方法可通过index和columns参数重排行和列。

drop方法根据行名或列名删除条目,通过axis=1或axis="columns"参数来选择删除列,不写axis参数则默认删除行。drop方法默认创建新的对象而不改变原始对象,若想直接改变原对象则添加inplace=True参数。

import pandas as pd

# 创建
data = {'c1': [1,2,3,4,5], 'c3': ['a', 'b', 'c', 'd', 'e'], 'c2': range(5)}
cols = ['c1', 'c2', 'c3']
rows = range(1, 6)
frame = pd.DataFrame(data, columns=cols, index=rows)

# reindex调整顺序
frame.reindex([2,1,3,5,4])
frame.reindex(columns=['c3','c1','c2'])

# drop删除条目
frame.drop(5)                # 等价于frame.drop(5, axis=0)
frame.drop(5, inplace=True)  # 更改原数据框
frame.drop('c1', axis=1)

Scikit-learn (sklearn)

Python的第三方机器学习库。

示例数据集

Scikit-learn提供了一些标准数据集作为示例,包括波士顿房价数据集,鸢尾花数据集,手写数字数据集(包含了1797个手写数字的图像数据)等。

from sklearn import datasets

# 鸢尾花数据集
iris = datasets.load_iris()                    # 数据集iris, 字典形式
print(iris.keys())
>> dict_keys(['data', 'target', 'frame', 'target_names', 'DESCR', 'feature_names', 'filename', 'data_module'])
iris_data = iris['data']                       # 150×4的二维ndarray, 行代表样品, 列代表特征
iris_target = iris['target']                   # 150长度的一维ndarray, 因变量(花的种类)

# 波士顿房价数据集从某一版本开始被移除, 用另一个房价数据集代替
housing = datasets.fetch_california_housing()  # 房价数据集, 字典形式
print(housing.keys())
>> dict_keys(['data', 'target', 'frame', 'target_names', 'feature_names', 'DESCR'])
housing_data = housing['data']                 # 20640×8的二维ndarray, 行代表样品, 列代表特征
housing_target = housing['target']             # 20640长度的一维ndarray, 因变量(房价)

# 手写数字数据集
digits = datasets.load_digits()                # 数据集digits, 字典形式
print(digits.keys())
>> dict_keys(['data', 'target', 'frame', 'feature_names', 'target_names', 'images', 'DESCR'])
digits_data = digits['data']                   # 1797×64的二维ndarray, 行代表样品, 列代表特征
digits_target = digits['target']               # 1797长度的一维ndarray, 因变量(0-9之间的数字)

交叉验证数据集划分

sklearn中用sklearn.model_selection.train_test_split函数划分数据集,参数首先是需要划分的数据集,可以是多个。其次是test_size或train_size,二者相加为1因此指定其一即可。random_state指定随机数种子保证结果可重复性,可以不指定。shuffle指定是否将数据集顺序打乱再划分,默认True。stratify表示是否按照某个因素进行分层抽样。返回结果是列表,存储被划分后的数据集。

import numpy as np
from sklearn.model_selection import train_test_split

X = np.random.randn(5,3,2)
y = np.array(range(0, 5))
z = np.random.rand(5)

result = train_test_split(X, y, z, test_size=0.4, random_state=1)
X_train, X_val, y_train, y_val, z_train, z_val = result

注意以上方法只能将数据集划分成两部分,若要进行k折交叉验证则需要重复用上述函数划分k-1次。第一次对原始数据集进行划分,test_size=1/k,得到了训练集1(占原数据集的(k-1)/k)和测试集1,第二次对训练集1进行划分,test_size=1/(k-1),得到训练集2(占原数据集的(k-2)/k)和测试集2,依此类推,直到第k-1次对训练集k-2进行划分,test_size=1/2,得到训练集k-1(同时也是测试集k-1),每个测试集都占原数据集的1/k。

train_test_split返回的是划分后的数据而不会返回划分的索引,若希望得到划分索引则应当以索引定义一个数组,作为train_test_split的参数传入。

  • 10
    点赞
  • 11
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值