Description
一个长度为n的序列a[i]和m个集合Si,每个集合都用一些编号表示,这些编号对应的是序列中的元素,有两种操作:
1.查询第k个集合中元素之和
2.把第k个集合中对应a序列中的元素全部加上x
Input
第一行三个整数n,m,q分别表示序列长度,集合数和操作数,之后n个整数a[i]表示该序列,之后m行第i行首先输入一个整数num表示第i个集合的元素个数,然后输入num个编号对应a序列中的num个元素,最后q行每行表示一个操作,? k表示查询第k个集合中元素之和,+ k x表示把第k个集合中对应a序列中的元素全部加上x
(1<=n,m,q<=1e5,|a[i]|<=1e8,|x|<=1e8,保证m个集合中元素之和不超过1e5)
Output
对于每次查询操作输出一个答案
Sample Input
5 3 5
5 -5 5 1 -4
2 1 2
4 2 1 4 5
2 2 5
? 2
+ 3 4
? 1
+ 2 1
? 2
Sample Output
-3
4
9
Solution
把集合分类,称元素个数超过sqrt(n)的集合为重集合,不超过sqrt(n)的集合为轻集合,轻集合直接更新,重集合记录更新值,记录第i个集合与第j个重集合的交集元素个数cnt[i][j],第i个重集合中元素之和sum[i],第i个重集合中元素的更新值add[i],对轻重集合的两种操作:
第i个轻集合更新:直接更新该集合对应的a序列元素,然后更新重集合的sum值,第j个重集合的sum[j]+=x*cnt[i][j],时间复杂度O(sqrt(n))
第i个重集合更新:累加更新值,add[i]+=x,时间复杂度O(1)
第i个轻集合查询:直接累加该集合对应的a序列元素,然后把重集合延迟的更新累加,即累加上cnt[i][j]*add[j],时间复杂度O(sqrt(n))
第i个重集合查询:本身的值累加上重集合延迟的更新,即sum[i]+cnt[i][j]*add[j],时间复杂度O(sqrt(n))
这样q次操作时间复杂度为O(qsqrt(n))
Code
#include<cstdio>
#include<iostream>
#include<cstring>
#include<algorithm>
#include<cmath>
#include<vector>
#include<queue>
#include<map>
#include<set>
#include<ctime>
using namespace std;
typedef long long ll;
#define INF 0x3f3f3f3f
#define maxn 111111
int n,m,q;
ll a[maxn];
int h[maxn];//h[i]=1:第i个集合是重集合,h[i]=0:第i个集合是轻集合
int cnt[maxn][356];//cnt[i][j]是第i个集合和第j个重集合的交集元素个数
int res,id[maxn];//res表示重集合的数量,id[i]表示第i个重集合的编号
ll sum[maxn];//sum[i]表示第i个集合(重)中元素总和
ll add[maxn];//add[i]表示第i个集合(重)中元素更新值
vector<int>g[maxn],gg[maxn];//g[i]表示第i个集合,gg[i]表示第i个元素所在集合
int main()
{
while(~scanf("%d%d%d",&n,&m,&q))
{
for(int i=1;i<=n;i++)scanf("%I64d",&a[i]);
memset(cnt,0,sizeof(cnt));
memset(sum,0,sizeof(sum));
memset(add,0,sizeof(add));
for(int i=1;i<=m;i++)g[i].clear();
for(int i=1;i<=n;i++)gg[i].clear();
int nn=sqrt(n+0.5),res=0;
for(int i=1;i<=m;i++)
{
int num,temp;
scanf("%d",&num);
if(num>=nn)h[i]=1,id[++res]=i;
else h[i]=0;
for(int j=0;j<num;j++)
{
scanf("%d",&temp);
g[i].push_back(temp);
if(h[i])sum[i]+=a[temp],gg[temp].push_back(res);
}
}
for(int i=1;i<=m;i++)
for(int j=0;j<g[i].size();j++)
{
int u=g[i][j];
for(int k=0;k<gg[u].size();k++)
{
int v=gg[u][k];
cnt[i][v]++;
}
}
while(q--)
{
int k,x;
char op[3];
scanf("%s%d",op,&k);
if(op[0]=='?')
{
ll ans=0;
if(h[k])
{
ans=sum[k];
for(int i=1;i<=res;i++)ans+=add[id[i]]*cnt[k][i];
}
else
{
for(int i=0;i<g[k].size();i++)ans+=a[g[k][i]];
for(int i=1;i<=res;i++)ans+=add[id[i]]*cnt[k][i];
}
printf("%I64d\n",ans);
}
else
{
scanf("%d",&x);
if(h[k])add[k]+=x;
else
{
for(int i=0;i<g[k].size();i++)a[g[k][i]]+=x;
for(int i=1;i<=res;i++)
sum[id[i]]+=x*cnt[k][i];
}
}
}
}
return 0;
}