python-按相同顺序打乱伪标签和数据集-np.random.get_state()和np.random.set_state()函数

1 问题描述

        因为我在进行网络训练时使用到了伪标签,又鉴于打乱数据集能减轻过拟合,提升模型性能。从而将dataloader加载数据集时shuffle设置为True,但又要按照相同顺序打乱伪标签,试了多种方法均失败。

2 解决办法

        将state = np.random.get_state()放在训练代码的随机种子之后和for epoch循环中的最后一行,将np.random.set_state(state)放在打乱标签的前一行。

3 证明

        下面是实验代码:(目的是使a和b的顺序相同,即说明打乱顺序相同)

import numpy as np
np.random.seed(42)
state = np.random.get_state()#getstate函数
a = np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
b = np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
c = np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
d = np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
for i in range(0, 10):
    np.random.shuffle(a)
    np.random.shuffle(c)
    np.random.set_state(state)#setstate函数
    np.random.shuffle(b)
    np.random.shuffle(d)
    print(a)
    print(c)
    print(b)
    print(d)
    print(" ", i)
    state = np.random.get_state()#getstate函数

最上面的getstate函数用来捕捉for循环内第一次shuffle;for循环内第一个setstate用来加载;而第二个getstate用来捕捉for循环内第一次shuffle,这是考虑到在数据集设置为shuffle时,数据会在循环开始之前被打乱,故将getstate放在for循环内代码的最后一行。

下图是执行结果,可以看到a与b相同,c与d相同,也就是getstate之后第一次打乱与setstate的第一次打乱相同,第二次与第二次相同。

  • 2
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值