1 问题描述
因为我在进行网络训练时使用到了伪标签,又鉴于打乱数据集能减轻过拟合,提升模型性能。从而将dataloader加载数据集时shuffle设置为True,但又要按照相同顺序打乱伪标签,试了多种方法均失败。
2 解决办法
将state = np.random.get_state()放在训练代码的随机种子之后和for epoch循环中的最后一行,将np.random.set_state(state)放在打乱标签的前一行。
3 证明
下面是实验代码:(目的是使a和b的顺序相同,即说明打乱顺序相同)
import numpy as np
np.random.seed(42)
state = np.random.get_state()#getstate函数
a = np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
b = np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
c = np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
d = np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
for i in range(0, 10):
np.random.shuffle(a)
np.random.shuffle(c)
np.random.set_state(state)#setstate函数
np.random.shuffle(b)
np.random.shuffle(d)
print(a)
print(c)
print(b)
print(d)
print(" ", i)
state = np.random.get_state()#getstate函数
最上面的getstate函数用来捕捉for循环内第一次shuffle;for循环内第一个setstate用来加载;而第二个getstate用来捕捉for循环内第一次shuffle,这是考虑到在数据集设置为shuffle时,数据会在循环开始之前被打乱,故将getstate放在for循环内代码的最后一行。
下图是执行结果,可以看到a与b相同,c与d相同,也就是getstate之后第一次打乱与setstate的第一次打乱相同,第二次与第二次相同。