idx_train, idx_val, idx_test 的生成。
import torch,random
import numpy as np
if __name__ == '__main__':
idx_train, idx_val,idx_test = [],[],[]
idx_train = np.hstack([idx_train, random.sample(list(range(30)), 10)])
others = np.delete(range(30), idx_train.astype(int))
idx_val = np.hstack([idx_val, random.sample(list(others), 10)])
# print(idx_train, others,idx_val)
others = np.delete(range(30), np.hstack([idx_train.astype(int),idx_val.astype(int)]))
idx_test = np.hstack([idx_test, others])
idx_train, idx_val, idx_test = torch.LongTensor(idx_train), \
torch.LongTensor(idx_val), torch.LongTensor(idx_test)
print(idx_train,idx_val,idx_test)
tensor([ 0, 6, 17, 12, 2, 28, 13, 21, 5, 4]) tensor([20, 26, 24, 27, 1, 25, 19, 11, 9, 16]) tensor([ 3, 7, 8, 10, 14, 15, 18, 22, 23, 29])