[NAS]Darts代码解析

本文深入解析Darts论文和源码,重点关注approximate architecture gradient。通过公式分析,解释了在求解α的梯度时,2和3的区别,即复合函数与偏导数的不同处理方式,并给出了近似计算方法。
摘要由CSDN通过智能技术生成

darts论文链接:https://arxiv.org/pdf/1806.09055.pdf
darts源码链接:https://github.com/quark0/darts


search部分
'''
train_search.py
#数据准备(cifar10)。
搜索时,从cifar10的训练集中按照1:1重新划分训练集和验证集
'''
train_transform, valid_transform = utils._data_transforms_cifar10(args)
train_data = dset.CIFAR10(root=args.data, train=True, download=True, transform=train_transform)

#论文中 args.train_portion 取0.5
num_train = len(train_data)
indices = list(range(num_train))
split = int(np.floor(args.train_portion * num_train))

train_queue = torch.utils.data.DataLoader(
      train_data, batch_size=args.batch_size,
      sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[:split]),
      pin_memory=True, num_workers=2)
valid_queue = torch.utils.data.DataLoader(
      train_data, batch_size=args.batch_size,
      sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[split:num_train]),
      pin_memory=True, num_workers=2)
      
'''
train_search.py
搜索网络
损失函数:交叉熵
优化器:带动量的SGD
学习率调整策略:余弦退火调整学习率 CosineAnnealingLR
'''
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(
      model.parameters(),
      args.learning_rate,
      momentum=args.momentum,
      weight_decay=args.weight_decay)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, float(args.epochs), eta_min=args.learning_rate_min)

'''
train_search.py
构建搜索网络
构建Architect优化
'''
# in model_search.py
model = Network(args.init_channels, CIFAR_CLASSES, args.layers, criterion)
# in architect.py
architect = Architect(model, args)

'''
model_search.py
论文中
# C	:16
# num_classes :2
# criterion
# layers:8
'''
class Network(nn.Module):
	def __init__(self, C, num_classes, layers, criterion, steps=4, multiplier=4, stem_multiplier=3):
	    super(Network, self).__init__()
	    self._C = C
	    self._num_classes = num_classes
	    self._layers = layers
	    self._criterion = criterion
	    self._steps = steps
	    self._multiplier = multiplier

	    C_curr = stem_multiplier*C
        # stem 开始conv+bn
	    self.stem = nn.Sequential(
	      nn.Conv2d(3, C_curr, 3, padding=1, bias=False),
	      nn.BatchNorm2d(C_curr)
	    )
 
		C_prev_prev, C_prev, C_curr = C_curr, C_curr, C
		self.cells = nn.ModuleList()
		reduction_prev = False
		# 对每个layers,8个cell
		# 分为normal cell和reduction cell (通道加倍)
		for i in range(layers):
			if i in [layers//3, 2*layers//3]:
				# 共8个cell ,取2-5个cell是作为reduction cell,经过reduction cell,通道加倍
			    C_curr *= 2
			    reduction = True
		 	else:
		    	reduction = False
			cell = Cell(steps, multiplier, C_prev_prev, C_prev, C_curr, reduction, reduction_prev)
			reduction_prev = reduction
			self.cells += [cell]
			C_prev_prev, C_prev = C_prev, multiplier*C_curr
		
		# cell堆叠之后,后接分类
		self.global_pooling = nn.AdaptiveAvgPool2d(1)
		self.classifier = nn.Linear(C_prev, num_classes)
		
		# 初始化alpha
		self._initialize_alphas()

	# 新建network,copy alpha参数
	def new(self):
		model_new = Network(self._C, self._num_classes, self._layers, self._criterion).cuda()
		for x, y in zip(model_new.arch_parameters(), self.arch_parameters()):
			x.data.copy_(y.data)
		return model_new

	def forward(self, input):
		s0 = s1 = self.stem(input)
		for i, cell in enumerate(self.cells):
			#reduction cell 和normal cell 的 共享参数aplha 不同
	 		if cell.reduction:
	 			# softmax 归一化,14*8,对每一个连接之间的8个op操作进行softmax
	    		weights = F.softmax(self.alphas_reduce, dim=-1)
	 		else:
	    		weights = F.softmax(self.alphas_normal, dim=-1)
	    	# 每个cell之间的连接,s0来自上上个cell输出,s1来自上一个cell的输出
	  		s0, s1 = s1, cell(s0, s1, weights)
		out = self.global_pooling(s1)
		logits = self.classifier(out.view(out.size(0),-1))
		return logits

	def _loss(self, input, target):
		logits = self(input)
		return self._criterion(logits, target) 
	
	# 初始化 alpha
	def _initialize_alphas(self):
		# 14 个连接,4个中间节点 2+3+4+5
		k = sum(1 for i in range(self._steps) for n in range(2+i))
		num_ops = len(PRIMITIVES)
		#14,8
		# normal cell
		# reduction cell
		self.alphas_normal = Variable(1e-3*torch.randn(k, num_ops).cuda(), requires_grad=True)
		self.alphas_reduce = Variable(1e-3*torch.randn(k, num_ops).cuda(), requires_grad=True)
		self._arch_parameters = [
			self.alphas_normal,
			self.alphas_reduce,
			]

	def arch_parameters(self):
		return self._arch_parameters
	
	def genotype(self):
		def _parse(weights):
			gene = []
     		n = 2
     		start = 0
      		for i in range(self._steps):
      	        # 对于每一个中间节点
        		end = start + n
        		# 每个节点对应连接的所有权重 (2,3,4,5)
        		W = weights[start:end].copy()
        		#对于每个节点,根据其与其他节点的连接权重的最大值,来选择最优的2个连接方式(与哪两个节点之间有连接)
        		#注意这里只是选择连接的对应节点,并没有确定对应的连接op,后续确定
        		edges = sorted(range(i + 2), key=lambda x: -max(W[x][k] for k in range(len(W[x])) if k != PRIMITIVES.index('none')))[:2]
        		# 对于最优的两个连接边,分别选择最优的连接op
        		# 这个选择方式,感觉太粗糙了。假设也存在从alpha权重上来看,连接1的第2优的op,比连接2的第1优的op要好。这种操作避免了同一个边的多个op的存在,其实我觉得这种存在也是合理的吧。
        		# 后续有论文对这个选择策略进行改进。如fair-darts,后续blog会讲
        		for j in edges:
					k_best = None
          			for k in range(len(W[j])):
            			if k != PRIMITIVES.index('none'):
              				if k_best is None or W[j][k] > W[j][k_best]:
                				k_best = k
                	# 记录下最好的op,和对应的连接边(与哪个节点相连)
                	# 对于每个节点,选择两个边和对应op,即一个cell有2*4=8个操作,定义死了,不够灵活!
          			gene.append((PRIMITIVES[k_best], j))
        		start = end
        		n += 1
        	return gene

        # 归一化,基于策略选取 每个连接之间最优的操作
		gene_normal = _parse(F.softmax(self.alphas_normal, dim=-1).data.cpu().numpy())
		gene_reduce = _parse(F.softmax(self.alphas_reduce, dim=-1).data.cpu().numpy())
		# 2,6
		concat = range(2+self._steps-self._multiplier, self._steps+2)
		genotype = Genotype(
		normal=gene_normal, normal_concat=concat,
		reduce=gene_reduce, reduce_concat=concat
		)
		return genotype
		
'''
model_search.py
cell的实现,参数共享,分为normal cell 和 reduction cell
经过reduction cell 特征图减半
'''
# 对于 每一条连接
class MixedOp(nn.Module):
	def __init__(self, C, stride):
		super(MixedOp, self).__init__()
		self._ops = nn.ModuleList()
		# 8种op操作
		for primitive in PRIMITIVES:
			# 计算每一种操作
     		op = OPS[primitive](C
  • 13
    点赞
  • 50
    收藏
    觉得还不错? 一键收藏
  • 10
    评论
评论 10
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值