模型的压缩

彩票假说:随机初始化权重,得到训练后权重,若训练效果不好则再次随机初始化,训练效果好剪枝原始随机初始化。
网络剪枝:剪枝后性能下降(期望不要下降太多),要慢慢剪枝一次不要剪太多。
网络剪枝问题:权重剪枝:难以实现难以加速,网络结构变得不规整。
神经元剪枝:容易实现容易加速、网络结构规整

知识蒸馏

教师网络:大规模网络模型,集成网络,极深网络模型
学生网络:简单网络模型

代码

模型

class Student(nn.Module):
	def __init__(self,img_c=3,ndf=32,num_class=10):
		super(Student,self).__init__()
       self.conv=nn.Sequential(
            nn.Conv2d(in_channels=img_c,out_channels=ndf,kernel_size=3,stride=1,padding=1),#输出宽高不变
            nn.BatchNorm2d(ndf),
            nn.ReLU(inplace=True),

            nn.Conv2d(in_channels=ndf,out_channels=ndf,kernel_size=3,stride=1,padding=1),
            nn.BatchNorm2d(ndf),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2,stride=2),#图像大小缩小了一倍
            nn.Conv2d(ndf,2*ndf,kernel_size=3,stride=1,padding=1),
            nn.BatchNorm2d(2*ndf),
            nn.ReLU(inplace=True),
            nn.Conv2d(2*ndf,2*ndf,kernel_size=3,stride=1,padding=1),
            nn.BatchNorm2d(2*ndf),
            nn.ReLU(inplace=True)
        )
        self.fc=nn.Linear(2*ndf,num_class)
        self.avg_pool=nn.AdaptiveAvgPool2d((1,1))#平均池化层将图像大小任意缩放,任意宽高变成1*1
    def forward(self,x):
        out=self.conv(x)
        out=self.avg_pool(out)
        out=out.view(out.size(0),-1)#压缩到一维中
        out=self.fc(out)
        return out

模型训练

class Distill_Trainer(object):
	def __init__(self,lr=0.005,batch_size=64,num_epoch=1,train_data=None,test_data=None):
		self.lr=lr
		self.batch_size=batch_size
		self.num_epoch=num_epoch
		self.train_data=train_data
		self.test_data=test_data
		self.data_loader=Data.DataLoader(dataset=train_data,batch_size=batch_size,shuffle=True)
		self.test_loader=Data.DataLoader(dataset=test_data,batch_size=batch_size,shuffle=True)
	def train(self,model,save_path,save_name):
		optim=torch.optim.Adam(model.parameters(),lr=self.lr)
		best_acc=[]
		acc_list=[]
		for epoch in range(self.num_epoch):
			model.train()
			curr_loss=0
			for i,(bx,by) in enumerate(self.data_loader):
				pre=model(bx)
				loss=F.cross_entropy(input=pre,target=by)
				optim.zero_grad()
				loss.backward()
				optim.step()
				curr_loss+=loss.item()
			curr_acc=self.test(model)
			acc_list.append(curr_acc)
		return acc_list

计算准确率

def test(self,model):
	acc=0
	model.eval()
	for i,(bx,by) in enumerate(self.test_loader):
		with torch.no_grad():
			pre=model(bx)
			_,pred=torch.max(pre,1)
			acc+=torch.sum(pred==by.data)
	acc=acc.double()/self.test_loader.dataset.__len__()
	return acc.item()

蒸馏训练

def train_distill(self,student_model,teacher_model,save_path,save_name);
	optim=torch.optim.Adam(student_model.parameters(),lr=self.lr)
	best_acc=0
	acc_list=[]
	kl_loss=torch.nn.KLDivLoss()
	for epoch in range(self.num_epoch):
		pre=student_model(bx)
		for i,(bx,by) in enumerate(self.data_loader):
			with torch.no_grad():
				out_teacher=teacher_model(bx)
			loss_dis=kl_loss(input=F.log_soft(pre/5.0,dim=1),target=F.softmax(out_teacher/5,dim=1))
			
			loss=loss_dis+F.cross_entropy(input=pre,target=by)
			
			optim.zero_grad()
			loss.backward()
			optim.step()
		curr_acc=self.test(student_model)
		acc_list.append(curr_acc)
	return acc_list

执行

if __name__=='__main__':
	lr=0.0005
	batch_size=64
	num_epoch=1
	transforms_train=transforms.Compose([
	transforms.RandomCrop(32,padding=4),
	transforms.RandomHorizontalFlip(),
	transforms.RandomRotation(10),
	transforms.ToTensor(),
	transforms.Normalize(0.5,0.5,0.5),(0.5,0.5,0.5),])

	train_data=torchvision.datasets.CIFAR10(
	root='./data/cifar_data',
	train=True,
	transform=transforms_train,
	download=True)

	transforms_test=transforms.Compose([
	transforms.Totensor(),
	transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))])

	test_data=torchvision.datasets.CIFAR10(
	root='./data/cifar_data',
	train=False,
	transform=transforms_test,
	download=True)

	teach_model=models.resnet18(pretrained=True)
	num_fc=teach_model.fc.in_features
	teach_model.fc=torch.nn.Linear(num_fc,10)

	student_model=Student(num_class=10)
	norm_model=Student(num_class=10)

	trainer=Distill_Trainer(train_data=train_data,test_data=test_data)
	acc_list_normal=trainer.train(model=norm_model)
	acc_List_teacher=trainer.train(mode=teacher_Model)
	acc_list_student = trainer.train_distill(student_model=studen_model,teacher_model=teach_model)
  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值