HDU 4913 Least common multiple
复杂线段树
传送门:HustOJ
题意
给你一些数,每个数都是 xi=2ai∗3bi 的形式。问你每个这些数的每个子集的最小公倍数之和。
思路
先膜%%%
线段树。
等价于求所有子集里面最大的a和b。
如果只有a,那么做法是按a排序,然后递推计算。因为算到a[i]时,它肯定是目前最大的。最大为a[i]的子集有
2i
个。递推计算即可。
现在有了a和b,依然按a排序,将b离散化后插入线段树。从前往后枚举,到i时,前面的数的a全都小于a[i],用线段树查询有多少小于b[i]的,记为x。那么答案应该加上 2x∗2ai∗3bi 。现在线段树中维护的第一个值是每个b(离散化后)目前出现的次数。所以更新时,当前b[i]的次数要+1。
对于当前b[i]后面的b,我们怎么统计答案呢?线段树每个叶子中再维护一个值,意义为:有多少个集合,他们的b的最大值是该叶子代表的b(实际维护集合个数* 3b )。pushup时对此值求和。那么将当前数加入考虑范围和内后,他不会影响包含比他b值大的那些集合的最小公倍数,而且这些集合最小公倍数的a值依然为当前a值。我们只需要求出:比当前b大的、已经出现过的b的:集合个数* 3b 。答案加上2^x*sum。
综上:每扫到一个i:查询出现过的比它小的个数;查询出现过的比他大的集合个数* 3b 。
结合查询说更新。每扫到一个i:首先当前的b出现过了。cnt++,并向上更新。比他大的b的sum值(表示:集合个数* 3b )要乘2。因为当前这个b选不选对于包含比它大的b的集合的最小公倍数是没有影响的。换句话说,假设有一个出现过的比他大的b,记为bx。以bx为最大b值的集合数目=原来以bx为最大b值的集合数目*(添加bi与不添加bi)。另外这个bi的sum值(目前一定是0)要变成:以他为最大b值的集合个数* 3b 。
综上:更新分为单点与区间。单点就是把这个b的cnt++,sum变成 2x∗3bi ;区间更新就是把这个b后面的sum值全乘以2。
再说一些细节。线段树再开一个成员,只对叶子有效。表示他所代表的b值的 3b ,这样在更新时比较方便。这个跟离散化的方式有关,弄懂了注意别写错了。我在半懂不懂时瞎写错了几次。
取模与溢出。注意先强制转型,将中间结果用longlong存在取模。比如a=b*c%mod。这样b*c也是可能溢出为负的。要用:a=(LL)b*c%d。
代码
#include<iostream>
#include<cstdio>
#include<cstring>
#include<cstdlib>
#include<cmath>
#include<algorithm>
#include<queue>
#include<list>
#include<stack>
#define lson l,m,rt<<1
#define rson m+1,r,rt<<1|1
using namespace std;
typedef long long LL;
const int MAXN=100007;
const int mod=1000000007;
struct Num
{
int i;
int a; int b;
int c;
}num[MAXN];
int lsh[MAXN];
bool cmp1(Num a, Num b)
{
if(a.b==b.b) return a.i<b.i;
else return a.b<b.b;
}
bool cmp2(Num a, Num b)
{
return a.a<b.a;
}
LL mypow(int a, int b)
{
LL tmp=1;
LL aa=a;
while(b)
{
if(b&1)
tmp=(tmp*aa)%mod;
b>>=1;
aa=(aa*aa)%mod;
}
return tmp%mod;
}
struct STree
{
int sum;
int cnt;
int lazy;
int val;
}stree[MAXN*4];
void pushup(int rt)
{
stree[rt].sum=(stree[rt<<1].sum+stree[rt<<1|1].sum)%mod;
stree[rt].cnt=stree[rt<<1].cnt+stree[rt<<1|1].cnt;
}
void build(int l, int r, int rt)
{
stree[rt].sum=0;
stree[rt].lazy=0;
stree[rt].cnt=0;
if(l==r)
{
stree[rt].val=mypow(3, lsh[l]);
return;
}
int m=(l+r)>>1;
build(lson);
build(rson);
pushup(rt);
}
void pushdown(int rt)
{
if(stree[rt].lazy)
{
stree[rt<<1].sum=(stree[rt<<1].sum*(mypow(2, stree[rt].lazy)))%mod;
stree[rt<<1|1].sum=(stree[rt<<1|1].sum*(mypow(2, stree[rt].lazy)))%mod;
stree[rt<<1].lazy+=stree[rt].lazy;
stree[rt<<1|1].lazy+=stree[rt].lazy;
stree[rt].lazy=0;
}
}
void update1(int L, int R, int l, int r, int rt)
{
if(L<=l&&r<=R)
{
stree[rt].sum=(stree[rt].sum*2)%mod;
stree[rt].lazy++;
return;
}
pushdown(rt);
int m=(l+r)>>1;
if(L<=m) update1(L, R, lson);
if(R>m) update1(L, R, rson);
pushup(rt);
}
void update2(int pos, int v, int l, int r, int rt)
{
if(l==r)
{
v=((LL)v*stree[rt].val)%mod;
stree[rt].cnt++;
stree[rt].sum=(stree[rt].sum+v)%mod;
return;
}
pushdown(rt);
int m=(l+r)>>1;
if(pos<=m) update2(pos, v, lson);
else update2(pos, v, rson);
pushup(rt);
}
LL query1(int L, int R, int l, int r, int rt)
{
if(L<=l&&r<=R)
{
return stree[rt].sum%mod;
}
pushdown(rt);
int m=(l+r)>>1;
LL res=0;
if(L<=m) res=(res+query1(L, R, lson))%mod;
if(m<R) res=(res+query1(L, R, rson))%mod;
return res;
}
int query2(int L, int R, int l, int r, int rt)
{
if(L<=l&&r<=R)
return stree[rt].cnt;
pushdown(rt);
int m=(l+r)>>1;LL res=0;
if(L<=m) res=(res+query2(L, R, lson))%mod;
if(m<R) res=(res+query2(L, R, rson))%mod;
return res;
}
int main()
{
freopen("in.txt", "r", stdin);
freopen("ou1.txt", "w", stdout);
int n;
while(scanf("%d", &n)==1)
{
memset(lsh, 0, sizeof(lsh));
memset(num, 0, sizeof(num));
for(int i=1;i<=n;i++)
{
num[i].i=i;
scanf("%d%d", &num[i].a, &num[i].b);
}
sort(num+1, num+1+n, cmp1);
int maxb;
int now=1;num[1].c=now;lsh[now]=num[1].b;
for(int i=2;i<=n;i++)
{
if(num[i].b==num[i-1].b) num[i].c=now;
else
{
num[i].c=++now;
lsh[now]=num[i].b;
}
}
maxb=num[n].c;
build(1, maxb, 1);
sort(num+1, num+n+1, cmp2);
LL res=0;
for(int i=1;i<=n;i++)
{
LL tmp=0;LL cnt=0;
cnt=(mypow(2, num[i].a)%mod*mypow(3, num[i].b)%mod)%mod;
if(num[i].c>1)
cnt=(cnt*(mypow(2, tmp=query2(1, num[i].c-1, 1, maxb, 1))%mod))%mod;
cnt=(cnt+(LL)(mypow(2, num[i].a)%mod*query1(num[i].c, maxb, 1, maxb, 1))%mod)%mod;
update1(num[i].c, maxb, 1, maxb, 1);
update2(num[i].c, mypow(2, tmp), 1, maxb, 1);
res+=cnt;
res%=mod;
}
printf("%lld\n", res);
}
return 0;
}