加载已训练好的模型pytorch

from __future__ import division
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import numpy as np
from utilscopy import *
from collections import OrderedDict
import torch.optim as optim
from utilscopy import *
import random
import pandas as pd
from basic_structure import IGNNK
import argparse
import sys
import os
import time


def parse_args(args):
    '''Parse training options user can specify in command line.
    Specify hyper parameters here

    Returns
    -------
    argparse.Namespace
        the output parser object
    '''
    parser = argparse.ArgumentParser(
        formatter_class=argparse.RawDescriptionHelpFormatter,
        description="Parse argument used when training IGNNK model.",
        epilog="python IGNNK_train.py DATASET, for example: python IGNNK_train.py 'metr' ")

    # Requird input parametrs
    parser.add_argument(
        'dataset', type=str, default='sanxia',
        help='Name of the datasets, select from metr, nrel, ushcn, sedata or pems'
    )

    # optional input parameters
    parser.add_argument(
        '--n_o', type=int, default=7,
        help='sampled space dimension'
    )
    parser.add_argument(
        # '--h',type=int,default=24,
        '--h', type=int, default=24,
        help='sampled time dimension'
    )
    parser.add_argument(
        '--z', type=int, default=100,
        help='hidden dimension for graph convolution'
    )
    parser.add_argument(
        '--K', type=int, default=2,
        help='If using diffusion convolution, the actual diffusion convolution step is K+1'
    )
    parser.add_argument(
        # '--n_m',type=int,default=50,
        '--n_m', type=int, default=0,  # do not know y parameter here works,but it works
        help='number of mask node during training'
    )
    parser.add_argument(
        # '--n_u',type=int,default=50,
        '--n_u', type=int, default=0,
        help='target locations, n_u locations will be deleted from the training data'
    )
    parser.add_argument(
        '--max_iter', type=int, default=80,
        help='max training episode'
    )
    parser.add_argument(
        '--learning_rate', type=float, default=0.00001,
        help='the learning_rate for Adam optimizer'
    )
    parser.add_argument(
        '--E_maxvalue', type=int, default=371,
        help='the max value from experience'
    )
    parser.add_argument(
        '--batch_size', type=int, default=4,
        help='Batch size'
    )
    parser.add_argument(
        '--to_plot', type=bool, default=True,
        help='Whether to plot the RMSE training result'
    )
    return parser.parse_known_args(args)[0]
def plot_o(o_list):
    """
    Draw Learning curves on testing error
    """
    fig, ax = plt.subplots()

    ax.plot(o_list,label='t',linewidth=3.5)
    ax.set_xlabel('the last 90 days',fontsize=20)
    ax.set_ylabel(' horizontal displacement.',fontsize=20)
    ax.tick_params(axis="x", labelsize=14)
    ax.tick_params(axis="y", labelsize=14)
    ax.legend(fontsize=16)
    plt.grid(True)
    plt.tight_layout()
    plt.savefig('预测结果.pdf')
def plot_X(o_list):
    """
    Draw Learning curves on testing error
    """
    fig, ax = plt.subplots()

    ax.plot(o_list,label='t',linewidth=3.5)
    ax.set_xlabel('the last 90 days',fontsize=20)
    ax.set_ylabel(' horizontal displacement.',fontsize=20)
    ax.tick_params(axis="x", labelsize=14)
    ax.tick_params(axis="y", labelsize=14)
    ax.legend(fontsize=16)
    plt.grid(True)
    plt.tight_layout()
    plt.savefig('观测结果.pdf')


if __name__ == "__main__":
    state_dict= torch.load('IGNNK_sanxia_400iter_2021-10-16_10_54_10_x_40%_RM.pth')
    '''
    for k, v in state_dict.items():
        name = k[7:]  # remove `module.`,表面从第7个key值字符取到最后一个字符,正好去掉了module.
        new_state_dict[name] = v  # 新字典的key值对应的value为一一对应的值。
    '''
    my_model = IGNNK(30,100,2)
    my_tumple = my_model.load_state_dict(state_dict)
   # my_tumple = torch.nn.Module().load_state_dict(state_dict)
   # my_model = torch.nn.Module().load_state_dict(torch.load('IGNNK_sanxia_400iter_2021-10-16_10_54_10_x_40%_RM.pth'))
    A, X = load_landslide_data()
    X = X.transpose()
    X = X[-90:, :]
    maxvalue = 371
    STmodel = IGNNK
    o = np.zeros((90, 7))
    for item in range(0, 3):
        T0 = X[item*30:(item + 1)* 30, :] / maxvalue
        T0 = np.expand_dims(T0, axis=0)
        T0 = torch.from_numpy(T0.astype('float32'))
        A_q = torch.from_numpy((calculate_random_walk_matrix(A).T).astype('float32'))
        A_h = torch.from_numpy((calculate_random_walk_matrix(A.T).T).astype('float32'))
        imputation_ver = my_model(T0, A_q, A_h)
        # print('this is imputation')
        # print(imputation)
        imputation_ver = imputation_ver.data.numpy()
        o[item*30:(item + 1) * 30, :] = imputation_ver[0, :, :]
    o = o * maxvalue
    plot_o(o)
    plot_X(X)

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值