day1
线性DP
例题1
t = int(input())
inf=0x3f3f3f3f3f3f3f3f
for _ in range(t):
n, a, b, k = map(int, input().split())
A = [0] + list(map(int, input().split()))
#dp[i][j]表示跳跃了i次后当前处在j位置能吃到的最大昆虫数量
dp = [[-inf] * 105 for i in range(105)]
#初始化状态,未跳跃前的状态
dp[0][0]=0
ans = 0
for i in range(1, k + 1):#条件,最多跳k次
for j in range(1, n + 1):#枚举位置
for p in range(a, b + 1):#求解步骤,每次跳T(a<=T<=b)个单位
if j - p < 0:
break
else:
#上一状态为dp[i-1][j']跳跃了i-1次,位置为j-p
#要想转移必须满足a<=j-j'<=b
dp[i][j] = max(dp[i][j], dp[i - 1][j - p] + A[j])
ans = max(ans, dp[i][j])
print(ans)
例题2
思路:
dp[i][j]表示处理到第i个数字,分出j个区间的最大值
dp[0][0]=0未划分区间时的最大值为0
将划分作为转移,那么上一次划分就是dp[i'][j-1]表示划分j-1个区间时最后的元素是i的情况
如何转移呢?
状态可以从“划分为j个区间,最后一个元素为i,元素i融入第j个区间”和“划分为j个区间,最后一个元素为i,元素i单独一个区间”
如下式所示:简而言之就是将a[i]新开一个区间还是并入旧区间
dp[i][j]=max(dp[i-1][j],dp[i-1][j-1])+a[i]*p[j]
inf=0x3f3f3f3f3f3f3f3f
n,k=map(int,input().split())
a=[0]+list(map(int,input().split()))
p=[0]+list(map(int,input().split()))
#dp[i][j]表示处理到第i个数字,分出j个区间的最大值
#对于dp[i][j],它只与dp[i-1][j']有关,使用滚动数组节省空间
dp=[[-inf] * (k+1) for i in range(2)]
dp[0][0]=0
now=0
pre=1
for i in range(1,n+1):
now,pre=pre,now
dp[now]=[-inf] * (k+1)#每次更新前,初始化要更新的数组
for j in range(1,k+1):#枚举区间个数
dp[now][j]=max(dp[pre][j-1],dp[pre][j])+a[i]*p[j]
print(dp[now][k])
例题3
n,m=map(int,input().split())
mod=1000000007
#dp[i][j][k]表示遇到i次酒店,j次花,剩余k酒,的方案数
#dp[n][m-1][1]表示终止
#dp[n][m-1][1]表示遇到n次酒店,m-1次花,剩余1酒最后一次必定为花
dp=[[[0 for k in range(m+1)] for j in range(m)] for i in range(n+1)]
dp[0][0][2]=1
for i in range(n+1):
for j in range(m):
if i==0 and j==0:
continue
for k in range(m+1):
if i-1>=0 and k%2==0:#遇到的是店,遇店加一倍,剩余酒必为偶数,相比上一状态店数加1,酒数加倍
dp[i][j][k]+=dp[i-1][j][k//2]
dp[i][j][k]%=mod
if j-1>=0 and k+1<=m:#遇到的是花,遇花喝一斗,遇到m次花,最多剩余m斗,相比上一状态花数加1,酒数减1
dp[i][j][k]+=dp[i][j-1][k+1]
dp[i][j][k]%=mod
print(dp[n][m-1][1])
例题4
n=int(input())
a=[0]+list(map(int,list(input())))[::-1]#高位操作不影响低位,反转一下先处理高位
b=[0]+list(map(int,list(input())))[::-1]
dp=[[0,0,0]for i in range(n+1)]#从a到b的三种状态
dp[1][0]=abs(a[1]-b[1])#不进位不退位
dp[1][1]=10-a[1]+b[1]#进位
dp[1][2]=10-b[1]+a[1]#退位
for i in range(2,n+1):
#当前位通过不进位得到,分三种情况:上一位通过不进位得到、上一位通过进位得到、上一位通过退位得到
dp[i][0]=min(dp[i-1][0]+abs(a[i]-b[i]),dp[i-1][1]+abs(a[i]+1-b[i]),dp[i-1][2]+abs(a[i]-1-b[i]))
#当前位通过进位得到,分三种情况:上一位通过不进位得到、上一位通过进位得到、上一位通过退位得到
dp[i][1]=min(dp[i-1][0]+10-a[i]+b[i],dp[i-1][1]+9-a[i]+b[i],dp[i-1][2]+11-a[i]+b[i])
#当前位通过退位得到,分三种情况:上一位通过不进位得到、上一位通过进位得到、上一位通过退位得到
dp[i][2]=min(dp[i-1][0]+10-b[i]+a[i],dp[i-1][1]+11-b[i]+a[i],dp[i-1][2]+9-b[i]+a[i])
print(min(dp[n][0],dp[n][1],dp[n][2]))
背包问题
例题1(多重背包问题)
在博客 中直接将有s[i]个的物品i看作i种物品,然后基于0/1背包思路求解,这样复杂度为m*sum(s[i])
还有另外的一种方式拆分,将s[i]按照二进制拆分,分为{1,2,4,8...}个,这样可以划分出所有数量的物品,假设划分的最大数量时2**k,那么最后一个时s[i]-2**(k+1)+1。复杂度降低为m*log2(s[i])
N,V=map(int,input().split())
w=[0]*20003
v=[0]*20003
dp=[0]*2003
index=0#新物品的编号
for i in range(1,N+1):
wi,vi,si=map(int,input().split())
if si==0:
si=V//wi
cnt=1
while si>cnt:#二进制拆分物品数量
index+=1
si-=cnt#原数量减少
w[index]=cnt*wi#存储一个拆分
v[index]=cnt*vi
cnt*=2#二进制拆分下一个物品数量
index+=1
w[index]=si*wi
v[index]=si*vi
for i in range(1,index+1):
for j in range(V,w[i]-1,-1):
dp[j]=max(dp[j],dp[j-w[i]]+v[i])
print(dp[V])
例题2(分组背包问题)
分组背包问题描述:
n个物品,第i组有c[i]个物品,第i组中第j个物品的价值为p[i][j],体积为v[i][j],每组物品最多只能选择一个。背包大小为m,问装载的最大价值。
思考:
定义dp[i][j]为处理到第i组物品,装载体积为j的最大价值
dp[i][j]=max(dp[i-1][j],max(dp[i-1][j-v[i][j]]+p[i][j]))
如果物品A有两个附属物品B·、C·,那么对于这三件物品有{},{A},{A,B},{A,C},{A,B,C} 对于这几种情况只能选择一种,对与有一个或者零个附属品的情况可以类比。
每组物品只能选择一种组合,显然是一个分组背包问题。
N,m=map(int,input().split())
v,w,q=[0],[0],[0]
G=[[]for i in range(m+1)]
dp=[0]*(N+1)
for i in range(1,m+1):
vi,pi,qi=map(int,input().split())
if qi!=0:
G[qi].append(i)
v.append(vi)
w.append(pi)
q.append(qi)
for i in range(1,m+1):
if q[i]==0:#遍历组
for j in range(N,v[i]-1,-1):
#每组仅有一个主件物品
dp[j]=max(dp[j],dp[j-v[i]]+v[i]*w[i])
if G[i]:#遍历附属物品
for k in range(len(G[i])):
if j-v[i]-v[G[i][k]]>0:
#选择一个附属物品
dp[j]=max(dp[j],dp[j-v[i]-v[G[i][k]]]+v[i]*w[i]+v[G[i][k]]*w[G[i][k]])
if len(G[i])>1 and j-v[i]-v[G[i][0]]-v[G[i][1]]>=0:
#选择两个附属物品
dp[j]=max(dp[j],dp[j-v[i]-v[G[i][0]]-v[G[i][1]]]+v[i]*w[i]+v[G[i][0]]*w[G[i][0]]
+v[G[i][1]]*w[G[i][1]])
print(dp[N])
例题3(完全背包问题)
思考:
将每一个有空的时间段长度看作一个背包的容量,将健身的持续时间看作物品的体积,收益看作物品的价值,然后就可以看作完全背包问题。
n,m,q=map(int,input().split())
t=list(map(int,input().split()))
k=[]#存储每个训练计划需要的天数
s=[]#存储每个训练计划获得的增益
bag=[]#存储有空的时间段长度
day=[i for i in range(1+n)]
sum_=0
for i in range(m):
ki,si=map(int,input().split())
k.append(2**ki)
s.append(si)
for i in range(1,len(t)):
bagi=day[t[i]]-day[t[i-1]]-1#计算相邻有其他安排时间点之间的时间
bag.append(bagi)
bag.insert(0,t[0]-1)#第一次没空的时间点之前的有空时间
bag.append(n-t[-1])#最后一次没空的时间点之后的有空时间
for size in bag:
dp=[0] * (size+5)
for i in range(len(k)):
for j in range(k[i],size+1):
dp[j]=max(dp[j],dp[j-k[i]]+s[i])
sum_+=dp[size]
print(sum_)
例题4
N=int(input())
A=[0]
dp=[0]*(10001)
dp[0]=1
for i in range(1,N+1):
Ai=int(input())
A.append(Ai)
def gcd(a,b):
if b==0:
return a
else:
return gcd(b,a%b)
g=A[1]
for i in range(2,N+1):
g=gcd(g,A[i])
if g!=1:#最大公约数不为1,则只能凑数最大公约数的倍数,显然有无数种情况凑不出来
print("INF")
else:
for i in range(1,N+1):
for j in range(A[i],10001):
dp[j]=max(dp[j],dp[j-A[i]])#标记容量为A[i]倍数的背包
print(10001-sum(dp))
day2
状压DP
例题1
思考:
枚举出所有的路径(18!),再比较,这样时间复杂度显然太高
换一种思路:
假设有五个星球,我们已经攻击了第1,2,3个星球,那么状态可以表示为{11100};如果说下一次要攻击4,那么状态就会变为{11110};如果不攻击4而是攻击5,那么状态就变为{11101};
注意:二进制右边为低位,我们将第一个元素放在右边。
再添加一个描述i,dp[s][i]就表示攻击了s的状态下我们最后攻击的星球是i的最小值
import math
n = int(input())
N=20
M = 1<<20
inf = 0x3f3f3f3f3f3f3f3f
#dp[s][i]表示s状态走到星球i时最小能量
#状态(即那些星球走过了,有2**n个状态)
dp = [[inf] * N for i in range(M)]
dis = [[0] * N for i in range(N)]
w = [0] * N
nodes = [0] * N
class Node:
def __init__(self,x,y,z):
self.x=x
self.y=y
self.z=z
def dist(a,b):
return math.sqrt((a.x-b.x)**2+(a.y-b.y)**2+(a.z-b.z)**2)
for i in range(n):
x,y,z,w[i] = map(int,input().split())
nodes[i] = Node(x,y,z)
for i in range(n):
for j in range(n):
dis[i][j] = dist(nodes[i],nodes[j])
#初始化状态
for i in range(n):
dp[1 <<i][i] = 0 #状态只有起点,刚刚走,显然所需能量为0
for i in range(1<<n):#遍历2**n种状态
for j in range(n):#遍历n个点
if (i & (1<<j)):#状态包含点j
for k in range(n):
if not (i & (1<<k)) :#状态不包含点k
dp[i | (1<<k)][k] = min(dp[i | (1<<k) ][k],dp[i][j] + dis[j][k] * w[k])
res = 0x3f3f3f3f3f3f3f3f
pos = (1<<n)-1
for i in range(n):
res = min(res,dp[pos][i])
print("{:.2f}".format(res))
例题2
MOD = int(1e9) + 7
N = int(1e7) + 100
dp = [[0 for i in range(3)] for j in range(3)]
dp[2][0] = 1
n=int(input())
for i in range(1,n+1):
#dp[i][0] 处理完第 i 列,最后一列都填满。
#dp[i][1] 处理完第 i 列,上面有空隙。
#dp[i][2] 处理完第 i 列,下面有空隙。
#将第i列填满
#第i-1列填满了,第i列填I型,否则第i列填L型补全第i-1列
dp[0] = dp[1]
dp[1] = dp[2]
dp[2] = [0 for j in range(3)]
dp[2][0] = (dp[2-1][0] + dp[2-1][1] + dp[2-1][2])%MOD
#还有一种情况从第i-2列传来,横着填两个I型
dp[2][0] = (dp[2][0] + dp[2-2][0])%MOD
#不要将第i列填满
#第i-1列填满了,第i列填L型
dp[2][1] = dp[2-2][0]
dp[2][2] = dp[2-2][0]
#第i-1列没填满,横着放一个I型补全第i-1列
dp[2][1] = (dp[2][1] + dp[2-1][2])%MOD#补全第i-1列下面
dp[2][2] = (dp[2][2] + dp[2-1][1])%MOD#补全第i-1列上面
print(dp[2][0])
例题3
思考:
建立状态为dp[i][s],表示处理到第i行,并且第i行的宠物放置状态为s时的最多的放置数量
N,M=map(int,input().split())
mp=[0]*(N+1)#存储每一行的食物状态
#dp[i][j]表示处理到第i行,第i行的状态为j时的最多宠物数
dp=[[0]*(1<<M) for i in range(1+N)]
ans=0
def check(s):
return (s&(s>>1))==0
def check_p(a,b):
return (a&b)==0
def sum_(s):
cnt=0
while s>0:
cnt+=s&1
s>>=1
return cnt
for i in range(1,N+1):
input_=list(map(int,input().split()))
for j in range(M):
mp[i] |= input_[j]<<j#换位
for k in range(1<<M):#枚举当前行的状态
if check(k) and check_p(k,mp[i]):#判断状态是否合法
total=sum_(k)
for o in range(1<<M):#枚举上一行的状态
if check_p(k,o):#判断是否与上一行冲突
dp[i][k]=max(dp[i][k],dp[i-1][o]+total)
ans=max(ans,dp[i][k])
print(ans)
例题4
m=int(input())
P=list(map(int,input().split()))
MOD=998244353
ans=0
num=[0]*16
#dp[i][j][k]表示处理到第i列时,合法的分布数量,这一维度可以被优化
#j表示第一行的地雷分布,k表示第三行的地雷分布
#第i+1列的第一行如果有雷 则j 表示为 100
#第i 列的第一行如果有雷 则j 表示为 010
#第i-1列的第一行如果有雷 则j 表示为 001
dp=[[0]*8 for i in range(8)]#地雷分布状态有2**3种
#初始状态可能有4种情况,{000,000}、{100,000}、{000,100}、{100,100}
#虚构前面两行(i-1、i)第1行(i+1)状态有四种
dp[0][0]=dp[4][0]=dp[0][4]=dp[4][4]=1
for i in range(8):#计算每一种状态的地雷数
for j in range(3):
if (1<<j) &i:
num[i]+=1
for i in range(1,m+1):
p=P[i-1]
dpp=[[0]*8 for i in range(8)]#滚动数组,初始化要更新的下一列
for a in range(8):
for b in range(8):#枚举第i列的状态
if dp[a][b]==0:
continue
for s in range(2):
for t in range(2):#枚举第i+1列的四种状态
a_s=(a>>1) | (s<<2)
b_t=(b>>1) | (t<<2)
if num[a_s]+num[b_t]==p:#找到满足条件的状态转移
dpp[a_s][b_t]+=dp[a][b]
dpp[a_s][b_t]%=MOD
dp,dpp=dpp,dp
#对于最后一列的第一行的状态有四种011、010、001、000,因为第i+1列已经没有格子了,肯定为0
#16种组合的和就是答案
for i in range(4):
for j in range(4):
ans+=dp[i][j]
ans%=MOD
print(ans)
day3
树形DP
例题1
题目的意思其实就是,给定一棵树,选出一个非空集合,使得对于任意的两个元素a,b都存在一个序列a,v1,v2...,vk,b是这个集合种的元素,并且相邻两点间有一条边,也就时求一个连通块,让这个连通块的和值最大
sys.setrecursionlimit(100000)
n=int(input())
score=[0]+list(map(int,input().split()))
tree=[[]for i in range(n+1)]
ans=0
#dp[i]表示对于结点为i的子树,以i为根的连通块和的最大值
dp=[0 for i in range(n+1)]
for i in range(n-1):
a,b=map(int,input().split())
tree[a].append(b)
tree[b].append(a)
def dfs(u,f):
global ans
dp[u]=score[u]
for i in tree[u]:
if i!=f:#向下搜索子结点
dp[i]=dfs(i,u)
if dp[i]>0:
dp[u]+=dp[i]
ans=max(ans,dp[u])
return dp[u]
dfs(1,0)
print(ans)
树上背包问题
例题1
题目中的物品有依赖关系,并且依赖关系构成了一棵树
在普通背包中,dp[j]表示使用了空间j的情况下的最大价值。但是在树形背包问题中,第一维度时节点的编号,因此dp[i][j]就可以表示对于子树i来说,使用了j空间且满足依赖关系的最大价值,因此只需保证每一个dp[i][j]都选了i节点空间即可
n,V=map(int,input().split())
G=[[] for i in range(n+1)]
v=[0] * (n+1)
w=[0] * (n+1)
#dp[i][j]表示对于子树i来说,使用了j空间且满足依赖关系的最大价值
dp=[[0] * (V+1) for i in range(n+1)]
for i in range(1,n+1):
v[i],w[i],s=map(int,input().split())
G[s].append(i)
def dfs(u,dp,G,v,w,V):
for i in range(v[u],V+1):
dp[u][i]=w[u]#初始化仅仅装根节点时的背包价值
for child in G[u]:#枚举物品组
dfs(child,dp,G,v,w,V)
for j in range(V,v[u]+v[child]-1,-1):#枚举体积,分给子节点的体积,首先已经装了父节点,其次可以装下子节点
for k in range(v[child],j-v[u]+1):#枚举决策,要往背包里放多大体积的东西
dp[u][j]=max(dp[u][j-k]+dp[child][k],dp[u][j])
dfs(0,dp,G,v,w,V)
print(dp[0][V])
换根DP问题
对于一棵树,它的根不一定是1号节点,可能是任意某个点,在某些问题中,需要尝试计算每种情况,最后维护出最大值。如果每次选择一个点进行处理,那么时间复杂度达到了O(n**2),我们可以利用性质将换根的复杂度降低为O(1),步骤如下:
- 以1为根进行一边扫描,并且处理出必要的信息,例如深度,DP值等
- 开始以1为根进行换根,并且向下递归,在递归之前,需要将自己变成子节点的身份
- 进入新的根后,按照根的身份,重新进行转移,并且维护答案
例题1
from collections import defaultdict
import sys
sys.setrecursionlimit(100000)
N=100010
G=defaultdict(list)
depth=[0]*N#存储原始状态的结点深度,就是移动结点所需的代价
Maxdepth=[0]*N#存储每个子树的的最大深度
ans=0
def dfs(u,f,dt):#dt当前结点的深度
global depth,Maxdepth
depth[u]=dt
Maxdepth[u]=0
for v in G[u]:
if v==f:
continue
dfs(v,u,dt+1)#儿子结点深度为父节点深度+1
#原始状态所有子树的最大深度
Maxdepth[u]=max(Maxdepth[v]+1,Maxdepth[u])
def dfs2(u,f):
global ans,depth,Maxdepth
tmpf=0
Mx1=0#以u为根的最大深度
Mx2=0#以u为根的次大深度
for v in G[u]:
#找到当前root与p的距离,此距离最大
tmpf=max(tmpf,Maxdepth[v]+1)
#对比当前状态的最大盈利,和答案,来维护答案,
ans=max(ans,tmpf*k-depth[u]*c)
#换根步骤
pre=Maxdepth[u]#存储原来根的最大深度
for v in G[u]:#遍历子节点,找到以u为根的最大和次大深度
if Maxdepth[v]+1>Mx1:
Mx2=Mx1
Mx1=Maxdepth[v]+1
elif Maxdepth[v]+1>Mx2:
Mx2=Maxdepth[v]+1
for v in G[u]:
if v==f:
continue
if Maxdepth[v]+1==Mx1:
Maxdepth[u]=Mx2#改变原来的值,当此孩子就在最大深度路径上,需要找次大深度
else:
Maxdepth[u]=Mx1#当此孩子不在最大深度路径上,那就变为最大的深度,相当于没变
dfs2(v,u)
Maxdepth[u]=pre#还原
t=int(input())
for _ in range(t):
n,k,c=map(int,input().split())
G.clear()
ans=0
for i in range(1,n):
u,v=map(int,input().split())
G[u].append(v)
G[v].append(u)
dfs(1,0,0)
dfs2(1,0)
print(ans)
子树节点类书上背包
例题1
from collections import defaultdict
import sys
inf=0x3f3f3f3f3f3f3f3f
sys.setrecursionlimit(100000)
G=defaultdict(list)#存储树结构
n,W=map(int,input().split())
a=[0]+list(map(int,input().split()))
w=[inf]*(n+1)
size=[0]*(n+1)#存储每个子树的结点个数
temp=[0]*(n*2)#
#dp[i][j]表示对于子树i来说,使用了j力气的最多气球数,这样复杂度太高
#dp[i][j]表示对于子树i来说,得到j个气球的最低力气(子树结点类树上背包)
dp=[[inf]*(n+1) for i in range(n+1)]
for i in range(1,n):
ui,w[i+1]=map(int,input().split())
G[ui].append(i+1)
def dfs_num(u,f):#得到每个子树的结点个数
size[u]=1
for child in G[u]:
dfs_num(child,u)
size[u]+=size[child]
def dfs(u,dp,G,a,w,W):
dp[u][0]=0#当前子树使用0力气,所得气球必定为0
for child in G[u]:
dfs(child,dp,G,a,w,W)
for i in range(1,size[u]+size[child]+1):
temp[i]=inf
for j in range(size[u]+1):#枚举父亲可以得到的气球,父亲选的气球可以来自任意一个孩子
for k in range(size[child]+1):#枚举儿子可以得到的气球
temp[j+k]=min(temp[j+k],dp[u][j]+dp[child][k])#组合可以得到的总数
for j in range(1,size[u]+1):
dp[u][j]=temp[j]
dp[u][size[u]-1]=min(dp[u][size[u]-1],a[u])#戳破气球
dp[u][size[u]]=min(dp[u][size[u]],w[u])#剪短绳子
dfs_num(1,0)
dfs(1,dp,G,a,w,W)
for i in range(n,-1,-1):
if dp[1][i]<=W:
print(i)
break
day4
数位DP
例题1
from functools import lru_cache
n=int(input())
s = str(n)
'''定义 f(i,mask,is_limit,is_num)表示构造第 i 位及其之后数位的
合法方案数
mask 表示前面选过的数字集合,换句话说,第 i 位要选的数字
不能在 mask中
is_limit 表示当前位是否受到了 n 的约束。若为真,则第i 位
填入的数字至多为 s[i],否则可以是 9。如果在受到约束的
情况下填了 s[i],那么后续填入的数字仍会受到 n 的约束。
is_num 表示 i 前面的数位是否填了数字。若为假,则当前位
可以跳过(不填数字),或者要填入的数字至少为 1;
若为真,则要填入的数字可以从 0 开始。
'''
@lru_cache(None) # 记忆化搜索
def f(i,mask,is_limit,is_num):
#当 i 等于 s 长度时,如果 is_num为真,则表示得到了一个
#合法数字,返回 1,否则返回 0
if i == len(s):
return int(is_num) # is_num 为 True 表示得到了一个合法数字
res = 0
#如果 is_num为假,说明前面没有填数字,
#那么当前也可以不填数字。一旦从这里递归下去,
#is_limit 就可以置为 false 了,因为 最高位不填数字,
#后面无论怎么填都比 n 小。
if not is_num:
res += f(i + 1, mask, False, False)
# 如果前面没有填数字,必须从 1 开始(因为不能有前导零)
low = 0 if is_num else 1
# 如果前面填的数字都和 n 的一样,
#那么这一位至多填 s[i](否则就超过 n 啦)
up = int(s[i]) if is_limit else 9
for d in range(low, up + 1): # 枚举要填入的数字 d
if (mask >> d & 1) == 0: # d 不在 mask 中
#如果 is_num为真,那么当前必须填一个数字。
#枚举填入的数字,根据 is_num 和 is_limit
#来决定填入数字的范围。
res += f(i + 1, mask | (1 << d), is_limit and d == up, True)
return res
#从s[0]开始枚举;开始时mask没有数字;一开始就要受到约束
#因为第一位不能随意填;一开始没填数字
print(f(0, 0, True, False))
例题2
from functools import lru_cache
l=int(input())-1
r=int(input())
s1=str(l)
s2=str(r)
MOD=998244353
def cal(s):
cnt=0
for last in range(1,10):#先指定最后一位
@lru_cache(None)
def f(i,mask,is_limit,is_num):#mask 标记前i为数字的和
if i == len(s)-1:
if not is_num or (is_limit and last>int(s[-1])):#不是数字,数字大于给定数字返回0
return 0
return mask%last==0
res=0
if not is_num:
res+=f(i+1,mask,False,False)
up=int(s[i]) if is_limit else 9
down=0 if is_num else 1
for d in range(down,up+1):
res+=f(i+1,(mask+d)%last,is_limit and d==up,True)
return res
cnt+=f(0,0,1,0)#将指定的每一位最后一位的情况求和
cnt%=MOD
del f#每次枚举最后一位并计算后删除f
return cnt
ans=(cal(s2)-cal(s1))%MOD
print(ans)
day5
对拍
- 对拍是一种验证代码正确性的手段,分为以下几个阶段:
- 生成数据
- 运行暴力代码,得到结果
- 运行正确代码,得到结果
- 对比两个结果
例题
定义如下文件:
chk.py(验证代码)
import os
import sys
while True:
# 运行data.exe, force.exe 和 std.exe
os.system("python data.py > input.txt")
os.system("python force.py < input.txt > fout.txt")
os.system("python std.py < input.txt > sout.txt")
# 对比fout.txt 和 sout.txt
result = os.system("fc fout.txt sout.txt")
if result == 0:
print("AC")
else:
print("WA")
break
data.py(数据生成代码)
import random
# 生成两个随机数
a = random.randint(10e4, 10e5)
b = random.randint(10e4, 10e5)
# 确保a <= b
if a > b:
a, b = b, a # 交换a和b的值
# 输出结果
print(a, b)
force.py(暴力代码)
l,r=map(int,input().split())
ans=0
for i in range(l,r+1):
if '2022' in str(i):
ans+=i
# 计算结果
print(ans)
std.py(需要验证的代码)
from functools import lru_cache
import sys
sys.setrecursionlimit(100000)
l, r = map(int, input().split())
s1 = str(l - 1)
s2 = str(r)
def cal(s):
@lru_cache(None)
def f(i, mask, is_limit, is_num):
if i == len(s):
if is_num:
if '2022' in mask:
return int(mask)
else:
return 0
else:
return 0
res = 0
if not is_num:
res += f(i + 1, mask, False, False)
up = int(s[i]) if is_limit else 9
down = 0 if is_num else 1
for d in range(down, up + 1):
pre = mask
mask = mask + str(d)
res += f(i + 1, mask, is_limit and d == up, True)
mask = pre
return res
ans = f(0, '0', 1, 0)
del f
return ans
total = cal(s2) - cal(s1)
print(total)
常见的数据生成
随机数
import random
# random.randint(a, b) 生成 [a, b] 之间的整数
生成一棵树
import random
def rd(l,r):
return random.randint(l,r)
def tree(n):#生成一颗n个节点的数
for i in range(2,n+1):
print(i,end=" ")
print(rd(1,i-1))
tree(4)
生成一个图
import random
def rd(l,r):
return random.randint(l,r)
def tree(n):#生成一颗n个节点的树
for i in range(2,n+1):
print(i,end=" ")
print(rd(1,i-1))
def tu(n,m):#生成一个n点m边的图
tree(n)
for i in range(n,m+1):
print(rd(1,n),end=" ")
print(rd(1,n))
tu(4,5)
生成一个排列
import random
def pl(n):
a=[i for i in range(1,n+1)]
random.shuffle(a)
for i in range(n):
print(a[i],end=" ")
pl(5)
生成n个不同的数:可以先定义一个集合存储已经生成的数,每生成一个数,检查是否在集合中,如果存在就重新生成,不存在就加入到集合中去
让某个数以概率生成:rd(1,100)小于概率百分比时生成
day6
数论应用
数论定理
欧拉函数
欧拉函数:φ(n)为在1~n中与n互质的数的个数,特别的φ(1)=1
通用定义:对于一个正整数n,可以进行质因数分解,从而达到:
p1,p2,p3...pr都是质数
欧拉函数的计算公式如下:
prime=[]
is_prime=[1]*2050
is_prime[0]=0
is_prime[1]=0
def get_prime(n):
global prime,is_prime
for i in range(2,n+1):
if is_prime[i]:
prime.append(i)
for j in range(i*i,n+1,i):
is_prime[j]=0
get_prime(n)
def phi(x):
res=1
for i in prime:
if i*i>x:
break
if x%i==0:
temp=i-1
x//=i
while x%i==0:
temp*=i
x//=i
res*=temp
if x>1:
res*=(x-1)#最后剩余一个大素数
return res
欧拉定理
费马小定理:如果p是质数,那么对于任意的整数a,满足
欧拉定理:如果正整数a,n互质,那么满足
欧拉定理推论:如果a,n不互质,并且b>φ(n)满足
例题1
import math
sys.setrecursionlimit(20000)
def gcd(a,b):
if b==0:
return a
else:
return gcd(b,a%b)
def get_phi(x):
cnt=0
for i in range(1,x+1):
if gcd(i,x)==1:
cnt+=1
return cnt
def g(x,f):
if x==2023:
if x<=f:
return x
else:
return x%f+f
p=g(x+1,get_phi(f))#求指数
res=1
for i in range(1,p+1):
res=(res*x)%f
if p*math.log(x)>math.log(f):
res+=f
return res
print(g(2,2023)%2023)
裴蜀定理
设a,b为不全为0的整数,那么对于a,b的线性组合xa+yb,xa+yb一定是gcd(a,b)的倍数。
扩展:
{a1,a2,a3...,an},最大约数d=gcd(a1,a2,a3...,an),对于n个数的线性组合,一定都是d的倍数
如果正整数a,b互质,那么对于任意的x,y,满足不能被xa+yb表示的最大整数为a*b-a-b
逆元
如果两个整数满足(a*b)%n=1,那么我们称之为a,b在模n意义下互为逆元
例题2
首先10**9+7是一个质数,根据欧拉定理,即,对于一个质数n,φ(n)=n-1,那么:
那么:
t=int(input())
def fast(n,m,mod):
res=1
n%=mod
while m>0:
if m&1:
res=res*n%mod
n=n*n%mod
m>>=1
return res
for _ in range(t):
n=int(input())
print(fast(n,int(1e9)+5,int(1e9)+7))
组合数学
在很多问题中需要求组合数对10**+7取模的解过,当n和m太大时,利用杨辉三角形的求法时间和空间复杂度都不够,有一种更好的方法。
由于具有模的性质,可以利用公式计算,出发可以用乘以逆元代替
MOD = 10**9 + 7
lim = 2*(10**6) + 10
fc = [0] * lim #表示n!对模数取模的结果
rfc = [0] * lim #表示阶乘的逆元
def qp(x, p=MOD - 2): #快速幂求逆元
res = 1
while p:
if p & 1:
res = res * x % MOD
p >>= 1
x = x * x % MOD
return res
def init():
fc[0] = 1
for i in range(1, lim):
fc[i] = fc[i - 1] * i % MOD
rfc[lim - 1] = qp(fc[lim - 1])
for i in range(lim - 1, 0, -1):
rfc[i - 1] = rfc[i] * i % MOD
def C(n, m):
if n < 0 or m > n or m < 0:
return 0
return fc[n] * rfc[m] % MOD * rfc[n - m] % MOD
def A(n, m):
if m > n or n < 0:
return 0
return fc[n] * rfc[n - m] % MOD
init()
print(A(10,2))
print(C(10,2))
day7
字符串哈希法
对于字符串S=s0s1s2...sn
选择一个小整数作为种子,一般是素数,seed=233
选择一个大整数作为模数,一般是素数,MOD=998244353
对于S的哈希值计算:
维护一个哈希前缀和:
如果要获取S’=s4s5s6的哈希值:
可以得到算式:
例题1
mod=998244353
seed=233
N=int(1e6)+5
p=[0]*N
h=[0]*N
th=[0]*N
def hash(s):
p[0]=1
for i in range(1,len(s)+1):
h[i]=((h[i-1]*seed)%mod+ord(s[i-1]))%mod
p[i]=(p[i-1]*seed)%mod
def get_hash(l,r):
return (h[r]-(h[l-1]*p[r-l+1])%mod+mod)%mod
s=input()
t=input()
hash(s)
for i in range(1,len(t)+1):
th[i]=((th[i-1]*seed)%mod+ord(t[i-1]))%mod
ans=0
for i in range(1,len(s)+1):#枚举左端点
l,r=1,min(len(t),len(s)-i+1)
while l<r:#满足条件的最大值,二分枚举长度
mid=(l+r+1)//2
if get_hash(i,i+mid-1)==th[mid]:
ans=max(ans,mid)
l=mid
else:
r=mid-1
print(ans)
双哈希
双哈希主要是解决冲突问题,如果选择的种子和模数不好,会导致不同的字符串存在相同的哈希值。可以选择两个种子seed1,seed2和两个模数MOD1,MOD2,利用这两个分别计算h1,h2,如果对比两个字符串,就需要比对两个哈希值分别相等。
day8
字符串KMP算法
给定字符串S,T,请在主串S中寻找子串T,S串称之为匹配串,T串称之为模式串
前缀函数:用next数组表示,其中next[i]表示在模式串T中,以i结尾的真子串,与T的前缀匹配的最长长度。
如果在匹配过程中失败,就将T串跳跃到上一个相同的前缀位置,因为本身T维护的next数组是和后缀相同的前缀,既然后缀匹配上了,那么前缀一定可以。
如何得到next数组?利用next数组的性质,在匹配中进行优化,自己匹配自己
例题1
题目要求大小写匹配,那么可以将某个字符串大小写替换,这样就变为完全匹配;对于循环移动的问题,一个常见的转换做法就是将某个字符串扩展为两倍,破环成链。对于S(abcdef)扩展为S‘(abcdefabcdef),对于串T(bcdefa)对应S’第2位开始长度为6的子串,实际上就是S逆时针旋转1位,同时也是顺时针旋转5位的情况。
def get_next(t,tn):#获取next数组
#next[i]表示以i结尾的真子串匹配的最长前缀位置
next=[0]*(tn+1)
j=0
for i in range(2,tn+1):#模式串自己与自己匹配
while j>0 and t[i]!=t[j+1]:
j=next[j]
if t[i]==t[j+1]:#前缀与后缀相同,位置+1
j+=1
next[i]=j#更新最长前缀的位置
return next
def match(s,t,next,sn,tn):
j=0
ans=0x3f3f3f3f3f3f3f3f
for i in range(1,sn+1):#枚举匹配串的每一位
while j>0 and s[i]!=t[j+1]:#不匹配时找到最长前缀的位置
j=next[j]
if s[i]==t[j+1]:
j+=1
if j==tn:
j=next[j]#寻找下一次匹配
start=i-tn+1#契合的左端点
ans=min(ans,start-1)#逆时针方向旋转位数
ans=min(ans,tn-start+1)#顺时针旋转位数
return ans
n=int(input())
s=input()
s=s+s
s=' '+s
t=input()
t=' '+t.swapcase()
next=get_next(t,n)
ans=match(s,t,next,2*n,n)
if ans<=n:
print("Yes")
print(ans)
else:
print("No")
例题2
本题用到的就是next性质,n-next就是T串的最小循环节
def get_next(t,tn):
next=[0]*(tn+1)
j=0
for i in range(2,tn+1):
while j>0 and t[i]!=t[j+1]:
j=next[j]
if t[i]==t[j+1]:
j+=1
next[i]=j
return next
n=int(input())
s=input()
s=' '+s
next=get_next(s,n)
print(n-next[n])
day9
分块思想
分块常用来解决一些区间问题,例如区间加法,区间求和,区间最小值等
例题1
采用分块的思想,将n个元素分为B组,每组元素n/B个,最后一组可能空缺一些元素,但是无关紧要
分完组后,我们维护每个组的一个和值,sum[i]表示第i组的和值,对于每个元素如果它被修改了,那么我们找到对应的组x,对sum[x]进行相应的操作即可
至于查询,如果询问的左右端点在同一个组,直接在区间内暴力循环;如果不在同一组,对应于区间内完整的组,直接查询sum,对于不完整的组暴力循环求解
#bl代表每一个块的左端点
#br代表每一个块的右端点
#Gnum代表组的数量
#Enum代表组内的元素数量
#gid每个元素的组编号
import math
N=int(1e5+100)
bl=[0]*N
br=[0]*N
gid=[0]*N
sum=[0]*N
def block(n):
global Gnum,Enum
Gnum=0
Enum=int(math.sqrt(n))+1#开方求组内元素的个数,+1防0
i=1
while i<=n:
Gnum+=1
bl[Gnum]=i
br[Gnum]=min(n,i+Enum-1)
for j in range(bl[Gnum],br[Gnum]+1):
gid[j]=Gnum
i+=Enum
def get_sum(l,r):
res=0
for i in range(l,r+1):
res+=a[i]
return res
n=int(input())
block(n)
a=[0]+list(map(int,input().split()))
for i in range(1,n+1):#遍历每一个元素,找到对应的组号,初始化组的和值
sum[gid[i]]+=a[i]
m=int(input())
for _ in range(m):
op,x,y=map(int,input().split())
if op==1:
a[x]+=y#修改元素
sum[gid[x]]+=y#修改元素所在组的和值
elif op==2:
if gid[x]==gid[y]:
print(get_sum(x,y))
else:
res=0
for i in range(gid[x]+1,gid[y]):#遍历完整的组sum
res+=sum[i]
res+=get_sum(x,br[gid[x]])#加上左边不完整的组的元素和
res+=get_sum(bl[gid[y]],y)#加上右边不完整的组的元素和
print(res)
例题2
注意到区间范围是10**12,可以将区间分为10**6块,每一块都是从[a*10**6,(a+1)*10**6),每一块的最后六位都是从[000000,999999],这一部分信息可以暴力求出;我们用cnt[i]表示[000000,999999]内封闭图形为i的数量,f(a)表示a的封闭图形的数量,那么对于[a*10**6,(a+1)*10**6)这个区间,我们要求区间内封闭图形数量为k时,满足条件的数量就是cnt[k-f(a)]
import math
onum=[1,0,0,0,1,0,1,0,2,1]
def get_o(x):#计算x包含的封闭图形个数
res=0
res+=onum[x%10]
x//=10
while x>0:
res+=onum[x%10]
x//=10
return res
def get_sum(l,r,k):#暴力枚举
res=0
for i in range(l,r+1):
if get_o(i)==k:
res+=1
return res
cnt_s=[0]*40
for i in range(1,1000001):#预处理区间[000000,999999]
x=i
cnt=0
for j in range(6):
cnt+=onum[x%10]
x//=10
cnt_s[cnt]+=1
l,r,k=map(int,input().split())
st=1000000
if r-l<st:
print(get_sum(l,r,k))
else:
res=0
res+=get_sum(l,(l//st+1)*st-1,k)#左端不完整块
res+=get_sum((r//st)*st,r,k)#右端不完整块
a=l//st+1
b=r//st
for i in range(a,b):
x=get_o(i)
if x<=k:
res+=cnt_s[k-x]
print(res)
例题3
import math
N=int(1e5+100)
bl=[0]*N
br=[0]*N
gid=[0]*N
sum=[0]*N
tag=[0]*N#对于区间修改,需要维护一个tag数组,tag[i]表示第i组的每个元素都需要改变多少
def block(n):
global Gnum,Enum
Gnum=0
Enum=int(math.sqrt(n))+1
i=1
while i<=n:
Gnum+=1
bl[Gnum]=i
br[Gnum]=min(n,i+Enum-1)
for j in range(bl[Gnum],br[Gnum]+1):
gid[j]=Gnum
i+=Enum
def get_sum(l,r):
res=0
for i in range(l,r+1):
res+=a[i]+tag[gid[i]]#暴力求和时需要加上tag
return res
n,Q=map(int,input().split())
block(n)
a=[0]+list(map(int,input().split()))
for i in range(1,n+1):
sum[gid[i]]+=a[i]
for _ in range(Q):
input_=list(map(int,input().split()))
op,x,y=input_[0:3]
if op==1:
A=input_[3]
if gid[x]==gid[y]:#如果修改的区间在一个块中,暴力循环修改
for i in range(x,y+1):
a[i]+=A
sum[gid[i]]+=A
else:
for i in range(gid[x]+1,gid[y]):#修改完整的块
sum[i]+=(br[i]-bl[i]+1)*A
tag[i]+=A
for i in range(x,br[gid[x]]+1):#暴力循环修改左端不完整的块
sum[gid[i]]+=A
a[i]+=A
for i in range(bl[gid[y]],y+1):#暴力循环修改右端不完整的块
sum[gid[i]]+=A
a[i]+=A
elif op==2:
if gid[x]==gid[y]:
print(get_sum(x,y))
else:
res=0
for i in range(gid[x]+1,gid[y]):
res+=sum[i]
res+=get_sum(x,br[gid[x]])
res+=get_sum(bl[gid[y]],y)
print(res)
day10
倍增思想
ST表
st表用于解决可重复贡献问题的数据结构
可重复贡献问题:指对于运算opt,满足x opt x = x,则对应的区间询问就是一个可重复贡献问题。例如,最大值有max(x,x)=x,gcd有gcd(x,x)=x,所以区间最大值、最小值和gcd就是一个可重复贡献问题。像区间和就不具有这个性质,如果求区间和的时候采用预处理的区间重叠了,会导致被计算两次,这是我们所不愿意看到的。另外,opt还必须满足结合律才能使用st表求解。
所谓的预处理,即需要预先处理出对于每个位置i,以i作为左端点,i+2**j-1作为右端点的每个区间的最大值,由于log2(n)较小,可以直接枚举每个j。
f[i][j]表示以i为左端点,区间长度为2**j,也就是是右端点为i+2**j-1
查询的时候选择两个最大的但不超过询问区间的二进制区间即可,虽然有重复,但是不影响答案
例题1
def st(n,a):
for i in range(1,n+1):
f[i][0]=a[i]
for j in range(1,20):#枚举长度,大的指数由小的指数推导
for i in range(1,n-(1<<j)+1+1):#枚举左端点
f[i][j]=max(f[i][j-1],f[i+(1<<(j-1))][j-1])
def query(l,r):
k=int(math.log2(r-l+1))
return max(f[l][k],f[r-(1<<k)+1][k])
n,q=map(int,input().split())
a=[0]+list(map(int,input().split()))
f=[[0]*20 for i in range(int(5e5)+100)]
st(n,a)
for _ in range(q):
l,r=map(int,input().split())
print(query(l,r))
LAC问题
最近公共祖先简称LCA,相关性质:满足结合律,LCA(v1,v2,v3...,vn)=LCA(v1,LCA(v2,v3...,vn))=...
from collections import defaultdict
G=defaultdict(list)
def dfs(u,fa):#倍增思想,先处理出每个点向上跳2**j次的父亲是谁
for i in range(1,20):
#u向上跳2**j次的父亲就是,u向上跳2**(j-1)次的父亲再向上跳2**(j-1)次的父亲
f[u][i]=f[f[u][i-1]][i-1]
for v in G[u]:
if v!=fa:
depth[v]=depth[u]+1#递归前拿到深度
f[v][0]=u#递归前初始化向上跳一次的父亲
dfs(v,u)
def lca(a,b):
if depth[a]<depth[b]:#保证a的深度一直是大于b的深度
a,b=b,a
for i in range(19,-1,-1):#枚举向上跳的高度,使a,b等高
#只有a向上跳之后的深度仍然大于等于b的深度,才转移到a的父亲
if depth[a]-(1<<i)>=depth[b]:
a=f[a][i]
if a==b:#a,b在同一条路径上
return a
for i in range(19,-1,-1):#枚举向上跳的高度,寻找公共祖先
if f[a][i]!=f[b][i]:
a=f[a][i]
b=f[b][i]
return f[a][0]
N=int(1e5)+100
f=[[0]*20 for i in range(N)]
depth=[0]*N
n=int(input())
for i in range(n-1):
p,q=map(int,input().split())
G[p].append(q)
G[q].append(p)
dfs(1,0)
Q=int(input())
for i in range(Q):
A,B=map(int,input().split())
print(lca(A,B))
day11
树状数组
树状数组通常用来解决动态前缀和、逆序对等问题
求和原理:树状数组的原理是基于二进制,每个整数可以用二进制来表示,例如7=2**2+2**1+2**0,按照倍增的思量,将7按照2进制分块,那么计算区间和值的次数不超过log2(n)次
引入lowbit(x):
- lowbit(x)表示在二进制表示下,x的最低位的1所代表的数值
- 7=0b111 => lowbit(7)=0b1=1
- 根据二进制的补码与原码的关系可得出lowbit(x)=x&(-x)
可以发现lowbit(x)的值就是前缀区间[1,x]的最后一块的长度
修改原理:某个数如果对应于某个位置x,对它加上lowbit(x),那么就会迭代到更大的一块区间,并且这段区间能覆盖原来的区间。如果修改了某个节点就需要修改覆盖节点的全部区间。
例题1
def lowbit(x):#得到以x结尾的区间长度
return x&(-x)
def get(pos):#得到[1,pos]的和
res=0
while pos>0:
res+=f[pos]
pos-=lowbit(pos)#找到下一个区间的结尾元素
return res
def update(pos,v):#将包含位置pos的区间都加上v
while pos<=n:
f[pos]+=v
pos+=lowbit(pos)
n=int(input())
f=[0]*(int(1e5)+100)#f[i]表示用i结尾的lowbit(i)长度区间的和值
input_=[0]+list(map(int,input().split()))
for i in range(1,n+1):
update(i,input_[i])
m=int(input())
for _ in range(m):
op,a,b=map(int,input().split())
if op==1:
update(a,b)
else:
print(get(b)-get(a-1))
例题2
该题较上题从单点修改改为了区间修改,对于一个数组{a1,a2,a3,...,an},做出差分数组{b1,b2,b3,...,bn}满足b[i]=a[i]-a[i-1],规定a[0]=0,可以得到a[p]=b[1]+b[2]+...+b[p],修改区间[l,r]的值,增加x,只需要对b[l]加x,b[r+1]减x
如何求和?可以利用二阶前缀和:
def lowbit(x):
return x&(-x)
#fi表示用i结尾的lowbit(i)长度区间的和值
def get(pos,f):#得到[1,pos]的和
res=0
while pos>0:
res+=f[pos]
pos-=lowbit(pos)
return res
def update(pos,v,f):#将包含位置pos的区间都加上v
while pos<=n:
f[pos]+=v
pos+=lowbit(pos)
n,m=map(int,input().split())
fb=[0]*(int(1e5)+100)
fc=[0]*(int(1e5)+100)
input_=[0]+list(map(int,input().split()))
for i in range(1,n+1):
update(i,input_[i]-input_[i-1],fb)
update(i,i*(input_[i]-input_[i-1]),fc)
for _ in range(m):
app=list(map(int,input().split()))
op,a,b=app[0:3]
if op==1:
k=app[3]
update(a,k,fb)
update(b+1,-k,fb)
update(a,a*k,fc)
update(b+1,-(b+1)*k,fc)
else:
l,r=a,b
r=(r+1)*get(r,fb)-get(r,fc)
l=l*get(l-1,fb)-get(l-1,fc)
print(r-l)
逆序对
例题1
题目的意思:给定一个数组{2,3,4,1,2}
对于位置4,值为1,在该位置左边有三个位置比它大。要求出所有的逆序对,枚举每个位置,找到该位置左边有多少个值比当前值大,然后求和即可。
树状数组维护的是前缀和,可以将值转化为下标,也就是变为权值梳妆数组。按照下标从左往右枚举,每次枚举到一个位置x,先查询树状数组中有多少个值比它大,计入答案,然后将x的值放入树状数组即可。
def lowbit(x):
return x&(-x)
def get(pos):#得到[1,pos]有几个数比自己小
res=0
while pos>0:
res+=f[pos]
pos-=lowbit(pos)
return res
def update(pos,v):#将包含位置pos的区间都加上v
while pos<=n:
f[pos]+=v
pos+=lowbit(pos)
n=int(input())
f=[0]*(int(1e5)+100)
input_=[0]+list(map(int,input().split()))
ans=0
for i in range(1,n+1):
ans+=(i-1)-get(input_[i])#i-1表示假设位置左边的数都比它大,减去比自己小的数就是逆序对的个数
update(input_[i],1)
print(ans)
day12
线段树
线段树主要用于解决区间问题,对区间的操作,和区间的查询都可以通过线段树解决。
线段树广义上被归类为二叉搜索树,适用于满足区间加法的问题。
定义:
- 线段树每个节点都代表一个区间。
- 线段树具有唯一的根节点,代表统计的区间是[1,n]。
- 线段树的每个叶子节点代表的都是一个长度为1的区间。
- 如果当前节点统计的区间为[l,r],定义mid=(l+r)/2,那么左儿子区间就是[l,mid],右儿子区间就是[mid+1,r]。
区间结构如下图:
可以将其看作一棵树:
习惯上,根节点编号为x,左右儿子编号分别为2*x,2*x+1
建树后,树上有时会出现一些空节点,如下图:
满二叉树的节点数量为2n-1,但是由于最后可能存在一排空节点,所以节点数量一般定义为4n
例题1
建树:线段树的每个节点维护的是区间信息,一般来说问什么就是维护什么。上题询问的是区间和,那么每个节点都需要维护相对应的区间和值。
例如给定一个数组A={2,4,4,5,3}
其树的信息为:
dfs建树流程:
- 维护信息,id 节点编号,l 当前节点左边界,r 当前节点右边界
- 判断是否为叶子节点
- 左右儿子递归
- 合并左右儿子的信息
节点更新:
假设2号位置发生了更新,需要加3
需要从跟节点,信息不停更新。
dfs更新步骤:
- 判断是否叶子节点
- 判断修改的位置在左节点还是右节点,进入递归
- 合并左右节点的信息
对于需要查询的区间[3,5]:
- 如果目标区间与当前区间无关联,那返回一个无贡献值
- 如果目标区间全包含当前区间,返回当前节点的信息
- 如果只包含部分区间,则进行递归
#dfs建树
def create(id,l,r,a):
if l==r:#到达叶节点,赋值
val[id]=a[l]
return
mid=(l+r)//2
#递归合并
create(id*2,l,mid,a)
create(id*2+1,mid+1,r,a)
val[id]=val[id*2]+val[id*2+1]
#dfs修改
def update(pos,v,id,l,r):
if l==r:#到达叶节点,更新
val[id]+=v
return
mid=(l+r)//2
if pos<=mid:#判断修改位置,分别递归
update(pos,v,id*2,l,mid)
else:
update(pos,v,id*2+1,mid+1,r)
#合并
val[id]=val[id*2]+val[id*2+1]
#dfs查询区间[L,R]
def get(L,R,id,l,r):
if L>r or R<l:
return 0
if L<=l and R>=r:
return val[id]
mid=(l+r)//2
left=get(L,R,id*2,l,mid)
right=get(L,R,id*2+1,mid+1,r)
return left+right
def init(n,a):
create(1,1,n,a)
def add(pos,v):
update(pos,v,1,1,n)
def query(L,R):
return get(L,R,1,1,n)
N=int(1e5)+100
val=[0]*(N*4)
n=int(input())
a = [0] + list(map(int, input().split()))
init(n,a)
m=int(input())
for _ in range(m):
op,a,b=map(int,input().split())
if op==1:
add(a,b)
else:
print(query(a,b))
例题2
为了解决区间加法的问题,采用一种技巧:懒标记,也叫延迟标记
对于每个节点,维护的信息添加一个变量tag,表示对于当前区间,所有的元素均加上了tag。当修改的区间完全包含当前区间时,将标记增加,将维护的区间和值修改后就返回。例如[1,3]所有元素加3:
此时2号节点的所有子孙都没有更改,当遇到修改或者拆分区间的时候。例如查询[1,2],到达2号节点后仍需递归,但是它的子孙没有加上标记,所有需要下传:
每次下传,都需要将当前节点的标记清空
#dfs建树
def create(id,l,r,a):
tag[id]=0
if l==r:#到达叶节点,赋值
val[id]=a[l]
return
mid=(l+r)//2
#递归合并
create(id*2,l,mid,a)
create(id*2+1,mid+1,r,a)
val[id]=val[id*2]+val[id*2+1]
#dfs修改
def update(L,R,v,id,l,r):
if L>r or R<l:
return
if L<=l and R>=r:
tag[id]+=v#完全包含,tag表示区间内所有元素加v
val[id]+=v*(r-l+1)
return
mid=(l+r)//2
pushdown(id,(mid-l+1),(r-(mid+1)+1))#修改时标记下传
update(L,R,v,id*2,l,mid)
update(L,R,v,id*2+1,mid+1,r)
#合并
val[id]=val[id*2]+val[id*2+1]
#dfs查询区间[L,R]
def get(L,R,id,l,r):
if L>r or R<l:
return 0
if L<=l and R>=r:
return val[id]
mid=(l+r)//2
pushdown(id,(mid-l+1),(r-(mid+1)+1))#差分时标记下传
left=get(L,R,id*2,l,mid)
right=get(L,R,id*2+1,mid+1,r)
return left+right
#懒标记
#tag下传
def pushdown(id,l,r):
tag[2*id]+=tag[id]
tag[2*id+1]+=tag[id]
val[2*id]+=l*tag[id]
val[2*id+1]+=r*tag[id]
tag[id]=0
def init(n,a):
create(1,1,n,a)
def add(L,R,v):
update(L,R,v,1,1,n)
def query(L,R):
return get(L,R,1,1,n)
N=int(4e5)+100
val=[0]*(N*4)
tag=[0]*(N*4)
n,q=map(int,input().split())
a = [0] + list(map(int, input().split()))
init(n,a)
for _ in range(q):
app=list(map(int,input().split()))
op,x,y=app[0:3]
if op==1:
k=app[3]
add(x,y,k)
else:
print(query(x,y))
day13
种类并查集
例题1
def find(x):
if x==fa[x]:
return x
else:
fa[x]=find(fa[x])
return fa[x]
n,m=map(int,input().split())
fa=[i for i in range(n*2+1)]
mm=[]
for i in range(m):
a,b,c=map(int,input().split())
mm.append((a,b,c))
mm.sort(key=lambda x:x[2],reverse=True)
for i in range(m):
a=mm[i][0]
b=mm[i][1]
if find(a)==find(b):#a和b在1号房,冲突
ans=mm[i][2]
break
if find(a+n)==find(b+n):##a和b在2号房,冲突
ans=mm[i][2]
xa=find(a)#条件1:将a放进1号房
xb=find(b+n)#条件2:将b放进2号房
fa[xa]=xb#条件1可以推出条件2
#相同的
xa=find(a+n)#条件1:将a放进2号房
xb=find(b)#条件2:将b放进1号房
fa[xa]=xb#条件2可以推出条件1
print(ans)
day14
计数排序
def sort(a,n):
if n<2:
return a
maxa=0
for i in range(n):
maxa=max(maxa,a[i])
count=[0]*(maxa+1)
for i in range(n):
count[a[i]]+=1
for i in range(1,maxa+1):
count[i]+=count[i-1]
out=[0]*n
for i in range(0,n):
out[count[a[i]]-1]=a[i]
#out[(n-1)-(count[a[i]]-1)]=a[i]
count[a[i]]-=1
return out
a=[2,2,1,5,3]
a=sort(a,5)
print(a)