修改pytorch和Keras预训练模型路径

1、Pytorch预训练模型路径修改

Pytorch安装目录下有一个hub.py,改文件指定了预训练模型的加载位置。该文件存在于xxx\site-packages\torch,例如我的存在于“C:\ProgramData\Miniconda3\Lib\site-packages\torch”。
打开hub.py文件,找到load_state_dict_from_url函数,其中第二个参数
model_dir用于指定权重文件路径:model_dir (string, optional): directory in which to save the object。将该参数值由None改为权重文件位置即可,例如model_dir=‘D:/Models_Download/torch’。

def load_state_dict_from_url(url, model_dir='D:/Models_Download/torch', map_location=None, progress=True, check_hash=False, file_name=None):
    r"""Loads the Torch serialized object at the given URL.

    If downloaded file is a zip file, it will be automatically
    decompressed.

    If the object is already present in `model_dir`, it's deserialized and
    returned.
    The default value of `model_dir` is ``<hub_dir>/checkpoints`` where
    `hub_dir` is the directory returned by :func:`~torch.hub.get_dir`.

    Args:
        url (string): URL of the object to download
        model_dir (string, optional): directory in which to save the object
        map_location (optional): a function or a dict specifying how to remap storage locations (see torch.load)
        progress (bool, optional): whether or not to display a progress bar to stderr.
            Default: True
        check_hash(bool, optional): If True, the filename part of the URL should follow the naming convention
            ``filename-<sha256>.ext`` where ``<sha256>`` is the first eight or more
            digits of the SHA256 hash of the contents of the file. The hash is used to
            ensure unique names and to verify the contents of the file.
            Default: False
        file_name (string, optional): name for the downloaded file. Filename from `url` will be used if not set.

    Example:
        >>> state_dict = torch.hub.load_state_dict_from_url('https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth')

    """

2、Keras修改预训练模型位置

Keras安装路径内并没有一个文件来定义预训练模型位置,我只能在调用预训练模型的时候指定模型文件的路径(有没有更好的设置方法?)。

base_model = vgg19.VGG19(input_shape=(IMAGE_SIZE, IMAGE_SIZE, 3), include_top=False, 
                         weights='D:\\Models_Download\\keras\\vgg19_weights_tf_dim_ordering_tf_kernels_notop.h5')
  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值