题目
https://gmoj.net/senior/#main/show/6857
https://www.luogu.com.cn/problem/P7077
题解
这题的思路比较清奇,不是像许多数据结构题一样在原数列上建一个数据结构,而是在操作上建一个数据结构。
如果直接在原数列上建数据结构维护修改的标记,就会发现加法标记互不相容。比如当前节点上有标记 a 1 + x a_1+x a1+x,现在传来了一个标记 a 2 + y a_2+y a2+y,那么是无法让它们共存的。
但是看到乘法是全局操作的,可以全部累积到加法上面,且操作形成一个 D A G DAG DAG,这是否暗示我们要换一种思路?
现在不妨考虑在操作上建一个数据结构,并尝试把乘法全部累积到加法上面去(一个加法操作受它之后所有乘法操作的影响)。
为了方便处理,可以把最后那个操作序列也看成第3类函数。现在把
D
A
G
DAG
DAG建出来(以样例2为例,边
x
→
y
x\to y
x→y上的标号表示作为父亲的函数
x
x
x的所有儿子中
y
y
y的顺序):
在每个节点上,维护两个标记:一个表示它调用的乘法操作的乘积之和(相同的乘法操作可能被计算了多次),一个表示它被调用的次数(注意,这里的调用次数包括了它以后的乘法操作的乘积)。
把所有节点按照 出度 拓扑排序一下,那么第一个标记要按照拓扑序从小到大维护,第二个标记从大到小维护, t a g 2 g k , i tag2_{g_{k,i}} tag2gk,i每次都加上 t a g 2 k ⋅ ∏ j = i + 1 c k t a g 1 g k , i tag2_{k}\cdot\prod_{j=i+1}^{c_k} tag1_{g_{k,i}} tag2k⋅∏j=i+1cktag1gk,i就行了。
CODE
#include<vector>
#include<cstdio>
using namespace std;
#define P 998244353
#define M 1100005
#define N 100005
int a[N],fir[N],nex[M],to[M],_fir[N],_nex[M],_to[M],deg[N];
int data[N],type[N],p[N],val[N],times[N],tag[N],m,s;
inline void add(int &x,int y){x+=y;if(x>=P) x-=P;}
inline void inc(int x,int y)
{
++deg[x],to[++s]=y,nex[s]=fir[x],fir[x]=s;
_to[s]=x,_nex[s]=_fir[y],_fir[y]=s;
}
inline void bfs()
{
int head=0,tail=0,u;
for(int i=1;i<=m;++i) if(!deg[i]) data[++tail]=i;
while(head<tail)
{
u=data[++head];
for(int i=_fir[u];i;i=_nex[i])
if(!--deg[_to[i]]) data[++tail]=_to[i];
}
}
int main()
{
freopen("call.in","r",stdin);
freopen("call.out","w",stdout);
int n,q,x,y;
scanf("%d",&n);
for(int i=1;i<=n;++i) scanf("%d",a+i);
scanf("%d",&m);
for(int i=1;i<=m;++i)
{
scanf("%d",type+i);
if(type[i]==1) scanf("%d%d",p+i,val+i);
else if(type[i]==2) scanf("%d",times+i);
else
{
scanf("%d",&x);
for(int j=1;j<=x;++j) scanf("%d",&y),inc(i,y);
}
}
scanf("%d",&q),++m;
for(int i=1;i<=q;++i) scanf("%d",&x),inc(m,x);
bfs();
for(int i=1,u;i<=m;++i)
{
u=data[i];
if(type[u]!=2) times[u]=1;
for(int j=fir[u];j;j=nex[j])
times[u]=1LL*times[u]*times[to[j]]%P;
}
tag[m]=1;
for(int i=m,u;i;--i)
{
u=data[i];
for(int j=fir[u],prod=tag[u];j;j=nex[j])
{
add(tag[to[j]],prod);
prod=1LL*prod*times[to[j]]%P;
}
}
for(int i=1;i<=n;++i) a[i]=1LL*a[i]*times[m]%P;
for(int i=1;i<=m;++i) if(type[i]==1)
add(a[p[i]],1LL*tag[i]*val[i]%P);
for(int i=1;i<=n;++i) printf("%d ",a[i]);
puts("");
return 0;
}