flyai下载预训练的keras模型

  1. 进入FlyAI预训练模型地址
  2. 找到需要的keras模型,相应链接后确定
    在这里插入图片描述3.得到复制后的内容
# 必须使用该方法下载模型,然后加载
from flyai.utils import remote_helper
path = remote_helper.get_remote_date("https://www.flyai.com/m/v0.8|NASNet-mobile.h5")

直接在本地运行会出错。需要修改以下两个文件:

  • remote_helper.py。把os.path.join(sys.path[0], 'data', 'input', 'model')改成自己的地址(保存模型的地址)
    在这里插入图片描述
  • download.py中的地址改成自己的地址
    在这里插入图片描述
  1. 下载模型放到本地的.keras文件夹下,即可在任意位置直接调用。

附:remote_helper.py代码

import sys

import hashlib
import json
import os
import platform
import uuid
from os.path import join

from flyai.processor.download import download_model

__DOMAIN = "https://www.flyai.com"


def __genearteMD5(str):
    hl = hashlib.md5()
    hl.update(str.encode(encoding='utf-8'))
    return hl.hexdigest()


def __check_dir(str):
    if " " in str:
        return False
    return all(ord(c) < 128 for c in str)


def __get_home_path():
    sys = platform.system()
    if sys == "Windows":
        if not __check_dir(os.environ['HOMEPATH']):
            path = join("C://", '.flyai', "")
        else:
            path = join(os.environ['HOMEPATH'], '.flyai', "")
    else:
        path = join(os.environ['HOME'], '.flyai', "")
    if not os.path.exists(path):
        os.makedirs(path)
    return path


def __get_mac():
    try:
        address = hex(uuid.getnode())[2:]
        return '-'.join(address[i:i + 2] for i in range(0, len(address), 2))
    except:
        return "unknown"


def __get_token():
    GOOS = platform.system()
    if GOOS == "Windows":
        file_path = os.path.join(os.environ['HOMEPATH'], '.flyai_flyai')
    else:
        file_path = os.path.join(os.environ['HOME'], '.flyai_flyai')
    if os.path.exists(file_path):
        file = open(file_path, 'r')
        token = file.read()
        return token
    else:
        file_path = join(__get_home_path(), "." + __genearteMD5(__get_mac() + __DOMAIN))
        if os.path.exists(file_path):
            file = open(file_path)
            login_data = json.loads(file.read())
            return login_data['token']
        else:
            file_path = os.path.join(sys.path[0], 'train.json')
            if os.path.exists(os.path.join(sys.path[0], 'train.json')):
                file = open(file_path)
                login_data = json.loads(file.read())
                return login_data['token']


def get_remote_date(remote_name):
    if "http" in remote_name:
        token = __get_token()
        if token is not None:
            return download_model(remote_name + "?token=" + __get_token(),
                                  os.path.join(sys.path[0], 'data', 'input', 'model'), is_print=True)
        else:
            return None

评论 16
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值