import torch
from torch import nn
from torchvision import transforms
label = torch.tensor([56, 91, 89, 0, 62, 34, 21, 67, 88, 12, 70, 90, 18, 60, 9, 37, 18, 14,
28, 14, 56, 75, 50, 39, 68, 77, 4, 7, 67, 57, 97, 70, 4, 18, 94, 12,
94, 92, 40, 85, 88, 37, 7, 14, 86, 81, 63, 3, 44, 16, 78, 68, 37, 33,
86, 0, 73, 85, 50, 57, 37, 59, 14, 37])
class_images = [[] for i in range(98)]
# print(class_images)
for i in torch.tensor([ 0, 3, 4, 7, 9, 12, 14, 16, 18, 21, 28, 33, 34, 37, 39, 40, 44, 50,
56, 57, 59, 60, 62, 63, 67, 68, 70, 73, 75, 77, 78, 81, 85, 86, 88, 89,
90, 91, 92, 94, 97]): # label.unique()
# print(i)
class_images[i] = np.where(label != i)[0]
print(class_images[89])
for i in range(128):
# print(label[i % 64])
units = class_images[label[i % 64]]
units_random = [random.choice(class_images[label[i % 64]])]
print(units)
print(units_random)
输出结果如下:
[ 0 1 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24
25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48
49 50 51 52 53 54 55 56 57 58 59 60 61 62 63]
[ 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 16 17 18 19 20 21 22 23 24
25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 42 43 44 45 46 47 48 49
50 51 53 54 55 56 57 58 59 61 62]
[19]