SIR模型python实现

python新手,代码不规范之处敬请见谅,第一次发帖排版也不太懂,各位将就看。

参考:

SIR模型和Python实现_阿丢是丢心心的博客-CSDN博客_sir模型

【保姆级教程】使用python实现SIR模型(包含数据集的制作与导入及最终结果的可视化)_自然卷的悲伤的博客-CSDN博客_sir模型python

这篇文章增加了平均次数,并且传播过程与参考文章有所不同,代码注释大家看参考文章的就行,SIR模型在上述参考文章中已经介绍很清楚了,咱就不赘述,直接上代码:

1.用到的包

import random
import networkx as nx
import re
import os
import matplotlib.pyplot as plt
import pandas as pd
import time

2. 将连边文件输入,建立Graph(最好是csv文件)

def readEdgeslistToGraph(edges_filename):
    edge = []
    with open(edges_filename, 'r', encoding='utf-8-sig') as f:
        data = f.readlines()
        for line in data:
            line_str = line.replace('\r','').replace('\n','').replace('\t','')
            line = re.split(',| ',line_str)
            single_edge = tuple([line[0],line[1]])
            edge.append(single_edge)
    G = nx.Graph()
    G.add_edges_from(edge)
    return G

 注意:是连边文件,或者说边集,比如

3. 初始化网络并随机选择初始“I”态节点

def randomChoiseInfectedNodes(G, initial_infected_nodes_num):
    infected_nodes = random.sample(list(G.nodes), k=initial_infected_nodes_num)
    for node in G.nodes():
        G.nodes[node]["state"] = "S"
    for node in infected_nodes:
        G.nodes[node]["state"] = "I"

4.更新节点状态

##更新单个节点状态
def updateNodeState(G, node, infected_rate, recover_rate, old_G):
    if old_G.nodes[node]["state"] == "I":
        p = random.random()
        if p < recover_rate:
            G.nodes[node]["state"] = "R"
    elif old_G.nodes[node]["state"] == "S":
        k = 0
        for neibor in old_G.adj[node]:
            if old_G.nodes[neibor]["state"] == "I":
                k += 1
        p = random.random()
        if p < (1 - (1 - infected_rate)**k):
            G.nodes[node]["state"] = "I"

##更新整个网络节点状态      
def updateNetworkState(G, infected_rate, recover_rate):
    old_G = G.copy()
    for node in G:
        updateNodeState(G, node, infected_rate, recover_rate, old_G)

注意: 这里与参考文章不一致,参考文章中没有记录传播前的状态,假如某个节点从“I”态转为“R”态,那么这个节点的“S”态邻居转为“I”态的概率就会偏小(参考文章算这个概率的时候,此时已经把这个节点当成“R”态而非“I”态),因此这里引入old_G来记录每次传播前的状态。

5.计算各个态的节点数

def countSIRnum(G):
    S = 0
    I = 0
    for node in G:
        if G.nodes[node]["state"] == "S":
            S += 1
        if G.nodes[node]["state"] == "I":
            I += 1 
    R = len(G.nodes) - S - I
    return S, I, R 


def eachIterateSIRnum(days):
    eachiterate_SIR_list = []
    for day in range(1, days+1):
        updateNetworkState(G, infected_rate, recover_rate)
        tuple_sir = countSIRnum(G)
        print("day%s:\tS:%s\tI:%s\tR:%s"%(day,tuple_sir[0],tuple_sir[1],tuple_sir[2]))
        eachiterate_SIR_list.append(list(tuple_sir))
    return eachiterate_SIR_list

6.画图

def plotSIR(SIR_num_list):
    color_dict = {"S": "blue", "I": "red", "R": "green"}    
    df = pd.DataFrame(SIR_num_list,columns=["S","I","R"]) 
    df.plot(figsize=(9,6),color=[color_dict.get(x) for x in df.columns])
    plt.ylabel("number")
    plt.xlabel("day")
    plt.show()

7.完整代码

import random
import networkx as nx
import re
import os
import matplotlib.pyplot as plt
import pandas as pd
import time


def readEdgeslistToGraph(edges_filename):
    edge = []
    with open(edges_filename, 'r', encoding='utf-8-sig') as f:
        data = f.readlines()
        for line in data:
            line_str = line.replace('\r','').replace('\n','').replace('\t','')
            line = re.split(',| ',line_str)
            single_edge = tuple([line[0],line[1]])
            edge.append(single_edge)
    G = nx.Graph()
    G.add_edges_from(edge)
    return G


def randomChoiseInfectedNodes(G, initial_infected_nodes_num):
    infected_nodes = random.sample(list(G.nodes), k=initial_infected_nodes_num)
    for node in G.nodes():
        G.nodes[node]["state"] = "S"
    for node in infected_nodes:
        G.nodes[node]["state"] = "I"
    

def updateNodeState(G, node, infected_rate, recover_rate, old_G):
    if old_G.nodes[node]["state"] == "I":
        p = random.random()
        if p < recover_rate:
            G.nodes[node]["state"] = "R"
    elif old_G.nodes[node]["state"] == "S":
        k = 0
        for neibor in old_G.adj[node]:
            if old_G.nodes[neibor]["state"] == "I":
                k += 1
        p = random.random()
        if p < (1 - (1 - infected_rate)**k):
            G.nodes[node]["state"] = "I"

            
def updateNetworkState(G, infected_rate, recover_rate):
    old_G = G.copy()
    for node in G:
        updateNodeState(G, node, infected_rate, recover_rate, old_G)
      
        
def countSIRnum(G):
    S = 0
    I = 0
    for node in G:
        if G.nodes[node]["state"] == "S":
            S += 1
        if G.nodes[node]["state"] == "I":
            I += 1 
    R = len(G.nodes) - S - I
    return S, I, R

def eachIterateSIRnum(days):
    eachiterate_SIR_list = []
    for day in range(1, days+1):
        updateNetworkState(G, infected_rate, recover_rate)
        tuple_sir = countSIRnum(G)
        # print("day%s:\tS:%s\tI:%s\tR:%s"%(day,tuple_sir[0],tuple_sir[1],tuple_sir[2]))
        eachiterate_SIR_list.append(list(tuple_sir))
    return eachiterate_SIR_list
    
def plotSIR(SIR_num_list):
    color_dict = {"S": "blue", "I": "red", "R": "green"}    
    df = pd.DataFrame(SIR_num_list,columns=["S","I","R"]) 
    df.plot(figsize=(9,6),color=[color_dict.get(x) for x in df.columns])
    plt.ylabel("number")
    plt.xlabel("day")
    plt.show()
      
    
    
if __name__ =='__main__':
    
    start = time.time()
    
    edges_filename = 'l_1.txt'
    initial_infected_nodes_num = 10
    infected_rate = 0.1
    recover_rate = 0.02
    days = 100
    iterate_num = 100
    
    average_SIR_list = [[0,0,0] for i in range(days)]
    for i in range(iterate_num):
        G = readEdgeslistToGraph(edges_filename)
        randomChoiseInfectedNodes(G, initial_infected_nodes_num)
        eachiterate_SIR_list = eachIterateSIRnum(days)
        for day in range(days):
            for state in range(3):
                average_SIR_list[day][state] +=  eachiterate_SIR_list[day][state] 
    for day in range(days):
        for state in range(3):
            average_SIR_list[day][state] =  average_SIR_list[day][state]/iterate_num
    
    plotSIR(average_SIR_list)
    
    end = time.time()
    print("running time: %.5s s"%(end - start))
    

 8.运行结果

 9.数据集

https://wwxd.lanzoue.com/iBdSw0i5uo6j

密码:h6s4

  • 4
    点赞
  • 28
    收藏
    觉得还不错? 一键收藏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值