# -*- coding: utf-8 -*-
import numpy as np
import copy
import torch
import torch.nn as nn
import torch
import torchvision
import torchvision.transforms as transforms
BATCH_SIZE = 128
transform_train=transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.5,0.5,0.5),(0.2,0.2,0.2)),
])
transform_test=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,0.5,0.5),(0.2,0.2,0.2))
])
trainset=torchvision.datasets.CIFAR10(
root='/Users/dengxq/Downloads',train=True,download=False,transform=transform_train)
trainloader=torch.utils.data.DataLoader(trainset,batch_size=BATCH_SIZE,shuffle=False,num_workers=4)
testset=torchvision.datasets.CIFAR10(
root='/Users/dengxq/Downloads',train=False,download=False,transform=transform_test)
testloader=torch.utils.data.DataLoader(testset, batch_size=BATCH_SIZE, shuffle=False,num_workers=4)
class Net(torch.nn.Module):
def __init__(self, layers=None, fc=None):
super(Net, self).__init__()
self.layers = torch.nn.Sequential(
nn.Conv2d(3, 32, 3, padding=1),
nn.BatchNorm2d(32),
nn.ReLU(inplace=True),
nn.Conv2d(32, 64, 3, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.MaxPool2d(2, 2),
nn.Dropout(p=0.05, inplace=False),
nn.Conv2d(64, 128, 3, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(),
nn.Conv2d(128, 256, 3, padding=1),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
nn.MaxPool2d(2, 2),
nn.Dropout(p=0.05, inplace=False),
nn.Conv2d(256, 256, 3, padding=1),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
nn.MaxPool2d(2, 2),
nn.Dropout(p=0.05, inplace=False),
nn.Conv2d(256, 512, 3, padding=1),
nn.BatchNorm2d(512),
nn.ReLU(),
nn.AvgPool2d(4, 4)
)
self.fc = nn.Linear(512, 10)
def forward(self, x):
x = self.layers(x)
x = x.view(-1, 512)
x = self.fc(x)
return x
def set_layer(self, layers, fc):
self.layers = layers
self.fc = fc
class ElitismGA:
def __init__(self, _pop_size, _r_mutation, _p_mutation,
_epochs, _elite_num, _mating_pool_size, _batch_size=32):
# input params
self.pop_size = _pop_size
self.r_mutation = _r_mutation
self.p_mutation = _p_mutation # for generational
self.epochs = _epochs
self.elite_num = _elite_num # for elitism
self.mating_pool_size = _mating_pool_size # for elitism
self.batch_size = _batch_size
# other params
self.chroms = []
self.evaluation_history = []
self.stddev = 0.5
self.criterion = nn.CrossEntropyLoss()
self.model = None
def initialization(self):
for i in range(self.pop_size):
net = Net()
self.chroms.append(net)
print('network initialization({}) finished.'.format(self.pop_size))
def train(self):
print('Elitism GA is training...')
self.initialization()
with torch.no_grad():
for epoch in range(self.epochs):
for step, (batch_x, batch_y) in enumerate(trainloader):
evaluation_result = self.evaluation(batch_x, batch_y, False)
self.selection(evaluation_result)
def test(self):
print('------ Test Start -----')
correct = 0
total = 0
with torch.no_grad():
for test_x, test_y in testloader:
# images, labels = test_x.cuda(), test_y.cuda()
images, labels = test_x, test_y
output = self.model(images)
_, predicted = torch.max(output.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
accuracy = 100 * correct / total
print('Accuracy of the model is: %.4f %%' % accuracy)
return accuracy
def selection(self, evaluation_result):
sorted_evaluation = sorted(evaluation_result, key=lambda x: x['train_acc'])
elites = [e['pop'] for e in sorted_evaluation[-self.elite_num:]]
print('Elites: {}'.format(elites))
children = [self.chroms[i] for i in elites]
mating_pool = np.array([self.roulette_wheel_selection(evaluation_result) for _ in range(self.mating_pool_size)])
pairs = []
print('mating_pool')
print(mating_pool)
while len(children) < self.pop_size:
pair = [np.random.choice(mating_pool) for _ in range(2)]
pairs.append(pair)
children.append(self.crossover(pair))
print('Pairs: {}'.format(pairs))
print('Cross over finished.')
self.replacement(children)
for i in range(self.elite_num, self.pop_size): # do not mutate elites
if np.random.rand() < self.p_mutation:
mutated_child = self.mutation(i)
del self.chroms[i]
self.chroms.insert(i, mutated_child)
def crossover(self, _selected_pop):
if _selected_pop[0] == _selected_pop[1]:
return copy.deepcopy(self.chroms[_selected_pop[0]])
chrom1 = copy.deepcopy(self.chroms[_selected_pop[0]])
chrom2 = copy.deepcopy(self.chroms[_selected_pop[1]])
chrom1_layers = list(chrom1.modules())
chrom2_layers = list(chrom2.modules())
child = torch.nn.Sequential()
fc = None
for i in range(len(chrom1_layers)):
layer1 = chrom1_layers[i]
layer2 = chrom2_layers[i]
if isinstance(layer1, nn.Conv2d):
child.add_module(str(i-2), layer1 if np.random.rand() < 0.5 else layer2)
elif isinstance(layer1, nn.Linear):
fc = layer1
elif isinstance(layer1, (torch.nn.Sequential, Net)):
pass
else:
child.add_module(str(i-2), layer1)
chrom1.set_layer(child, fc)
return chrom1
def mutation(self, _selected_pop):
child = torch.nn.Sequential()
chrom = copy.deepcopy(self.chroms[_selected_pop])
chrom_layers = list(chrom.modules())
fc = None
# 变异比例,选择几层进行变异
for i, layer in enumerate(chrom_layers):
if isinstance(layer, nn.Conv2d):
if np.random.rand() < self.r_mutation:
weights = layer.weight.detach().numpy()
w = weights.astype(np.float32) + np.random.normal(0, self.stddev, weights.shape).astype(np.float32)
layer.weight = torch.nn.Parameter(torch.from_numpy(w))
child.add_module(str(i-2), layer)
elif isinstance(layer, nn.Linear):
fc = layer
elif isinstance(layer, (torch.nn.Sequential, Net)):
pass
else:
child.add_module(str(i-2), layer)
print('Mutation({}) finished.'.format(_selected_pop))
chrom.set_layer(child, fc)
return chrom
def replacement(self, _child):
self.chroms[:] = _child
print('Replacement finished.')
def evaluation(self, batch_x, batch_y, _is_batch=True):
cur_evaluation = []
for i in range(self.pop_size):
model = self.chroms[i]
output = model(batch_x)
train_loss = self.criterion(output, batch_y).item()
_, predicted = torch.max(output.data, 1)
total = batch_y.size(0)
correct = (predicted == batch_y.data).sum().item()
train_acc = 100 * correct / total
cur_evaluation.append({
'pop': i,
'train_loss': round(train_loss, 4),
'train_acc': round(train_acc, 4),
})
best_fit = sorted(cur_evaluation, key=lambda x: x['train_acc'])[-1]
self.evaluation_history.append({
'iter': len(self.evaluation_history) + 1,
'best_fit': best_fit,
'avg_fitness': np.mean([e['train_acc'] for e in cur_evaluation]).round(4),
'evaluation': cur_evaluation,
})
print('\nIter: {}'.format(self.evaluation_history[-1]['iter']))
print('Best_fit: {}, avg_fitness: {:.4f}'.format(self.evaluation_history[-1]['best_fit'],
self.evaluation_history[-1]['avg_fitness']))
self.model = self.chroms[best_fit['pop']]
return cur_evaluation
def roulette_wheel_selection(self, evaluation_result):
sorted_evaluation = sorted(evaluation_result, key=lambda x: x['train_acc'])
cum_acc = np.array([e['train_acc'] for e in sorted_evaluation]).cumsum()
extra_evaluation = [{'pop': e['pop'], 'train_acc': e['train_acc'], 'cum_acc': acc}
for e, acc in zip(sorted_evaluation, cum_acc)]
rand = np.random.rand() * cum_acc[-1]
for e in extra_evaluation:
if rand < e['cum_acc']:
return e['pop']
return extra_evaluation[-1]['pop']
if __name__ == '__main__':
g = ElitismGA(
_pop_size=100,
_p_mutation=0.1,
_r_mutation=0.1,
_epochs=20,
_elite_num=20,
_mating_pool_size=40,
_batch_size=32
)
g.train()
g.test()
自动下载样本集错误。
请参考:
运行结果:
Elitism GA is training...
network initialization(100) finished.
Iter: 1
Best_fit: {'pop': 99, 'train_loss': 2.2962, 'train_acc': 17.1875}, avg_fitness: 10.1484
Elites: [25, 33, 36, 70, 87, 22, 84, 88, 90, 50, 58, 69, 1, 29, 38, 75, 0, 48, 96, 99]
mating_pool
[45 0 48 8 58 46 58 50 85 58 65 40 43 51 25 39 21 93 50 13 11 49 96 38
35 90 80 46 26 90 9 73 82 55 69 93 30 2 28 99]
Pairs: [[46, 11], [50, 0], [38, 39], [85, 2], [48, 9], [90, 69], [48, 13], [90, 26], [21, 28], [43, 28], [48, 58], [96, 51], [55, 50], [26, 25], [8, 58], [9, 35], [73, 2], [45, 50], [35, 90], [96, 58], [2, 13], [69, 65], [45, 50], [2, 46], [90, 90], [38, 99], [93, 50], [99, 46], [85, 90], [65, 58], [46, 58], [90, 9], [46, 73], [46, 90], [49, 58], [13, 58], [21, 80], [0, 13], [80, 48], [26, 9], [46, 43], [2, 46], [11, 85], [58, 46], [48, 58], [9, 51], [49, 50], [65, 50], [45, 46], [48, 90], [65, 40], [69, 55], [46, 49], [51, 9], [73, 82], [28, 50], [55, 58], [35, 38], [90, 2], [39, 65], [55, 45], [85, 69], [9, 93], [58, 21], [93, 55], [0, 58], [65, 99], [50, 13], [21, 11], [11, 69], [93, 11], [55, 80], [2, 90], [2, 48], [43, 26], [0, 99], [50, 28], [35, 96], [9, 30], [90, 85]]
Cross over finished.
Replacement finished.
Mutation(26) finished.
Mutation(28) finished.
Mutation(36) finished.
Mutation(46) finished.
Mutation(62) finished.
Iter: 2
Best_fit: {'pop': 83, 'train_loss': 2.2934, 'train_acc': 20.3125}, avg_fitness: 10.2188
Elites: [50, 14, 48, 60, 65, 2, 30, 35, 72, 82, 42, 9, 17, 18, 37, 43, 81, 68, 75, 83]
mating_pool
[50 81 23 85 33 86 22 64 15 44 14 40 35 30 15 73 20 87 50 37 99 84 9 7
2 56 30 72 56 75 0 19 28 17 49 47 37 2 76 42]
Pairs: [[9, 0], [72, 20], [9, 64], [64, 99], [35, 23], [2, 23], [50, 15], [72, 30], [81, 28], [42, 44], [15, 40], [33, 20], [20, 37], [44, 15], [35, 86], [84, 19], [0, 99], [7, 2], [56, 42], [50, 19], [50, 72], [87, 14], [35, 99], [84, 2], [20, 30], [37, 7], [9, 40], [37, 84], [15, 50], [15, 50], [30, 64], [75, 86], [81, 87], [47, 9], [76, 50], [14, 75], [37, 19], [9, 75], [37, 17], [72, 7], [72, 2], [19, 73], [37, 56], [44, 14], [84, 50], [15, 17], [86, 49], [19, 42], [15, 81], [30, 37], [72, 19], [30, 15], [20, 40], [2, 64], [35, 56], [81, 87], [49, 73], [9, 49], [15, 28], [0, 47], [50, 30], [47, 56], [44, 7], [85, 50], [56, 73], [84, 23], [17, 42], [30, 99], [22, 30], [23, 15], [81, 42], [49, 87], [73, 2], [73, 15], [50, 44], [73, 14], [17, 9], [20, 50], [56, 75], [75, 73]]
Cross over finished.
Replacement finished.
Mutation(22) finished.
Mutation(29) finished.
Mutation(40) finished.
Mutation(45) finished.
Mutation(53) finished.
Mutation(62) finished.
Mutation(63) finished.
Mutation(66) finished.
Mutation(74) finished.
Mutation(75) finished.
Mutation(84) finished.
Mutation(91) finished.
Mutation(97) finished.
Iter: 3
Best_fit: {'pop': 11, 'train_loss': 2.3081, 'train_acc': 17.1875}, avg_fitness: 11.1172
Elites: [35, 43, 46, 80, 89, 1, 5, 8, 24, 55, 87, 6, 22, 40, 42, 88, 12, 73, 13, 11]
mating_pool
[68 8 94 17 57 79 59 10 36 9 8 99 35 13 79 34 46 87 85 21 82 85 71 52
61 87 83 45 32 8 68 84 24 1 92 89 5 24 65 31]
Pairs: [[89, 36], [82, 46], [8, 46], [68, 79], [17, 85], [65, 99], [57, 46], [1, 32], [79, 17], [83, 79], [82, 13], [92, 71], [94, 32], [65, 92], [83, 24], [1, 99], [35, 87], [24, 8], [79, 10], [85, 68], [34, 85], [61, 34], [61, 83], [82, 13], [61, 13], [8, 82], [8, 8], [79, 89], [79, 1], [5, 1], [65, 8], [68, 8], [79, 83], [8, 9], [57, 79], [57, 36], [85, 99], [79, 17], [1, 92], [13, 34], [36, 84], [85, 10], [65, 46], [85, 36], [85, 79], [61, 89], [71, 45], [36, 52], [24, 46], [94, 35], [71, 57], [24, 45], [52, 71], [87, 32], [5, 83], [71, 17], [71, 83], [94, 71], [10, 24], [79, 99], [24, 65], [87, 46], [8, 5], [82, 8], [84, 79], [68, 46], [94, 10], [46, 9], [9, 21], [10, 13], [8, 10], [59, 36], [89, 36], [59, 85], [68, 8], [5, 85], [8, 84], [89, 79], [8, 85], [21, 59]]
Cross over finished.
Replacement finished.
Mutation(58) finished.
Mutation(64) finished.
Iter: 4
Best_fit: {'pop': 17, 'train_loss': 2.287, 'train_acc': 18.75}, avg_fitness: 10.3750
Elites: [7, 8, 11, 36, 54, 61, 75, 76, 4, 92, 28, 67, 95, 1, 10, 47, 49, 52, 18, 17]
mating_pool
[10 6 10 2 83 8 48 16 3 1 47 29 92 6 51 33 2 64 33 78 74 69 55 27
78 62 25 49 6 15 28 10 2 38 99 33 51 24 30 49]
Pairs: [[51, 27], [1, 16], [2, 10], [64, 74], [74, 30], [51, 2], [28, 51], [47, 28], [33, 27], [10, 6], [30, 78], [83, 2], [83, 1], [48, 25], [10, 8], [99, 48], [16, 6], [99, 49], [78, 99], [29, 27], [2, 6], [25, 49], [64, 8], [10, 6], [25, 49], [10, 83], [3, 30], [62, 49], [51, 1], [64, 2], [28, 55], [74, 69], [10, 64], [2, 30], [15, 15], [29, 10], [69, 33], [74, 51], [55, 8], [78, 51], [74, 24], [99, 10], [24, 16], [49, 51], [6, 6], [49, 78], [2, 24], [10, 92], [33, 78], [33, 48], [83, 3], [30, 10], [30, 62], [55, 64], [55, 51], [33, 30], [6, 78], [92, 33], [27, 30], [1, 6], [99, 6], [33, 49], [51, 47], [8, 47], [10, 1], [6, 49], [6, 92], [29, 49], [10, 6], [78, 51], [48, 10], [49, 3], [64, 51], [16, 55], [47, 10], [55, 30], [2, 6], [51, 33], [69, 33], [78, 29]]
Cross over finished.
Replacement finished.
Mutation(26) finished.
Mutation(38) finished.
Mutation(43) finished.
Mutation(45) finished.
Mutation(51) finished.
Mutation(86) finished.
Mutation(89) finished.
Mutation(93) finished.
Mutation(98) finished.
Iter: 5
Best_fit: {'pop': 18, 'train_loss': 2.2814, 'train_acc': 22.6562}, avg_fitness: 10.5703
Elites: [17, 29, 39, 52, 54, 56, 84, 98, 14, 15, 22, 46, 60, 73, 87, 88, 86, 45, 91, 18]
mating_pool
[22 9 80 1 62 48 94 51 63 24 66 74 69 58 54 23 31 52 51 42 8 95 26 51
4 44 18 80 6 11 59 27 86 39 97 74 13 91 75 24]
Pairs: [[51, 62], [59, 63], [86, 27], [58, 11], [94, 23], [39, 69], [4, 39], [91, 39], [94, 86], [69, 24], [26, 6], [94, 52], [94, 24], [26, 59], [1, 39], [9, 51], [31, 51], [74, 95], [86, 80], [74, 9], [9, 23], [51, 6], [75, 91], [13, 31], [66, 59], [44, 95], [1, 69], [18, 6], [27, 75], [44, 9], [44, 4], [51, 75], [6, 80], [80, 80], [4, 4], [97, 80], [75, 13], [66, 52], [97, 94], [51, 11], [6, 51], [13, 51], [24, 48], [24, 54], [62, 80], [31, 39], [74, 18], [74, 48], [51, 86], [22, 97], [24, 48], [1, 13], [66, 66], [24, 74], [22, 22], [44, 94], [66, 66], [22, 24], [4, 22], [52, 39], [22, 59], [63, 59], [95, 51], [74, 97], [8, 80], [97, 22], [95, 23], [44, 91], [97, 94], [91, 52], [42, 48], [58, 23], [80, 51], [62, 22], [51, 13], [66, 39], [1, 13], [97, 6], [48, 13], [51, 80]]
Cross over finished.
Replacement finished.
Mutation(30) finished.
Mutation(53) finished.
Mutation(59) finished.
Mutation(62) finished.
Mutation(74) finished.
Mutation(75) finished.
Mutation(80) finished.
Mutation(87) finished.
Mutation(92) finished.
Mutation(99) finished.
Iter: 6
Best_fit: {'pop': 18, 'train_loss': 2.2928, 'train_acc': 17.9688}, avg_fitness: 9.8828
Elites: [73, 92, 25, 51, 70, 81, 99, 4, 68, 89, 14, 19, 32, 41, 28, 43, 62, 0, 16, 18]
mating_pool
[38 21 74 93 81 35 1 38 89 92 63 27 61 17 68 29 16 68 0 73 0 33 52 2
36 37 89 5 12 32 65 22 89 93 37 33 88 36 4 81]
Pairs: [[1, 38], [37, 89], [61, 36], [88, 35], [81, 88], [38, 22], [21, 38], [68, 27], [27, 93], [17, 68], [36, 63], [68, 65], [65, 0], [35, 38], [74, 2], [65, 33], [35, 65], [89, 12], [81, 0], [93, 12], [65, 74], [36, 4], [32, 36], [89, 35], [63, 12], [12, 33], [29, 61], [61, 2], [65, 36], [65, 81], [68, 73], [89, 0], [73, 12], [38, 22], [21, 89], [27, 1], [38, 33], [89, 21], [73, 73], [81, 89], [2, 16], [89, 88], [52, 33], [93, 5], [89, 36], [89, 63], [33, 68], [81, 0], [16, 37], [65, 21], [2, 89], [68, 32], [27, 63], [29, 0], [93, 29], [63, 36], [61, 38], [81, 16], [68, 93], [33, 74], [0, 38], [81, 89], [52, 81], [89, 33], [37, 74], [12, 2], [89, 27], [73, 65], [81, 33], [63, 0], [33, 93], [2, 38], [4, 81], [93, 68], [81, 4], [5, 33], [29, 27], [0, 35], [4, 36], [2, 37]]
Cross over finished.
Replacement finished.
Mutation(34) finished.
Mutation(36) finished.
Mutation(48) finished.
Mutation(59) finished.
Mutation(62) finished.
Mutation(63) finished.
Mutation(70) finished.
Mutation(75) finished.
Mutation(82) finished.