Python *args参数的作用

用作函数占位符,可以增加扩展性

比如在深度学习中,网络和训练函数部分代码如下:

class SubNet(nn.Module):
	...
	def forward(self, X, *args):
		pred = do_something(X, *args)
		return pred

class Net(nn.Module):
	...
	def forward(self, X, *args):
		do_something()
		return self.subnet(X, *args)

def train_fn(net, loss, optimizer, train_iter, n_epochs):
	...
	for batch in train_iter:
		X, y = batch
		pred = net(X)
		l = loss(pred, y)
		optimizer.zero_grad()
		l.backward()
		optimizer.step()
	...

在这种情况下,模型有几个部分组成,使用如下的代码即可完成训练

subnet = SubNet()
net = Net(subnet)
train_fn(net, ...)

当子网络改进之后,模型要求的参数变多之后

class SubNet2(nn.Module):
	...
	def forward(self, X, W, *args):
		pred = do_something(X, W, *args)
		return pred
	
def train_fn(net, loss, optimizer, train_iter, n_epochs):
	...
	W = some_code()
	for batch in train_iter:
		X, y = batch
		pred = net(X, W)
		l = loss(pred, y)
		optimizer.zero_grad()
		l.backward()
		optimizer.step()
	...

训练的代码只需要将子网络改变成改进后的网络注入到模型中即可像之前一样进行训练(train_fn需要做一些适应性调整)

subnet = SubNet()
net = Net(subnet)
train_fn(net)

在这个时候,train_fn中net给的参数相比第一次的参数增加了,但是不需要对net的定义进行修改,新增加的参数会自动以*args的形式传入net,在调用self.subnet(X, *args)的时候自动进行传递,实现一定程度的解耦

  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值