- 进入FlyAI预训练模型地址
- 找到需要的keras模型,相应链接后确定
3.得到复制后的内容
# 必须使用该方法下载模型,然后加载
from flyai.utils import remote_helper
path = remote_helper.get_remote_date("https://www.flyai.com/m/v0.8|NASNet-mobile.h5")
直接在本地运行会出错。需要修改以下两个文件:
- remote_helper.py。把
os.path.join(sys.path[0], 'data', 'input', 'model')
改成自己的地址(保存模型的地址)
- download.py中的地址改成自己的地址
- 下载模型放到本地的
.keras
文件夹下,即可在任意位置直接调用。
附:remote_helper.py
代码
import sys
import hashlib
import json
import os
import platform
import uuid
from os.path import join
from flyai.processor.download import download_model
__DOMAIN = "https://www.flyai.com"
def __genearteMD5(str):
hl = hashlib.md5()
hl.update(str.encode(encoding='utf-8'))
return hl.hexdigest()
def __check_dir(str):
if " " in str:
return False
return all(ord(c) < 128 for c in str)
def __get_home_path():
sys = platform.system()
if sys == "Windows":
if not __check_dir(os.environ['HOMEPATH']):
path = join("C://", '.flyai', "")
else:
path = join(os.environ['HOMEPATH'], '.flyai', "")
else:
path = join(os.environ['HOME'], '.flyai', "")
if not os.path.exists(path):
os.makedirs(path)
return path
def __get_mac():
try:
address = hex(uuid.getnode())[2:]
return '-'.join(address[i:i + 2] for i in range(0, len(address), 2))
except:
return "unknown"
def __get_token():
GOOS = platform.system()
if GOOS == "Windows":
file_path = os.path.join(os.environ['HOMEPATH'], '.flyai_flyai')
else:
file_path = os.path.join(os.environ['HOME'], '.flyai_flyai')
if os.path.exists(file_path):
file = open(file_path, 'r')
token = file.read()
return token
else:
file_path = join(__get_home_path(), "." + __genearteMD5(__get_mac() + __DOMAIN))
if os.path.exists(file_path):
file = open(file_path)
login_data = json.loads(file.read())
return login_data['token']
else:
file_path = os.path.join(sys.path[0], 'train.json')
if os.path.exists(os.path.join(sys.path[0], 'train.json')):
file = open(file_path)
login_data = json.loads(file.read())
return login_data['token']
def get_remote_date(remote_name):
if "http" in remote_name:
token = __get_token()
if token is not None:
return download_model(remote_name + "?token=" + __get_token(),
os.path.join(sys.path[0], 'data', 'input', 'model'), is_print=True)
else:
return None