保存网络中的参数到txt

这段代码定义了一个VGG网络结构(VGG_SNIP),包括VGG-C和VGG-D变体,并使用SNIP论文中的配置。网络结构中使用了卷积、池化、ReLU和批量归一化层。之后,代码加载了预训练的VGG16权重到模型中。最后,遍历模型的所有参数,将它们保存到单独的文件中,不同形状的参数以不同的方式写入文件。
摘要由CSDN通过智能技术生成
import os
import scipy.io as io

import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader

# from pytorchtools import EarlyStopping

from torchvision.datasets import MNIST, CIFAR10, CIFAR100
from torchvision.transforms import Compose, ToTensor, Normalize
from torchvision import transforms


VGG_CONFIGS = {
    # M for MaxPool, Number for channels
    'D': [
        64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M',
        512, 512, 512, 'M'
    ],
    'E': [
        64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M',
        512, 512, 512, 512, 'M'
    ],
}


class VGG_SNIP(nn.Module):
    """
    This is a base class to generate three VGG variants used in SNIP paper:
        1. VGG-C (16 layers)
        2. VGG-D (16 layers)
        3. VGG-like

    Some of the differences:
        * Reduced size of FC layers to 512
        * Adjusted flattening to match CIFAR-10 shapes
        * Replaced dropout layers with BatchNorm
    """

    def __init__(self, config, num_classes=100):
        super().__init__()

        self.features = self.make_layers(VGG_CONFIGS[config], batch_norm=True)

        self.classifier = nn.Sequential(
            nn.Linear(512, 512),  # 512 * 7 * 7 in the original VGG
            nn.ReLU(True),
            nn.BatchNorm1d(512),  # instead of dropout
            nn.Linear(512, 512),
            nn.ReLU(True),
            nn.BatchNorm1d(512),  # instead of dropout
            nn.Linear(512, num_classes),
        )

    @staticmethod
    def make_layers(config, batch_norm=False):  # TODO: BN yes or no?
        layers = []
        in_channels = 3
        for v in config:
            if v == 'M':
                layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
            else:
                conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
                if batch_norm:
                    layers += [
                        conv2d,
                        nn.BatchNorm2d(v),
                        nn.ReLU(inplace=True)
                    ]
                else:
                    layers += [conv2d, nn.ReLU(inplace=True)]
                in_channels = v
        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        x = F.log_softmax(x, dim=1)
        return x

net = VGG_SNIP('D').cuda()
net = torch.nn.DataParallel(net).cuda()
checkpoint = torch.load(r'E:\code\practicing\snip-master\VGG16_2.pth.tar')
net.load_state_dict(checkpoint['state_dict'])



for name, parameter in net.named_parameters():
    with open('C:\\Users\\Pong\\Desktop\\for_ly\\{}.txt'.format(name), 'w') as thisfile:
        if (len(parameter.shape))==1:
            for channel in parameter:
                channel = str(channel.cpu().detach().numpy())
                thisfile.write(channel)
                thisfile.write("\n")
        if (len(parameter.shape))==2:
            for channel in parameter:
                channel = channel.cpu().detach().numpy()
                np.savetxt(thisfile, channel, fmt='%f', delimiter = ',')
                thisfile.write("\n")
        elif (len(parameter.shape))==4:
            for channel in parameter:
                for filter in channel:
                    filter = filter.cpu().detach().numpy()
                    filter = np.array(filter)
                    np.savetxt(thisfile, filter, fmt='%f', delimiter = ',')
                    thisfile.write("\n")
    # c = np.loadtxt('/home/Velocitymodel/speedfile/b.txt', delimiter = ',').reshape((2, 2, 2))  恢复


print(":)")
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值