修改pytorch和Keras预训练模型路径

本文介绍了如何修改Pytorch和Keras预训练模型的默认加载路径。在Pytorch中,可以通过编辑hub.py文件,改变load_state_dict_from_url函数中的model_dir参数来设定权重文件的位置。而在Keras中,需在实例化模型时直接指定权重文件路径。对于Keras,目前没有全局配置方法,但可以在调用模型时指定权重文件。
摘要由CSDN通过智能技术生成

1、Pytorch预训练模型路径修改

Pytorch安装目录下有一个hub.py,改文件指定了预训练模型的加载位置。该文件存在于xxx\site-packages\torch,例如我的存在于“C:\ProgramData\Miniconda3\Lib\site-packages\torch”。
打开hub.py文件,找到load_state_dict_from_url函数,其中第二个参数
model_dir用于指定权重文件路径:model_dir (string, optional): directory in which to save the object。将该参数值由None改为权重文件位置即可,例如model_dir=‘D:/Models_Download/torch’。

def load_state_dict_from_url(url, model_dir='D:/Models_Download/torch', map_location=None, progress=True, check_hash=False, file_name=None):
    r"""Loads the Torch serialized object at the given URL.

    If downloaded file is a zip file, it will be automatically
    decompressed.

    If the object is already present in `model_dir`, it's deserialized and
    returned.
    The default value of `model_dir` is ``<hub_dir>/checkpoints`` where
    `hub_dir` is the directory returned by :func:`~torch.hub.get_dir`.

    Args:
        url (string): URL of the object to download
        model_dir (string, optional): directory in which to save the object
        map_location (optional): a function or a dict specifying how to remap storage locations (see torch.load)
        progress (bool, optional): whether or not to display a progress bar to stderr.
            Default: True
        check_hash(bool, optional): If True, the filename part of the URL should follow the naming convention
            ``filename-<sha256>.ext`` where ``<sha256>`` is the first eight or more
            digits of the SHA256 hash of the contents of the file. The hash is used to
            ensure unique names and to verify the contents of the file.
            Default: False
        file_name (string, optional): name for the downloaded file. Filename from `url` will be used if not set.

    Example:
        >>> state_dict = torch.hub.load_state_dict_from_url('https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth')

    """

2、Keras修改预训练模型位置

Keras安装路径内并没有一个文件来定义预训练模型位置,我只能在调用预训练模型的时候指定模型文件的路径(有没有更好的设置方法?)。

base_model = vgg19.VGG19(input_shape=(IMAGE_SIZE, IMAGE_SIZE, 3), include_top=False, 
                         weights='D:\\Models_Download\\keras\\vgg19_weights_tf_dim_ordering_tf_kernels_notop.h5')
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值