PyTorch项目应用实例(一)加载(本地|官方)预训练模型

37 篇文章 7 订阅
27 篇文章 11 订阅

背景:我们需要把模型上传集群运行,所以预训练的模型需要放在文件夹之内进行加载,把环境及配置拷入env之后,不能用文件夹之外的库。预训练的resnet101需要直接放入目录下加载。

目录

一、预训练模型的加载

1.1 模型加载

1.2 加载流程

1.3 模型位置

1.4 缺点

1.5 找到预训练模型位置

二、加载指定位置模型

2.1 例子程序

2.2 把网络模型放入目录下

2.3 我们的程序

三、验证(可不看)

四、集群预训练模型的解决

4.1 相应报错

4.2 加载模型位置

4.3 服务器拷贝及运行


一、预训练模型的加载

1.1 模型加载

直接通过pytorch的models加载模型。

class HGAT_FC(nn.Module):
    def __init__(self, backbone, groups, nclasses, nclasses_per_group, group_channels, class_channels):
        super(HGAT_FC, self).__init__()
        self.groups = groups
        self.nclasses = nclasses
        self.nclasses_per_group = nclasses_per_group
        self.group_channels = group_channels
        self.class_channels = class_channels
        if backbone == 'resnet101':
            model = models.resnet101(pretrained=True)
        elif backbone == 'resnet50':
            model = models.resnet50(pretrained=False)
        else:
            raise Exception()

其中需要导入的库为 torchvision.models

import torch
import torchvision.models as models
from torch import nn
import mymodels.utils as utils
import torch
from torch import nn
import torch.nn.functional as F

1.2 加载流程

import torch
import torchvision.models as models
。。。
        if backbone == 'resnet101':
            model = models.resnet101(pretrained=True)
        elif backbone == 'resnet50':
            model = models.resnet50(pretrained=False)
        else:
            raise Exception()

1.3 模型位置

cd ~是返回home目录。这个表明torch再home目录下安装着。

[xingxiangrui@xx.com ~]$ cd ~/.torch/models
[xingxiangrui@xx.com models]$ pwd
/home/xingxiangrui/.torch/models
[xingxiangrui@xx.com models]$ ls
resnet101-5d3b4d8f.pth

1.4 缺点

如果没有下载过,torchvision会自动联网下载模型。

但是没有网络的情况下或者没有权限的情况下,模型不会下载,因此不能运行,会报错。

requests.exceptions.ConnectionError: ('Connection aborted.', TimeoutError(10060, '由于连接方在一段时间后没有正确答复或连接的主机没有反应,连接尝试失败。', None, 10060, None))

因此需要用下面的方法,直接从目录之中加载模型。

1.5 找到预训练模型位置

每个环境下,模型位置不一定,如果模型已经下载,需要找到模型存储的位置

如果预训练,则相应语句为:

def resnet101(pretrained=False, **kwargs):
    """Constructs a ResNet-101 model.

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
    if pretrained:
        model.load_state_dict(model_zoo.load_url(model_urls['resnet101']))
    return model

对load_url函数进行ctrl+b

找到相应的位置:即如果模型本地有,则从本地加载,如果没有,则从url下载。

def load_state_dict_from_url(url, model_dir=None, map_location=None, progress=True):
    r"""Loads the Torch serialized object at the given URL.

    If the object is already present in `model_dir`, it's deserialized and
    returned. The filename part of the URL should follow the naming convention
    ``filename-<sha256>.ext`` where ``<sha256>`` is the first eight or more
    digits of the SHA256 hash of the contents of the file. The hash is used to
    ensure unique names and to verify the contents of the file.

    The default value of `model_dir` is ``$TORCH_HOME/checkpoints`` where
    environment variable ``$TORCH_HOME`` defaults to ``$XDG_CACHE_HOME/torch``.
    ``$XDG_CACHE_HOME`` follows the X Design Group specification of the Linux
    filesytem layout, with a default value ``~/.cache`` if not set.

    Args:
        url (string): URL of the object to download
        model_dir (string, optional): directory in which to save the object
        map_location (optional): a function or a dict specifying how to remap storage locations (see torch.load)
        progress (bool, optional): whether or not to display a progress bar to stderr

    Example:
        >>> state_dict = torch.hub.load_state_dict_from_url('https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth')

    """
    # Issue warning to move data if old env is set
    if os.getenv('TORCH_MODEL_ZOO'):
        warnings.warn('TORCH_MODEL_ZOO is deprecated, please use env TORCH_HOME instead')

    if model_dir is None:
        torch_home = _get_torch_home()
        model_dir = os.path.join(torch_home, 'checkpoints')

    try:
        os.makedirs(model_dir)
    except OSError as e:
        if e.errno == errno.EEXIST:
            # Directory already exists, ignore.
            pass
        else:
            # Unexpected OSError, re-raise.
            raise

    parts = urlparse(url)
    filename = os.path.basename(parts.path)
    cached_file = os.path.join(model_dir, filename)
    if not os.path.exists(cached_file):
        sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file))
        hash_prefix = HASH_REGEX.search(filename).group(1)
        _download_url_to_file(url, cached_file, hash_prefix, progress=progress)
    return torch.load(cached_file, map_location=map_location)

设置断点,用调试器找到模型位置:

二、加载指定位置模型

这样就不用担心联网的问题,并且可以指定好相应的模型。

https://blog.csdn.net/u014264373/article/details/85332181

直接从pth文件之中进行加载。

例如

import torch
import torchvision.models as models
 
# pretrained=True就可以使用预训练的模型
net = models.squeezenet1_1(pretrained=False)
pthfile = r'E:\anaconda\app\envs\luo\Lib\site-packages\torchvision\models\squeezenet1_1.pth'
net.load_state_dict(torch.load(pthfile))
print(net)

2.1 例子程序

程序定义直接从目录下面读取文件。

直接从目录下加载

文件放在运行的目录下(语法很可能不对,只是参考):

def gcn_resnet101(num_classes, t, pretrained=True, adj_file=None, in_channel=300):
    # fixme
    model = models.resnet101(pretrained=False)
    if pretrained:
        print('load pretrained model...')
        model.load_state_dict(torch.load('./resnet101-5d3b4d8f.pth'))
    return GCNResnet(model, num_classes, t=t, adj_file=adj_file, in_channel=in_channel)

2.2 把网络模型放入目录下

cp ~/.torch/models/resnet101-5d3b4d8f.pth chun-ML_GCN/

注意,要与程序运行的位置和 load_state_dict的路径一致

2.3 我们的程序

        if backbone == 'resnet101':
            model = models.resnet101(pretrained=False)
            print('load pretrained model...')
            model.load_state_dict(torch.load('./resnet101-5d3b4d8f.pth'))
        elif backbone == 'resnet50':
            model = models.resnet50(pretrained=False)
            print('load pretrained model...')
            model.load_state_dict(torch.load('./resnet50-5d3b4d8f.pth'))

即直接加载运行目录下的resnet101-5d3b4d8f.pth 这个模型。

三、验证(可不看)

这部分是我们对自己程序的验证,其他可以不看。因为每个人模型不一样。

直接按上面的方法进行更改。

general_train.py之中,改为exp_3,hgat_fc.py之中按照上面进行修改。

直接在目录下,env/bin/python general_train.py如果不报错,即可。

四、集群预训练模型的解决

集群预训练模型的解决

4.1 相应报错

看出报错在于集群依然想要加载预训练模型。

Downloading: "http://xxxxxxxr/cadene/pretrainedmodels/se_resnet152-d17c99b7.pth" to /home/xxx/.torch/models/se_resnet152-d17c99b7.pth
Traceback (most recent call last):
  File "/home/xxx/job/tmp/job-25509/torch/lib/python3.5/site-packages/urllib3/connection.py", line 159, in _new_conn
    (self._dns_host, self.port), self.timeout, **extra_kw)
  File "/home/xxx/job/tmp/job-25509/torch/lib/python3.5/site-packages/urllib3/util/connection.py", line 80, in create_connection
    raise err
  File "/home/xxx/job/tmp/job-25509/torch/lib/python3.5/site-packages/urllib3/util/connection.py", line 70, in create_connection
    sock.connect(sa)
OSError: [Errno 101] Network is unreachable

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/xxx/job/tmp/job-25509/torch/lib/python3.5/site-packages/urllib3/connectionpool.py", line 600, in urlopen
    chunked=chunked)
  File "/home/xxx/job/tmp/job-25509/torch/lib/python3.5/site-packages/urllib3/connectionpool.py", line 354, in _make_request
    conn.request(method, url, **httplib_request_kw)
  File "/home/sxxx/job/tmp/job-25509/torch/lib/python3.5/http/client.py", line 1107, in request
    self._send_request(method, url, body, headers)
  File "/home/xxx/job/tmp/job-25509/torch/lib/python3.5/ccccccccc/client.py", line 1152, in _send_request
    self.endheaders(body)
  File "/home/xxx/job/tmp/job-25509/torch/lib/python3.5/http/client.py", line 1103, in endheaders
    self._send_output(message_body)
  File "/home/xxx/job/tmp/job-25509/torch/lib/python3.5/http/client.py", line 934, in _send_output
    self.send(msg)
  File "/home/xxx/job/tmp/job-25509/torch/lib/python3.5/http/client.py", line 877, in send
    self.connect()
  File "/home/xxx/job/tmp/job-25509/torch/lib/python3.5/site-packages/urllib3/connection.py", line 181, in connect
    conn = self._new_conn()
  File "/home/xxx/job/tmp/job-25509/torch/lib/python3.5/site-packages/urllib3/connection.py", line 168, in _new_conn
    self, "Failed to establish a new connection: %s" % e)
urllib3.exceptions.NewConnectionError: <urllib3.connection.HTTPConnection object at 0x7f03fa52d748>: Failed to establish a new connection: [Errno 101] Network is unreachable

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/xx/job/tmp/job-25509/torch/lib/python3.5/site-packages/requests/adapters.py", line 449, in send
    timeout=timeout
  File "/home/xxxxx/job/tmp/job-25509/torch/lib/python3.5/site-packages/urllib3/connectionpool.py", line 638, in urlopen
    _stacktrace=sys.exc_info()[2])
  File "/home/xxxx/job/tmp/job-25509/torch/lib/python3.5/site-packages/urllib3/util/retry.py", line 398, in increment
    raise MaxRetryError(_pool, url, error or ResponseError(cause))
urllib3.exceptions.MaxRetryError: HTTPConnectionPool(host='data.lip6.fr', port=80): Max retries exceeded with url: /cadene/pretrainedmodels/se_resnet152-d17c99b7.pth (Caused by NewConnectionError('<urllib3.connection.HTTPConnection object at 0x7f03fa52d748>: Failed to establish a new connection: [Errno 101] Network is unreachable',))

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "train_se_clsgat.py", line 128, in <module>
    main()
  File "train_se_clsgat.py", line 107, in main
    model = util.get_model(args)
  File "/home/xxx/job/tmp/job-25509/util.py", line 266, in get_model
    class_channels=args.CLASS_CHANNELS)
  File "/home/xxxx/job/tmp/job-25509/models/se_clsgat.py", line 379, in __init__
    model=senet_origin.se_resnet152()
  File "/home/xxx/job/tmp/job-25509/models/senet_origin.py", line 423, in se_resnet152
    initialize_pretrained_model(model, num_classes, settings)
  File "/home/xxx/job/tmp/job-25509/models/senet_origin.py", line 377, in initialize_pretrained_model
    model.load_state_dict(model_zoo.load_url(settings['url']))
  File "/home/slurm/job/tmp/job-25509/torch/lib/python3.5/site-packages/torch/utils/model_zoo.py", line 65, in load_url
    _download_url_to_file(url, cached_file, hash_prefix, progress=progress)
  File "/home/xxx/job/tmp/job-25509/torch/lib/python3.5/site-packages/torch/utils/model_zoo.py", line 71, in _download_url_to_file
    u = urlopen(url, stream=True)
  File "/home/xxx/job/tmp/job-25509/torch/lib/python3.5/site-packages/requests/api.py", line 75, in get
    return request('get', url, params=params, **kwargs)
  File "/home/xxx/job/tmp/job-25509/torch/lib/python3.5/site-packages/requests/api.py", line 60, in request
    return session.request(method=method, url=url, **kwargs)
  File "/home/xxx/job/tmp/job-25509/torch/lib/python3.5/site-packages/requests/sessions.py", line 533, in request
    resp = self.send(prep, **send_kwargs)
  File "/home/xxx/job/tmp/job-25509/torch/lib/python3.5/site-packages/requests/sessions.py", line 646, in send
    r = adapter.send(request, **kwargs)
  File "/home/xxx/job/tmp/job-25509/torch/lib/python3.5/site-packages/requests/adapters.py", line 516, in send
    raise ConnectionError(e, request=request)
requests.exceptions.ConnectionError: HTTPConnectionPool(host='data.lip6.fr', port=80): Max retries exceeded with url: /cadene/pretrainedmodels/se_resnet152-d17c99b7.pth (Caused by NewConnectionError('<urllib3.connection.HTTPConnection object at 0x7f03fa52d748>: Failed to establish a new connection: [Errno 101] Network is unreachable',))

需要将预训练模型放在目录之下免得集群重复加载。

程序没有运行到加载模型一步。

==== GLOBAL INFO ====
IPLIST: xx.xx.xx.xx
IP0: xx.xx.xx.xx
====================
==== NODE INFO ====
NODE_RNAK: 0
IP0: xx.xx.xx.xx
NODE_IP: xx.xx.xx
===================
{'ADJ_FILE': 'data/data/coco/coco_adj.pkl',
 'ALPHA': 0.8,
 'BACKBONE': 'resnet150',
 'BATCH_SIZE': 16,
 'CLASS_CHANNELS': 256,
 'CPROB': array([[1.00000000e+00, 8.26410144e-01, 7.04392284e-01, ...,
        4.03311258e-01, 4.45312500e-01, 5.40000000e-01],
       [4.18382255e-02, 1.00000000e+00, 1.02719033e-01, ...,
        1.12582781e-02, 0.00000000e+00, 5.71428571e-03],
       [1.34192234e-01, 3.86532575e-01, 1.00000000e+00, ...,
        3.84105960e-02, 7.81250000e-03, 8.57142857e-03],
       ...,
       [1.34812060e-02, 7.43331876e-03, 6.73948408e-03, ...,
        1.00000000e+00, 2.34375000e-02, 8.57142857e-03],
       [1.26178775e-03, 0.00000000e+00, 1.16198001e-04, ...,
        1.98675497e-03, 1.00000000e+00, 2.57142857e-02],
       [8.36764511e-03, 1.74901618e-03, 6.97188008e-04, ...,
        3.97350993e-03, 1.40625000e-01, 1.00000000e+00]]),
 'DATA': 'data/data/coco',
 'DATA_TYPE': 'coco',
 'DEEPMAR_LOSS': <loss.DeepMarWeights object at 0x7f04044800f0>,
 'DEVICE_IDS': [0, 1, 2, 3, 4, 5, 6, 7],
 'EPOCH': 100,
 'EPOCH_STEP': 30,
 'EVALUATE': False,
 'EXP_NAME': 'se_clsgat',
 'GROUPS': 12,
 'GROUP_CHANNELS': 512,
 'IMAGE_SIZE': 448,
 'INP_NAME': 'data/data/coco/coco_glove_word2vec.pkl',
 'IS_SLURM': False,
 'LOSS_TYPE': 'DeepMarLoss',
 'LR': 0.01,
 'LRP': 0.01,
 'LR_SCHEDULER': None,
 'LR_SCHEDULER_PARAMS': None,
 'MAX_EPOCH': 100,
 'MODEL': 'se_clsgat',
 'MOMENTUM': 0.9,
 'NCLASSES': 80,
 'NCLASSES_PER_GROUP': [1, 8, 5, 10, 5, 10, 7, 10, 6, 6, 5, 7],
 'PRINT_FREQ': 10,
 'RESUME': 'checkpoints/coco/se_clsgat/checkpoint.pth.tar',
 'SAVE_MODEL_PATH': 'checkpoints/coco/se_clsgat',
 'START_EPOCH': 0,
 'WEIGHT_DECAY': 1e-05,
 'WEIGHT_FILE': 'data/coco/coco_rate.pkl',
 'WORKERS': 4}
Compose(
    Resize(size=(512, 512), interpolation=PIL.Image.BILINEAR)
    MultiScaleCrop
    RandomHorizontalFlip(p=0.5)
    ToTensor()
    Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
)
Compose(
    Warp (size=448, interpolation=2)
    ToTensor()
    Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
)
[dataset] Done!
[annotation] Done!
[json] Done!
[dataset] Done!
[annotation] Done!
[json] Done!
-------------------------------------------------------
Primary job  terminated normally, but 1 process returned
a non-zero exit code.. Per user-direction, the job has been aborted.
-------------------------------------------------------
--------------------------------------------------------------------------
mpirun detected that one or more processes exited with non-zero status, thus causing
the job to be terminated. The first process to do so was:

  Process name: [[58771,1],0]
  Exit code:    1
--------------------------------------------------------------------------

4.2 加载模型位置

通过上面1.5中的方法设置断点找到模型位置:拷贝过去

本地可以采用这种方法:

(torch041) $ cd /Users/baidu/.cache/torch/checkpoints/
(torch041) baidudeMacBook-Pro:checkpoints baidu$ ls
resnet101-5d3b4d8f.pth		se_resnet152-d17c99b7.pth
(torch041) baidudeMacBook-Pro:checkpoints baidu$ cp se_resnet152-d17c99b7.pth /Users/baidu/Desktop/code/ML_GAT-master/

运行没有报错。

4.3 服务器拷贝及运行

服务器已经知道相应的torch的缓存的地址:

cd ~/.torch/models/
ls
resnet101-5d3b4d8f.pth  resnet50-19c8e357.pth  se_resnet152-d17c99b7.pth

直接更换更改好的

senet_origin

def initialize_pretrained_model(model, num_classes, settings):
    assert num_classes == settings['num_classes'], \
        'num_classes should be {}, but is {}'.format(
            settings['num_classes'], num_classes)
    # model.load_state_dict(model_zoo.load_url(settings['url']))
    print('loading pretrained model from local...')
    model.load_state_dict(torch.load('./se_resnet152-d17c99b7.pth'))
    model.input_space = settings['input_space']
    model.input_size = settings['input_size']
    model.input_range = settings['input_range']
    model.mean = settings['mean']
    model.std = settings['std']
  • 20
    点赞
  • 67
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 7
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 7
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

祥瑞Coding

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值