python新手,代码不规范之处敬请见谅,第一次发帖排版也不太懂,各位将就看。
参考:
SIR模型和Python实现_阿丢是丢心心的博客-CSDN博客_sir模型
【保姆级教程】使用python实现SIR模型(包含数据集的制作与导入及最终结果的可视化)_自然卷的悲伤的博客-CSDN博客_sir模型python
这篇文章增加了平均次数,并且传播过程与参考文章有所不同,代码注释大家看参考文章的就行,SIR模型在上述参考文章中已经介绍很清楚了,咱就不赘述,直接上代码:
1.用到的包
import random
import networkx as nx
import re
import os
import matplotlib.pyplot as plt
import pandas as pd
import time
2. 将连边文件输入,建立Graph(最好是csv文件)
def readEdgeslistToGraph(edges_filename):
edge = []
with open(edges_filename, 'r', encoding='utf-8-sig') as f:
data = f.readlines()
for line in data:
line_str = line.replace('\r','').replace('\n','').replace('\t','')
line = re.split(',| ',line_str)
single_edge = tuple([line[0],line[1]])
edge.append(single_edge)
G = nx.Graph()
G.add_edges_from(edge)
return G
注意:是连边文件,或者说边集,比如
3. 初始化网络并随机选择初始“I”态节点
def randomChoiseInfectedNodes(G, initial_infected_nodes_num):
infected_nodes = random.sample(list(G.nodes), k=initial_infected_nodes_num)
for node in G.nodes():
G.nodes[node]["state"] = "S"
for node in infected_nodes:
G.nodes[node]["state"] = "I"
4.更新节点状态
##更新单个节点状态
def updateNodeState(G, node, infected_rate, recover_rate, old_G):
if old_G.nodes[node]["state"] == "I":
p = random.random()
if p < recover_rate:
G.nodes[node]["state"] = "R"
elif old_G.nodes[node]["state"] == "S":
k = 0
for neibor in old_G.adj[node]:
if old_G.nodes[neibor]["state"] == "I":
k += 1
p = random.random()
if p < (1 - (1 - infected_rate)**k):
G.nodes[node]["state"] = "I"
##更新整个网络节点状态
def updateNetworkState(G, infected_rate, recover_rate):
old_G = G.copy()
for node in G:
updateNodeState(G, node, infected_rate, recover_rate, old_G)
注意: 这里与参考文章不一致,参考文章中没有记录传播前的状态,假如某个节点从“I”态转为“R”态,那么这个节点的“S”态邻居转为“I”态的概率就会偏小(参考文章算这个概率的时候,此时已经把这个节点当成“R”态而非“I”态),因此这里引入old_G来记录每次传播前的状态。
5.计算各个态的节点数
def countSIRnum(G):
S = 0
I = 0
for node in G:
if G.nodes[node]["state"] == "S":
S += 1
if G.nodes[node]["state"] == "I":
I += 1
R = len(G.nodes) - S - I
return S, I, R
def eachIterateSIRnum(days):
eachiterate_SIR_list = []
for day in range(1, days+1):
updateNetworkState(G, infected_rate, recover_rate)
tuple_sir = countSIRnum(G)
print("day%s:\tS:%s\tI:%s\tR:%s"%(day,tuple_sir[0],tuple_sir[1],tuple_sir[2]))
eachiterate_SIR_list.append(list(tuple_sir))
return eachiterate_SIR_list
6.画图
def plotSIR(SIR_num_list):
color_dict = {"S": "blue", "I": "red", "R": "green"}
df = pd.DataFrame(SIR_num_list,columns=["S","I","R"])
df.plot(figsize=(9,6),color=[color_dict.get(x) for x in df.columns])
plt.ylabel("number")
plt.xlabel("day")
plt.show()
7.完整代码
import random
import networkx as nx
import re
import os
import matplotlib.pyplot as plt
import pandas as pd
import time
def readEdgeslistToGraph(edges_filename):
edge = []
with open(edges_filename, 'r', encoding='utf-8-sig') as f:
data = f.readlines()
for line in data:
line_str = line.replace('\r','').replace('\n','').replace('\t','')
line = re.split(',| ',line_str)
single_edge = tuple([line[0],line[1]])
edge.append(single_edge)
G = nx.Graph()
G.add_edges_from(edge)
return G
def randomChoiseInfectedNodes(G, initial_infected_nodes_num):
infected_nodes = random.sample(list(G.nodes), k=initial_infected_nodes_num)
for node in G.nodes():
G.nodes[node]["state"] = "S"
for node in infected_nodes:
G.nodes[node]["state"] = "I"
def updateNodeState(G, node, infected_rate, recover_rate, old_G):
if old_G.nodes[node]["state"] == "I":
p = random.random()
if p < recover_rate:
G.nodes[node]["state"] = "R"
elif old_G.nodes[node]["state"] == "S":
k = 0
for neibor in old_G.adj[node]:
if old_G.nodes[neibor]["state"] == "I":
k += 1
p = random.random()
if p < (1 - (1 - infected_rate)**k):
G.nodes[node]["state"] = "I"
def updateNetworkState(G, infected_rate, recover_rate):
old_G = G.copy()
for node in G:
updateNodeState(G, node, infected_rate, recover_rate, old_G)
def countSIRnum(G):
S = 0
I = 0
for node in G:
if G.nodes[node]["state"] == "S":
S += 1
if G.nodes[node]["state"] == "I":
I += 1
R = len(G.nodes) - S - I
return S, I, R
def eachIterateSIRnum(days):
eachiterate_SIR_list = []
for day in range(1, days+1):
updateNetworkState(G, infected_rate, recover_rate)
tuple_sir = countSIRnum(G)
# print("day%s:\tS:%s\tI:%s\tR:%s"%(day,tuple_sir[0],tuple_sir[1],tuple_sir[2]))
eachiterate_SIR_list.append(list(tuple_sir))
return eachiterate_SIR_list
def plotSIR(SIR_num_list):
color_dict = {"S": "blue", "I": "red", "R": "green"}
df = pd.DataFrame(SIR_num_list,columns=["S","I","R"])
df.plot(figsize=(9,6),color=[color_dict.get(x) for x in df.columns])
plt.ylabel("number")
plt.xlabel("day")
plt.show()
if __name__ =='__main__':
start = time.time()
edges_filename = 'l_1.txt'
initial_infected_nodes_num = 10
infected_rate = 0.1
recover_rate = 0.02
days = 100
iterate_num = 100
average_SIR_list = [[0,0,0] for i in range(days)]
for i in range(iterate_num):
G = readEdgeslistToGraph(edges_filename)
randomChoiseInfectedNodes(G, initial_infected_nodes_num)
eachiterate_SIR_list = eachIterateSIRnum(days)
for day in range(days):
for state in range(3):
average_SIR_list[day][state] += eachiterate_SIR_list[day][state]
for day in range(days):
for state in range(3):
average_SIR_list[day][state] = average_SIR_list[day][state]/iterate_num
plotSIR(average_SIR_list)
end = time.time()
print("running time: %.5s s"%(end - start))
8.运行结果
9.数据集
https://wwxd.lanzoue.com/iBdSw0i5uo6j
密码:h6s4