darts论文链接:https://arxiv.org/pdf/1806.09055.pdf
darts源码链接:https://github.com/quark0/darts
search部分
'''
train_search.py
#数据准备(cifar10)。
搜索时,从cifar10的训练集中按照1:1重新划分训练集和验证集
'''
train_transform, valid_transform = utils._data_transforms_cifar10(args)
train_data = dset.CIFAR10(root=args.data, train=True, download=True, transform=train_transform)
#论文中 args.train_portion 取0.5
num_train = len(train_data)
indices = list(range(num_train))
split = int(np.floor(args.train_portion * num_train))
train_queue = torch.utils.data.DataLoader(
train_data, batch_size=args.batch_size,
sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[:split]),
pin_memory=True, num_workers=2)
valid_queue = torch.utils.data.DataLoader(
train_data, batch_size=args.batch_size,
sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[split:num_train]),
pin_memory=True, num_workers=2)
'''
train_search.py
搜索网络
损失函数:交叉熵
优化器:带动量的SGD
学习率调整策略:余弦退火调整学习率 CosineAnnealingLR
'''
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(
model.parameters(),
args.learning_rate,
momentum=args.momentum,
weight_decay=args.weight_decay)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, float(args.epochs), eta_min=args.learning_rate_min)
'''
train_search.py
构建搜索网络
构建Architect优化
'''
# in model_search.py
model = Network(args.init_channels, CIFAR_CLASSES, args.layers, criterion)
# in architect.py
architect = Architect(model, args)
'''
model_search.py
论文中
# C :16
# num_classes :2
# criterion
# layers:8
'''
class Network(nn.Module):
def __init__(self, C, num_classes, layers, criterion, steps=4, multiplier=4, stem_multiplier=3):
super(Network, self).__init__()
self._C = C
self._num_classes = num_classes
self._layers = layers
self._criterion = criterion
self._steps = steps
self._multiplier = multiplier
C_curr = stem_multiplier*C
# stem 开始conv+bn
self.stem = nn.Sequential(
nn.Conv2d(3, C_curr, 3, padding=1, bias=False),
nn.BatchNorm2d(C_curr)
)
C_prev_prev, C_prev, C_curr = C_curr, C_curr, C
self.cells = nn.ModuleList()
reduction_prev = False
# 对每个layers,8个cell
# 分为normal cell和reduction cell (通道加倍)
for i in range(layers):
if i in [layers//3, 2*layers//3]:
# 共8个cell ,取2-5个cell是作为reduction cell,经过reduction cell,通道加倍
C_curr *= 2
reduction = True
else:
reduction = False
cell = Cell(steps, multiplier, C_prev_prev, C_prev, C_curr, reduction, reduction_prev)
reduction_prev = reduction
self.cells += [cell]
C_prev_prev, C_prev = C_prev, multiplier*C_curr
# cell堆叠之后,后接分类
self.global_pooling = nn.AdaptiveAvgPool2d(1)
self.classifier = nn.Linear(C_prev, num_classes)
# 初始化alpha
self._initialize_alphas()
# 新建network,copy alpha参数
def new(self):
model_new = Network(self._C, self._num_classes, self._layers, self._criterion).cuda()
for x, y in zip(model_new.arch_parameters(), self.arch_parameters()):
x.data.copy_(y.data)
return model_new
def forward(self, input):
s0 = s1 = self.stem(input)
for i, cell in enumerate(self.cells):
#reduction cell 和normal cell 的 共享参数aplha 不同
if cell.reduction:
# softmax 归一化,14*8,对每一个连接之间的8个op操作进行softmax
weights = F.softmax(self.alphas_reduce, dim=-1)
else:
weights = F.softmax(self.alphas_normal, dim=-1)
# 每个cell之间的连接,s0来自上上个cell输出,s1来自上一个cell的输出
s0, s1 = s1, cell(s0, s1, weights)
out = self.global_pooling(s1)
logits = self.classifier(out.view(out.size(0),-1))
return logits
def _loss(self, input, target):
logits = self(input)
return self._criterion(logits, target)
# 初始化 alpha
def _initialize_alphas(self):
# 14 个连接,4个中间节点 2+3+4+5
k = sum(1 for i in range(self._steps) for n in range(2+i))
num_ops = len(PRIMITIVES)
#14,8
# normal cell
# reduction cell
self.alphas_normal = Variable(1e-3*torch.randn(k, num_ops).cuda(), requires_grad=True)
self.alphas_reduce = Variable(1e-3*torch.randn(k, num_ops).cuda(), requires_grad=True)
self._arch_parameters = [
self.alphas_normal,
self.alphas_reduce,
]
def arch_parameters(self):
return self._arch_parameters
def genotype(self):
def _parse(weights):
gene = []
n = 2
start = 0
for i in range(self._steps):
# 对于每一个中间节点
end = start + n
# 每个节点对应连接的所有权重 (2,3,4,5)
W = weights[start:end].copy()
#对于每个节点,根据其与其他节点的连接权重的最大值,来选择最优的2个连接方式(与哪两个节点之间有连接)
#注意这里只是选择连接的对应节点,并没有确定对应的连接op,后续确定
edges = sorted(range(i + 2), key=lambda x: -max(W[x][k] for k in range(len(W[x])) if k != PRIMITIVES.index('none')))[:2]
# 对于最优的两个连接边,分别选择最优的连接op
# 这个选择方式,感觉太粗糙了。假设也存在从alpha权重上来看,连接1的第2优的op,比连接2的第1优的op要好。这种操作避免了同一个边的多个op的存在,其实我觉得这种存在也是合理的吧。
# 后续有论文对这个选择策略进行改进。如fair-darts,后续blog会讲
for j in edges:
k_best = None
for k in range(len(W[j])):
if k != PRIMITIVES.index('none'):
if k_best is None or W[j][k] > W[j][k_best]:
k_best = k
# 记录下最好的op,和对应的连接边(与哪个节点相连)
# 对于每个节点,选择两个边和对应op,即一个cell有2*4=8个操作,定义死了,不够灵活!
gene.append((PRIMITIVES[k_best], j))
start = end
n += 1
return gene
# 归一化,基于策略选取 每个连接之间最优的操作
gene_normal = _parse(F.softmax(self.alphas_normal, dim=-1).data.cpu().numpy())
gene_reduce = _parse(F.softmax(self.alphas_reduce, dim=-1).data.cpu().numpy())
# 2,6
concat = range(2+self._steps-self._multiplier, self._steps+2)
genotype = Genotype(
normal=gene_normal, normal_concat=concat,
reduce=gene_reduce, reduce_concat=concat
)
return genotype
'''
model_search.py
cell的实现,参数共享,分为normal cell 和 reduction cell
经过reduction cell 特征图减半
'''
# 对于 每一条连接
class MixedOp(nn.Module):
def __init__(self, C, stride):
super(MixedOp, self).__init__()
self._ops = nn.ModuleList()
# 8种op操作
for primitive in PRIMITIVES:
# 计算每一种操作
op = OPS[primitive](C