[python]本科毕业设计基于机器学习的重庆轨道交通客流量时空分析预测源码+实现细节

介绍

本科毕业设计,基于机器学习的重庆轨道交通客流量时空分析预测

客流量统计计算
  • 对于站点:将每个站点抽象为一个图,利用弗洛伊德算法解决多源最短路径路径,求出乘客从站点A 进入,在站点D离开中途经过的所有站点,为途径的每个站点的人流量数加1,这样可以获取每个站点的日均人流量。
  • 对于线路(1号线,2号线的客流量):同用洛伊德算法,求出乘客从站点A 进入,在站点D离开中途经过的所有涉及到的线路,每个线路人流量加1。
利用BP神经网络进行预测
  • 对于每条线路的客流量分析

    • day,工作日,周六,周天
    • han_shu_jia,三种值,不是寒暑假,寒假,暑假
    • week_of_mouth,该周是一个月的中第几周
    • season_of_year,一年中的哪个季节
  • 只对日均人流量最大的10个站点进行人流预测,

    • type,所属线路
    • day,工作日,周六,周天
    • han_shu_jia,三种值,寒假,暑假, 不是假期
    • week_of_mouth,该周是一个月的中第几周
    • season_of_year,一年中的哪个季节

bpnn完成代码部分:

import xlrd
import sys
import queue
import pymysql
import pandas as pd

from datetime import timedelta

TimeType = {1: timedelta(hours=7, minutes=30),
            2: timedelta(hours=9, minutes=30),
            3: timedelta(hours=11, minutes=30),
            4: timedelta(hours=13, minutes=30),
            5: timedelta(hours=15, minutes=30),
            6: timedelta(hours=17, minutes=30),
            7: timedelta(hours=19, minutes=30),
            8: timedelta(hours=21, minutes=30),
            9: timedelta(hours=23, minutes=30),
            }

stationCount = 183
pathMatrix = [[-1 for col in range(stationCount)] for row in range(stationCount)]  # 用于需要两点间最短路径的矩阵,-1表示不可达
stationMatrix = [0 for col in range(stationCount)]  # 站点的逻辑矩阵,即逻辑序号到物理序号的映射
map = {}  # 站点物理序号到逻辑序号的映射
StationID_to_Station = {}  # 站点ID到Station的一个字典
LineID_to_Line = {}  # 线路ID到Line的一个字典
graphMatrix = [[sys.maxsize for col in range(stationCount)] for row in
               range(stationCount)]  # 使用邻接矩阵存储各站点间到达情况,可达到值为1,不可达为sys.maxsize


class Station:

    def __init__(self, lineID, stationId, name, transfer):
        self._stationID = stationId
        self._lineID = lineID
        self._name = name
        self._transfer = transfer
        self._passengerFlow = {}
        self._next = -1
        self._pre = -1
        self._timeFlow = {}
        time = timedelta(hours=5, minutes=0)
        while time <= timedelta(hours=23, minutes=50):
            self._timeFlow[time] = 0
            time = time + timedelta(minutes=5)

    def setPre(self, pre):
        self._pre = pre

    def resetPassengerFlow(self):
        self._passengerFlow = {}

    def resetTimeFlow(self):
        for key in self._timeFlow.keys():
            self._timeFlow[key] = 0

    def addTimeFlow(self, time):
        self._timeFlow[time] = self._timeFlow[time] + 1

    def addPassengerFlow(self, day, timeType):
        if day not in self._passengerFlow.keys():
            self._passengerFlow[day] = {}
        if timeType not in self._passengerFlow[day].keys():
            self._passengerFlow[day][timeType] = 0
        self._passengerFlow[day][timeType] = self._passengerFlow[day][timeType] + 1

    def setNext(self, next):
        self._next = next

    def getName(self):
        return self._name

    def getPre(self):
        return self._pre

    def getNext(self):
        return self._next

    def getTransfer(self):
        return self._transfer

    def getLineID(self):
        return self._lineID

    def getPassengerFlow(self):
        return self._passengerFlow

    def getTimeFlow(self):
        return self._timeFlow

    def show(self):
        print(self._name, self._stationID, self._lineID, self._pre, self._next, self._transfer)


class Line:

    def __init__(self, lineID):
        self._lineID = lineID
        self._passengerFlow = {}
        self._timeFlow = {}
        time = timedelta(hours=5, minutes=0)
        while time <= timedelta(hours=23, minutes=50):
            self._timeFlow[time] = 0
            time = time + timedelta(minutes=5)

    def resetPassengerFlow(self):
        self._passengerFlow = {}

    def addPassengerFlow(self, day, timeType):
        if day not in self._passengerFlow.keys():
            self._passengerFlow[day] = {}
        if timeType not in self._passengerFlow[day].keys():
            self._passengerFlow[day][timeType] = 0
        self._passengerFlow[day][timeType] = self._passengerFlow[day][timeType] + 1

    def resetTimeFlow(self):
        for key in self._timeFlow.keys():
            self._timeFlow[key] = 0

    def addTimeFlow(self, time):
        self._timeFlow[time] = self._timeFlow[time] + 1

    def getPassengerFlow(self):
        return self._passengerFlow

    def getLineID(self):
        return self._lineID

    def getTimeFlow(self):
        return self._timeFlow


# 用弗洛伊德算法计算个站点间的最短路径,以及路径矩阵
def Floyd():
    n = len(graphMatrix)
    for k in range(n):
        for i in range(n):
            for j in range(n):
                length = graphMatrix[i][k] + graphMatrix[k][j]
                if length < graphMatrix[i][j]:
                    graphMatrix[i][j] = length
                    pathMatrix[i][j] = k

    return graphMatrix, pathMatrix


# 输入进站ID 和出站ID 获取途径站点
def MinPath(inID, toID, pathQueue, lineIDSet):
    inLID = map[inID]
    toLID = map[toID]
    if pathMatrix[inLID][toLID] != toLID:
        tempID = stationMatrix[pathMatrix[inLID][toLID]]
        MinPath(inID, tempID, pathQueue, lineIDSet)
        MinPath(tempID, toID, pathQueue, lineIDSet)
    elif pathMatrix[inLID][toLID] == toLID:
        pathQueue.put(toID)
        lineIDSet.add(StationID_to_Station[toID].getLineID())

    return pathQueue


# 对轨道交通建立模型
def CreatGraph(filepath):
    global graphMatrix, pathMatrix
    excel = xlrd.open_workbook(filepath)  # 打开excel文件
    sheet = excel.sheet_by_index(0)  # 获取工作薄
    preStationID = -1
    preLineID = 0
    count = 1
    lineSet = set()
    for value in sheet:  # 遍历excel的每一行
        # 获取数据
        if count == 1:
            count = 2
            continue
        lineID = value[1].value
        lineSet.add(lineID)
        stationID = value[2].value
        stationName = value[3].value
        transfer = value[6].value
        # 存到字典里面
        StationID_to_Station[stationID] = Station(lineID, stationID, stationName, transfer)

        # 记录对于每一个站点在某一条线路中的前后站点
        if preLineID == lineID:
            StationID_to_Station[stationID].setPre(preStationID)
            StationID_to_Station[preStationID].setNext(stationID)
        preStationID = stationID
        preLineID = lineID

    for lineId in lineSet:
        LineID_to_Line[lineId] = Line(lineId)
    i = 0
    for stationID, station in StationID_to_Station.items():  # 建立好站点的逻辑ID和物理ID的映射
        stationMatrix[i] = stationID
        map[stationID] = i
        i = i + 1
    for stationID, station in StationID_to_Station.items():  # 构建好邻接矩阵,可达到值为1,不可达为sys.maxsize, 构建好路径矩阵
        m = map[stationID]
        graphMatrix[m][m] = 0  # 自己到自己的距离为0
        pathMatrix[m][m] = m
        if station.getPre() != -1:  # 不是起点站
            n = map[station.getPre()]
            graphMatrix[m][n] = 1
            pathMatrix[m][n] = n
        if station.getNext() != -1:  # 不是终点站
            n = map[station.getNext()]
            graphMatrix[m][n] = 1
            pathMatrix[m][n] = n
        if station.getTransfer() != 0:  # 是换乘站
            if StationID_to_Station[station.getTransfer()].getPre() != -1:
                n = map[StationID_to_Station[station.getTransfer()].getPre()]
                graphMatrix[m][n] = 1
                pathMatrix[m][n] = n
            if StationID_to_Station[station.getTransfer()].getNext() != -1:
                n = map[StationID_to_Station[station.getTransfer()].getNext()]
                graphMatrix[m][n] = 1
                pathMatrix[m][n] = n
    graphMatrix, pathMatrix = Floyd()
    return graphMatrix, pathMatrix, StationID_to_Station, map, stationMatrix


# 找到进站到出站进过的站点以及线路
def FindPth(inID, toID):
    pathQueue = queue.Queue()
    pathQueue.put(inID)
    lineIDSet = set()
    lineIDSet.add(StationID_to_Station[inID].getLineID())
    pathQueue = MinPath(inID, toID, pathQueue, lineIDSet)
    return pathQueue, lineIDSet


# 通过出站时间判断所作用的时间类型
def findTimeType(time):
    for key, values in TimeType.items():
        if 1 < key <= 9:
            if TimeType[key - 1] < time <= TimeType[key]:
                return key
        elif key == 1:
            if time < TimeType[key]:
                return key
    return 1


# 处理刷卡数据,并统计客流量信息,写入数据库
def getPassengerFlow():
    conn = pymysql.connect(host='localhost',  # 连接名称
                           user='root',  # 用户名
                           passwd='q19723011',  # 密码
                           port=3306,  # 端口,默认为3306
                           db='month6',  # 数据库
                           charset='utf8',  # 字符编码
                           )
    cur = conn.cursor()  # 生成游标对象

    sql_maxID = "SELECT max(id) id from month6"
    cur.execute(sql_maxID)
    maxID = cur.fetchall()[0][0]
    dataCount = 100000
    gap = 100000
    print("开始查询数据!")
    while dataCount <= maxID:
        sql = "select * from month6 where id  BETWEEN %d AND %d  " % (dataCount - gap + 1, dataCount)  # SQL语句
        print("开始处理数据!", dataCount - gap + 1, dataCount)
        dataCount = dataCount + gap
        cur.execute(sql)  # 执行SQL语句
        datas = cur.fetchall()  # 通过fetchall方法获得数据
        for data in datas:  # 对每一条刷卡数据进行处理
            timeType = findTimeType(data[6])  # 出站时间的类型,从5:30到23:30分 每两个小时作为一个类型
            dayType = data[10]  # 星期几
            inStationID = data[4]  # 进站ID
            toStationID = data[7]  # 出站ID
            pathQueue, lineIDSet = FindPth(inStationID, toStationID)
            while not pathQueue.empty():
                stationID = pathQueue.get()
                StationID_to_Station[stationID].addPassengerFlow(dayType, timeType)  # 站点客流量
            for lineID in lineIDSet:
                LineID_to_Line[lineID].addPassengerFlow(dayType, timeType)  # 线路客流量
        print("已经处理完", dataCount - gap, "条数据")

    try:
        # 插入数据到表lineFlow
        print("开始写表lineFlow")
        for lineID, line in LineID_to_Line.items():
            for day, dayFlow in line.getPassengerFlow().items():
                for timeType, passengerFlow in dayFlow.items():
                    # print(type(lineID),type(day),type(timeType),type(passengerFlow))
                    sql_lineFlow = """insert into lineFlow (lineID, day,timeType,passengerFlow) values(%d, %d,
                    %d,%d)""" % (int(lineID), (day), (timeType), (passengerFlow))

                    # 执行sql语句
                    cur.execute(sql_lineFlow)

        print("表lineFlow完毕")

        # 插入数据到表station
        print("开始写表station")
        for stationID, station in StationID_to_Station.items():
            sql_station = """insert into station (stationID, stationName,lineID,preStationID,nextStationID,transfer) 
                            values(%d,'%s',%d,%d,%d,%d)""" \
                          % (int(stationID), station.getName(), int(station.getLineID()),
                             int(station.getPre()), int(station.getNext()), int(station.getTransfer()))
            # 执行sql语句
            cur.execute(sql_station)
        print("表station完毕")

        # 插入数据到表stationFlow
        print("开始写表stationFlow")
        for stationID, station in StationID_to_Station.items():
            for day, dayFlow in station.getPassengerFlow().items():
                for timeType, passengerFlow in dayFlow.items():
                    sql_stationFlow = """insert into stationFlow (stationID, day,timeType,passengerFlow) values(%d, %d,
                    %d,%d)""" % (int(stationID), int(day), int(timeType), int(passengerFlow))
                    # 执行sql语句
                    cur.execute(sql_stationFlow)

        print("表stationFlow完毕")

        # 提交代码并保存变化
        conn.commit()
        print("代码执行完毕!")

    except pymysql.Error as err:
        print(err)

    print("完成")
    cur.close()  # 关闭游标
    conn.close()  # 关闭连接


# def modify(dddd):
#     conn = pymysql.connect(host='localhost',  # 连接名称
#                            user='root',  # 用户名
#                            passwd='q19723011',  # 密码
#                            port=3306,  # 端口,默认为3306
#                            db='month6',  # 数据库
#                            charset='utf8',  # 字符编码
#                            )
#     cur = conn.cursor()  # 生成游标对象
#     for days in range(dddd, dddd + 1):
#         if days < 10:
#             day_str = '0' + str(days)
#             print(day_str)
#         else:
#             day_str = str(days)
#             print(day_str)
#         for lineID, line in LineID_to_Line.items():
#             for day, dayFlow in line.getPassengerFlow().items():
#                 for timeType, passengerFlow in dayFlow.items():
#                     sql_lineDayFlow = """insert into linedayflow (lineID, date,day,timeType,flow)
#                                                        values(%d,%s,%d,%d,%d)""" \
#                                       % (int(lineID), '201806' + day_str, int(day), int(timeType), int(passengerFlow))
#                     # 执行sql语句
#                     cur.execute(sql_lineDayFlow)
#         sql_select = "SELECT * from month6 where  DATE_FORMAT(outDate,'%Y%m%d')=" + '201806' + day_str
#         # sql_select = "SELECT * from month6 where id <=10000"
#         print("开始查询", day_str, "的数据")
#         cur.execute(sql_select)
#         datas = cur.fetchall()
#         for data in datas:  # 对每一条刷卡数据进行处理
#             timeType = findTimeType(data[6])  # 出站时间的类型,从5:30到23:30分 每两个小时作为一个类型
#             dayType = data[10]  # 星期几
#             inStationID = data[4]  # 进站ID
#             toStationID = data[7]  # 出站ID
#             pathQueue, lineIDSet = FindPth(inStationID, toStationID)
#             while not pathQueue.empty():
#                 stationID = pathQueue.get()
#                 StationID_to_Station[stationID].addPassengerFlow(dayType, timeType)  # 站点客流量
#             for lineID in lineIDSet:
#                 LineID_to_Line[lineID].addPassengerFlow(dayType, timeType)  # 线路客流量
#         print(day_str, "数据处理完毕,开始写数据库!")
#         for stationID, station in StationID_to_Station.items():
#             for day, dayFlow in station.getPassengerFlow().items():
#                 for timeType, passengerFlow in dayFlow.items():
#                     # 执行sql语句
#                     sql_stationDayFlow = """insert into stationDayFlow (stationID, Date,day,timeType,flow)
#                                    values(%d,%s,%d,%d,%d)""" \
#                                          % (int(stationID), '201806' + day_str, int(day), int(timeType),
#                                             int(passengerFlow))
#                     cur.execute(sql_stationDayFlow)
#
#         for lineID, line in LineID_to_Line.items():
#             for day, dayFlow in line.getPassengerFlow().items():
#                 for timeType, passengerFlow in dayFlow.items():
#                     sql_lineDayFlow = """insert into linedayflow (lineID, date,day,timeType,flow)
#                                                        values(%d,%s,%d,%d,%d)""" \
#                                       % (int(lineID), '201806' + day_str, int(day), int(timeType), int(passengerFlow))
#                     # 执行sql语句
#                     cur.execute(sql_lineDayFlow)
#     conn.commit()
#     cur.close()  # 关闭游标
#     conn.close()  # 关闭连接

def findTime(time):
    [h, m, s] = str(time).split(':')[:3]
    h = int(h)
    m = int(m)
    new_m = 0
    if m % 10 < 5:
        new_m = int(m / 10) * 10
    else:
        new_m = int(m / 10) * 10 + 5
    return timedelta(hours=h, minutes=new_m)


def m_flow(day):
    conn = pymysql.connect(host='localhost',  # 连接名称
                           user='root',  # 用户名
                           passwd='q19723011',  # 密码
                           port=3306,  # 端口,默认为3306
                           db='month6',  # 数据库
                           charset='utf8',  # 字符编码
                           )
    cur = conn.cursor()  # 生成游标对象
    for days in range(day, day + 1):
        if days < 10:
            day_str = '0' + str(days)
            print(day_str)
        else:
            day_str = str(days)
            print(day_str)

        sql_select = "SELECT * from month6 where  DATE_FORMAT(outDate,'%Y%m%d')=" + '201806' + day_str

        # sql_select = "SELECT * from month6 where id <=10"
        print("开始查询", day_str, "的数据")
        cur.execute(sql_select)
        datas = cur.fetchall()
        print("开始处理", day_str, "的数据")
        l = len(datas)
        i = 0
        for data in datas:  # 对每一条刷卡数据进行处理
            if i % (int(l / 10)) == 0:
                print(int(i / (int(l / 10))))

            i = i + 1
            inTime = data[3]
            outTime = data[6]
            # timeType = findTimeType(data[6])  # 出站时间的类型,从5:30到23:30分 每两个小时作为一个类型
            # dayType = data[10]  # 星期几
            indate = data[2]  # 日期
            inStationID = data[4]  # 进站ID
            toStationID = data[7]  # 出站ID
            pathQueue, lineIDSet = FindPth(inStationID, toStationID)
            timeStep = ((outTime - inTime) / pathQueue.qsize())
            time_temp = inTime
            lineTime = time_temp - timedelta(minutes=5)
            while not pathQueue.empty():
                stationID = pathQueue.get()
                try:
                    StationID_to_Station[stationID].addTimeFlow(findTime(time_temp))
                except Exception as e:
                    print("except:", e)
                time_temp = time_temp + timeStep
                lineID = StationID_to_Station[stationID].getLineID()
                if time_temp >= lineTime + timedelta(minutes=5):
                    try:
                        LineID_to_Line[lineID].addTimeFlow(findTime(time_temp))  # 线路客流量
                    except Exception as e:
                        print("except:", e)
                    lineTime = time_temp
        print(day_str, "数据处理完毕,开始写数据库!")
        # sql_stationTimeFlow = """insert into station_time_flow (stationID, date, time, flow) values (%d,%s,%s,%d)""" \
        #                       % (3, '20180101', '85024', 101)
        # print(sql_stationTimeFlow)
        # cur.execute(sql_stationTimeFlow)
        # print("123")
        for stationID, station in StationID_to_Station.items():
            for time, timeFlow in station.getTimeFlow().items():
                # 执行sql语句
                sql_stationTimeFlow = """insert into station_time_flow (stationID, date ,time,flow)
                                   values(%d,%s,%s,%d)""" \
                                      % (int(stationID), indate.strftime("%Y%m%d"),
                                         "".join(str(time).split(':')[:3]), timeFlow)
                try:
                    cur.execute(sql_stationTimeFlow)
                except Exception as e:
                    print("except:", e)
                    print(sql_stationTimeFlow)

        for lineID, line in LineID_to_Line.items():
            for time, timeFlow in line.getTimeFlow().items():
                # 执行sql语句
                sql_lineTimeFlow = """insert into line_time_flow (lineID, date ,time,flow)
                                   values(%d,%s,%s,%d)""" \
                                   % (int(lineID), indate.strftime("%Y%m%d"),
                                      "".join(str(time).split(':')[:3]), timeFlow)
                try:
                    cur.execute(sql_lineTimeFlow)
                except Exception as e:
                    print("except:", e)
                    print(sql_lineTimeFlow)
    print(day_str, "结束!!!!!!!")
    conn.commit()
    cur.close()  # 关闭游标
    conn.close()  # 关闭连接


def reset():
    for stationID, station in StationID_to_Station.items():
        station.resetTimeFlow()
    for lineID, line in LineID_to_Line.items():
        line.resetTimeFlow()


if __name__ == "__main__":
    filename = "D:/PythonProject/graduation-project/DATA/stopinfo.xls"
    CreatGraph(filename)
    # m_flow(1)
    # print(timedelta(seconds=85200))
    for day in range(1, 31):
        reset()
        m_flow(day)
#     # getPassengerFlow()

bpnn实现:

import datetime
import math
import random
from datetime import timedelta
import joblib
import numpy
import numpy as np
import pymysql
from matplotlib import pyplot as plt, pyplot
from numpy import float16, float32
from sklearn import metrics
from sklearn.metrics import mean_absolute_error
from sklearn.neural_network import MLPRegressor
from sklearn.preprocessing import OneHotEncoder, MinMaxScaler
from sklearn import ensemble
from sklearn.svm import SVR
import pickle

rf_model_dict= {}
rf_model_dict["station"]={}
rf_model_dict["line"]={}


def get_data():
    conn = pymysql.connect(host='localhost',  # 连接名称
                           user='root',  # 用户名
                           passwd='q19723011',  # 密码
                           port=3306,  # 端口,默认为3306
                           db='month6',  # 数据库
                           charset='utf8',  # 字符编码
                           )
    cur = conn.cursor()  # 生成游标对象
    sql = "select * from stationflow  "
    cur.execute(sql)
    datas = list(cur.fetchall())
    conn.commit()
    cur.close()  # 关闭游标
    conn.close()  # 关闭连接
    return datas


def train(datas):
    # 处理数据,划分训练集,测试集,归一化处理,独热编码处理
    data_count = len(datas)
    # random.shuffle(datas)  # 打乱数据
    datas = np.array(datas)
    data_case = datas[:, 0:6]  # 获取特征值
    data_label = datas[:, 6:7]  # 获取标签
    mm = MinMaxScaler()
    data_label_process = mm.fit_transform(data_label)  # 对数据归一化处理
    # enc = OneHotEncoder(sparse=False)
    # enc.fit(data_case[:, 0:1])
    # data_case_h = enc.transform(data_case[:, 0:1])  # 对特征值进行独热编码
    mm_case = MinMaxScaler()
    data_case_hot = mm_case.fit_transform(data_case)  # 对数据归一化处理
    # data_case_hot = np.append(data_case_h, data_case_mm, axis=1)
    # # 以7:3划分训练集和测试集
    # train_data_case = data_case_hot[0:int(data_count * 0.7)]
    # train_data_label = data_label_process[0:int(data_count * 0.7)]
    # test_data_case = data_case_hot[int(data_count * 0.7):]
    # test_data_label = data_label_process[int(data_count * 0.7):]
    # 最后一个是预测目标
    test_data_case = data_case_hot[0:1]
    test_data_label = data_label_process[0:1]
    train_data_case = data_case_hot[1:]
    train_data_label = data_label_process[1:]
    # 训练模型
    model = MLPRegressor(hidden_layer_sizes=(7, 8, 8), activation='tanh', solver='adam', max_iter=2000,
                         learning_rate='adaptive', learning_rate_init=0.02)  # BP神经网络回归模型
    model.fit(train_data_case, train_data_label.ravel())  # 训练模型
    pre_train = model.predict(train_data_case)  # 模型训练集预测
    pre_test = model.predict(test_data_case)  # 模型测试机预测
    pre = mm.inverse_transform(np.append(pre_train, pre_test).reshape(1, -1))[0]  # 反归一化
    return [data_label[0], pre[0]]


def bpnn_data(time, stationID):
    conn = pymysql.connect(host='localhost',  # 连接名称
                           user='root',  # 用户名
                           passwd='q19723011',  # 密码
                           port=3306,  # 端口,默认为3306
                           db='month6',  # 数据库
                           charset='utf8',  # 字符编码
                           )
    cur = conn.cursor()  # 生成游标对象
    dataProcess = [[], []]
    start_time = "".join(str(time - timedelta(minutes=30)).split(':')[:3])
    sql = "select * from station_time_flow where time >= %s and  time <= %s and stationID = %d" \
          % (start_time, "".join(str(time).split(':')[:3]), stationID)
    cur.execute(sql)
    datas = list(cur.fetchall())
    for i in range(0, len(datas), 7):
        week = datas[i][1].weekday() + 1
        if 1 <= week <= 5:
            dayType = 0
        else:
            dayType = 1
        data_list = [datas[i][3], datas[i + 1][3],
                     datas[i + 2][3], datas[i + 3][3], datas[i + 4][3],
                     datas[i + 5][3], datas[i + 6][3]]
        dataProcess[dayType].append(data_list)

    conn.commit()
    cur.close()  # 关闭游标
    conn.close()  # 关闭连接
    return [train(dataProcess[0]), train(dataProcess[1])]


def bpnn(stationID):
    m = timedelta(hours=7)
    conn = pymysql.connect(host='localhost',  # 连接名称
                           user='root',  # 用户名
                           passwd='q19723011',  # 密码
                           port=3306,  # 端口,默认为3306
                           db='month6',  # 数据库
                           charset='utf8',  # 字符编码
                           )
    cur = conn.cursor()  # 生成游标对象
    res_list = [[[], []], [[], []]]
    datelist = ["20180601", "20180602"]
    while m <= timedelta(hours=23):
        res = bpnn_data(m, stationID)
        for i in range(len(datelist)):
            sql = """  insert into stationprebpnn(stationID,date,time,oriflow,preflow) values (%d,%s,%s,%d,%d)""" \
                  % (stationID, datelist[i], "".join(str(m).split(':')[:3]), res[i][0], res[i][1])
            try:
                cur.execute(sql)
            except Exception as e:
                print("except:", e)
                print(sql)
            res_list[i][0].append(res[i][0])
            res_list[i][1].append(res[i][1])
        m = m + timedelta(minutes=5)
    for i in range(len(res_list)):
        mse = metrics.mean_squared_error(res_list[i][0], res_list[i][1])
        rmse = np.sqrt(mse)
        mae = metrics.mean_absolute_error(res_list[i][0], res_list[i][1])
        mape = np.mean(np.abs((np.array(res_list[i][0]) - np.array(res_list[i][1])) / np.array(res_list[i][0]))) * 100
        # print("开始画图")
        # show(res_list[i][0], res_list[i][1])
        sql = """insert into station_evaluating_indicator_bpnn(stationID,date,mse,rmse,mae) values (%d,%s,%f,%f,%f
                    )""" % (stationID, datelist[i], mse, rmse, mae)
        try:
            cur.execute(sql)
        except Exception as e:
            print("except:", e)
            print(sql)

    conn.commit()
    cur.close()  # 关闭游标
    conn.close()  # 关闭连接


def show(ori_data, pre_data):
    x = range(len(ori_data))
    y1 = ori_data
    y2 = pre_data
    plt.title('bpnn')  # 折线图标题
    plt.rcParams['font.sans-serif'] = ['SimHei']  # 显示汉字
    plt.xlabel('时间')  # x轴标题
    plt.ylabel('客流量')  # y轴标题
    plt.plot(x, y1, marker='o', markersize=1)  # 绘制折线图,添加数据点,设置点的大小
    plt.plot(x, y2, marker='o', markersize=1)

    # for a, b in zip(x, y1):
    #     plt.text(a, b, b, ha='center', va='bottom', fontsize=10)  # 设置数据标签位置及大小
    # for a, b in zip(x, y2):
    #     plt.text(a, b, b, ha='center', va='bottom', fontsize=10)

    plt.legend(['实际', '预测'])  # 设置折线名称

    plt.show()  # 显示折线图

def day_type(day):
    if 1 <= day <= 5:
        return 0
    else:
        return 1
    return 1
def bpnnDateBase():
    conn = pymysql.connect(host='localhost',  # 连接名称
                           user='root',  # 用户名
                           passwd='q19723011',  # 密码
                           port=3306,  # 端口,默认为3306
                           db='month6',  # 数据库
                           charset='utf8',  # 字符编码
                           )
    cur = conn.cursor()  # 生成游标对象
    sql = """select *from station"""
    cur.execute(sql)
    datas = list(cur.fetchall())
    stationIDs = np.array(datas)[:, 0:1]
    conn.commit()
    cur.close()  # 关闭游标
    conn.close()  # 关闭连接
    for stationID in stationIDs:
        print(stationID)
        bpnn(stationID)



def rf_station_data(time, stationID, date):
    conn = pymysql.connect(host='localhost',  # 连接名称
                           user='root',  # 用户名
                           passwd='q19723011',  # 密码
                           port=3306,  # 端口,默认为3306
                           db='month6',  # 数据库
                           charset='utf8',  # 字符编码
                           )
    cur = conn.cursor()  # 生成游标对象
    dataProcess = []
    start_time = "".join(str(time - datetime.timedelta(minutes=30)).split(':')[:3])
    sql = "select * from station_time_flow where time >= %s and  time <= %s and stationID = %d" \
          % (start_time, "".join(str(time).split(':')[:3]), int(stationID))
    cur.execute(sql)
    datas = list(cur.fetchall())
    ob_dayType = day_type(datetime.datetime.strptime(date, "%Y-%m-%d").date().weekday() + 1)
    ob_data_list = list()
    for i in range(0, len(datas), 7):
        if datas[i][1] == datetime.datetime.strptime(date, "%Y-%m-%d").date():
            ob_data_list = [datas[i][3], datas[i + 1][3],
                            datas[i + 2][3], datas[i + 3][3], datas[i + 4][3],
                            datas[i + 5][3], datas[i + 6][3]]
        elif day_type(datas[i][1].weekday() + 1) == ob_dayType:
            data_list = [datas[i][3], datas[i + 1][3],
                         datas[i + 2][3], datas[i + 3][3], datas[i + 4][3],
                         datas[i + 5][3], datas[i + 6][3]]
            dataProcess.append(data_list)
    dataProcess.append(ob_data_list)

    conn.commit()
    cur.close()  # 关闭游标
    conn.close()  # 关闭连接
    return rf_train(dataProcess)



def rf_train(datas):
    # 处理数据,划分训练集,测试集,归一化处理,独热编码处理
    datas = np.array(datas)
    data_case = datas[:, 0:6]  # 获取特征值
    data_label = datas[:, 6:7]  # 获取标签
    mm = MinMaxScaler()
    data_label_process = mm.fit_transform(data_label)  # 对数据归一化处理
    mm_case = MinMaxScaler()
    data_case_process = mm_case.fit_transform(data_case)  # 对数据归一化处理
    test_data_case = data_case_process[len(data_case_process) - 1:]
    test_data_label = data_label_process[len(data_label_process) - 1:]
    train_data_case = data_case_process[0:len(data_case_process) - 1]
    train_data_label = data_label_process[0:len(data_label_process) - 1]
    # 训练模型
    model = SVR(kernel='rbf')
    model.fit(train_data_case, train_data_label.ravel())

    # print("Traing Score:%f" % regr.score(X_train, y_train))
    # print("Testing Score:%f" % regr.score(X_test, y_test))
    pre_train = model.predict(train_data_case)  # 模型训练集预测
    pre_test = model.predict(test_data_case)  # 模型测试机预测






    # model = MLPRegressor(hidden_layer_sizes=(7, 8, 8), activation='tanh', solver='adam', max_iter=2000,
    #                      learning_rate='adaptive', learning_rate_init=0.02)  # BP神经网络回归模型
    # model.fit(train_data_case, train_data_label.ravel())  # 训练模型
    # pre_train = model.predict(train_data_case)  # 模型训练集预测
    # pre_test = model.predict(test_data_case)  # 模型测试机预测
    print("123")
    pre = mm.inverse_transform(pre_test.reshape(1, -1))[0]  # 反归一化
    # show(data_label,pre)
    return [float(data_label[-1][0]), float(int(pre[-1])),model]


def rf_predict():

    lineID = '3'
    stationID = '103'
    date = '2018-06-03'
    dayType = datetime.datetime.strptime(date, "%Y-%m-%d").weekday() + 1
    # stationName = Station.objects.filter(stationid=stationID).values()[0]["stationname"]
    m = datetime.timedelta(hours=7)
    station_test_flow = list()
    station_pre_flow = list()
    line_test_flow = list()
    line_pre_flow = list()
    # rf_model_dict["station"][stationID] = {}
    # rf_model_dict["station"][stationID]["workday"] = {}
    # rf_model_dict["station"][stationID]["weekday"] = {}
    model_dict ={}
    while m <= datetime.timedelta(hours=23):
        print(m)
        station_res = rf_station_data(m, stationID, date)
        model = station_res[2]
        model_dict[m]=model
        # line_res = bpnn_line_data(m, lineID, date)
        station_test_flow.append(station_res[0])
        station_pre_flow.append(station_res[1])
        # line_test_flow.append(line_res[0])
        # line_pre_flow.append(line_res[1])
        m = m + datetime.timedelta(minutes=5)
    with open('D:\PythonProject\graduationProject\saveModels\RF\station\clf.pkl', 'wb') as f:
        pickle.dump(model_dict, f)

    station_mse = metrics.mean_squared_error(station_test_flow, station_pre_flow)
    station_rmse = np.sqrt(station_mse)
    station_mae = metrics.mean_absolute_error(station_test_flow, station_pre_flow)
    # line_mse = metrics.mean_squared_error(line_test_flow, line_pre_flow)
    # line_rmse = np.sqrt(line_mse)
    # line_mae = metrics.mean_absolute_error(line_test_flow, line_pre_flow)
    # response = {"line": {"lineID": lineID, "lineName": str(lineID) + "号线", "date": date,
    #                      "testflow": line_test_flow, "preflow": line_pre_flow,
    #                      "mse": int(line_mse), "rmse": int(line_rmse), "mae": int(line_mae)},
    #             "station": {"stationID": stationID, "stationName": stationName, "date": date,
    #                         "testflow": station_test_flow, "preflow": station_pre_flow,
    #                         "mse": int(station_mse), "rmse": int(station_rmse), "mae": int(station_mae)}
    #             }
    show(station_test_flow,station_pre_flow)
    print(station_mae,station_rmse,station_mse)
    # return response
if __name__ == '__main__':
    # bpnnDateBase()
    # bpnn(102)
    rf_predict()
    with open('D:\PythonProject\graduationProject\saveModels\RF\station\clf.pkl', 'rb') as f:
        clf2 = pickle.load(f)
        # 测试读取后的Model
        print(clf2)

完整源码下载:https://download.csdn.net/download/FL1768317420/89207090

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

海神之光.

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值