给定一个集合,集合内的每个元素都是2^ai*3^bi,问该集合的所有子集的最小公倍数的和。
集合大小不超过10^5,ai和bi不超过10^9。
首先将集合内元素按照bi排序。这样就只用考虑ai不用考虑bi了。
定义状态dp[i][j]为集合的前i元素的任意子集的最小公倍数(不考虑3)为2^j的方案数。
则如果j<a[i],dp[i][j]=dp[i-1][j]。如果j>a[i],dp[i][j]=2*dp[i-1][j]。如果j==a[i],dp[i][j]=dp[i-1][j]+dp[i-1][0]+dp[i-1][1]+...+dp[i-1][j]。
每次求出最后一个选i这个元素的所有的最小公倍数的和,加到结果里就可以了。
又因为dp[i]和dp[i-1]之间的关系是,每次前一部分不动,中间有一个值会增加,后边一部分是乘二,并且我们每次只关心某一部分的方案数总和以及这一部分的方案数与对应的2^j的积的和。所以我们可以用一个线段树来代替整个dp数组。线段树的范围时0~10^9,即dp[i][j]中的j那部分。
又因为集合最多只有10^5个元素,所以线段树上实际仅有10^5个位置有效,所以可以离散化。
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <map>
using namespace std;
const int mod=1000000007;
inline int pow(int x,int y) {
while (y>=500000003) y-=500000003;
int ans=1;
while (y) {
if (y&1) ans=(long long)ans*x%mod;
x=(long long)x*x%mod;
y>>=1;
}
return ans;
}
inline int in() {
char c=getchar();
while (c<'0'||c>'9') c=getchar();
int ans=0;
while (c>='0'&&c<='9') {
ans=ans*10+c-'0';
c=getchar();
}
return ans;
}
map<int,int>c;
map<int,int>e;
int d[100000];
struct Number {
int a,b;
void read() {
a=in();b=in();
}
friend bool operator < (const Number &a,const Number &b) {
return a.b<b.b;
}
};
Number a[100000];
int n,m;
struct SeqTreeNode {
int sum,mulSum;
int mul;
SeqTreeNode *ls,*rs;
SeqTreeNode *clear() {
sum=mulSum=0;
mul=1;
return this;
}
void down() {
ls->multipy(mul);
rs->multipy(mul);
mul=1;
}
void multipy(int x) {
sum=(long long)sum*x%mod;
mulSum=(long long)mulSum*x%mod;
mul=(long long)mul*x%mod;
}
void repair() {
sum=(ls->sum+rs->sum)%mod;
mulSum=(ls->mulSum+rs->mulSum)%mod;
}
};
SeqTreeNode b[4000000],*bp,*root;
SeqTreeNode *maketree(int l,int r) {
SeqTreeNode *ans=bp++;
ans->clear();
if (l==r) {
ans->ls=ans->rs=NULL;
} else {
int t=(l+r)>>1;
ans->ls=maketree(l,t);
ans->rs=maketree(t+1,r);
}
return ans;
}
void set(SeqTreeNode *from,int l,int r,int i,int x) {
if (l==r) {
int tmp=pow(2,e[l]);
from->sum=(from->sum+x)%mod;
from->mulSum=(from->mulSum+(long long)tmp*x)%mod;
} else {
int t=(l+r)/2;
from->down();
if (i<=t) set(from->ls,l,t,i,x);
else set(from->rs,t+1,r,i,x);
from->repair();
}
}
void setMul2(SeqTreeNode *from,int l,int r,int ll,int rr) {
//printf("--%d %d %d %d %d\n",l,r,ll,rr,from->sum);
if (l==ll&&r==rr) {
from->multipy(2);
} else {
int t=(l+r)/2;
from->down();
if (rr<=t) setMul2(from->ls,l,t,ll,rr);
else if (ll>t) setMul2(from->rs,t+1,r,ll,rr);
else {
setMul2(from->ls,l,t,ll,t);
setMul2(from->rs,t+1,r,t+1,rr);
}
from->repair();
}
}
int getSum(SeqTreeNode *from,int l,int r,int ll,int rr) {
if (from->sum==0) return 0;
if (l==ll&&r==rr) {
return from->sum;
} else {
int t=(l+r)/2;
from->down();
if (rr<=t) return getSum(from->ls,l,t,ll,rr);
else if (ll>t) return getSum(from->rs,t+1,r,ll,rr);
else {
return (getSum(from->ls,l,t,ll,t)+getSum(from->rs,t+1,r,t+1,rr))%mod;
}
}
}
int getMulSum(SeqTreeNode *from,int l,int r,int ll,int rr) {
if (from->sum==0) return 0;
if (l==ll&&r==rr) {
return from->mulSum;
} else {
int t=(l+r)/2;
from->down();
if (rr<=t) return getMulSum(from->ls,l,t,ll,rr);
else if (ll>t) return getMulSum(from->rs,t+1,r,ll,rr);
else {
return (getMulSum(from->ls,l,t,ll,t)+getMulSum(from->rs,t+1,r,t+1,rr))%mod;
}
}
}
int main() {
int i;
while (scanf("%d",&n)!=EOF) {
for (i=0;i<n;i++) {
a[i].read();
d[i]=a[i].a;
}
sort(d,d+n);
sort(a,a+n);
c.clear();
e.clear();
m=0;
if (d[0]!=0) {
e[m]=0;
c[0]=m++;
}
for (i=0;i<n;i++) {
if (i==0||d[i]!=d[i-1]) {
e[m]=d[i];
c[d[i]]=m++;
}
}
int ans=0;
bp=b;
root=maketree(0,m-1);
set(root,0,m-1,0,1);
for (i=0;i<n;i++) {
int tmp=c[a[i].a],lsum,rsum;
lsum=getSum(root,0,m-1,0,tmp);
if (tmp+1<m) rsum=getMulSum(root,0,m-1,tmp+1,m-1);
else rsum=0;
ans=(ans+
(long long)pow(3,a[i].b)*(rsum+(long long)pow(2,a[i].a)*lsum%mod)%mod
)%mod;
set(root,0,m-1,tmp,lsum);
if (tmp+1<m) setMul2(root,0,m-1,tmp+1,m-1);
}
printf("%d\n",ans);
}
return 0;
}
之前的动态建树的TLE了...改成离散化+静态建树后顺利通过...
动态建树版本如下..再也不敢懒得写离散化了..实测虽然只是常数差距,但是差距挺大的...
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
const int mod=1000000007;
const int L=0;
const int R=1000000000;
inline int pow(int x,int y) {
int ans=1;
while (y) {
if (y&1) ans=(long long)ans*x%mod;
x=(long long)x*x%mod;
y>>=1;
}
return ans;
}
inline int in() {
char c=getchar();
while (c<'0'||c>'9') c=getchar();
int ans=0;
while (c>='0'&&c<='9') {
ans=ans*10+c-'0';
c=getchar();
}
return ans;
}
struct Number {
int a,b;
void read() {
a=in();b=in();
}
friend bool operator < (const Number &a,const Number &b) {
return a.b<b.b;
}
};
Number a[100000];
int n;
struct SeqTreeNode {
int sum,mulSum;
int mul;
SeqTreeNode *ls,*rs;
SeqTreeNode *clear() {
ls=rs=NULL;
sum=mulSum=0;
mul=1;
return this;
}
void down();
void multipy(int x) {
sum=(long long)sum*x%mod;
mulSum=(long long)mulSum*x%mod;
mul=(long long)mul*x%mod;
}
void repair() {
sum=(ls->sum+rs->sum)%mod;
mulSum=(ls->mulSum+rs->mulSum)%mod;
}
};
SeqTreeNode b[4000000],*bp,*root;
void SeqTreeNode::down() {
if (ls==NULL) ls=(bp++)->clear();
if (rs==NULL) rs=(bp++)->clear();
ls->multipy(mul);
rs->multipy(mul);
mul=1;
}
void set(SeqTreeNode *from,int l,int r,int i,int x) {
if (l==r) {
int tmp=pow(2,l);
from->sum=(from->sum+x)%mod;
from->mulSum=(from->mulSum+(long long)tmp*x)%mod;
} else {
int t=(l+r)/2;
from->down();
if (i<=t) set(from->ls,l,t,i,x);
else set(from->rs,t+1,r,i,x);
from->repair();
}
}
void setMul2(SeqTreeNode *from,int l,int r,int ll,int rr) {
//printf("--%d %d %d %d %d\n",l,r,ll,rr,from->sum);
if (l==ll&&r==rr) {
from->multipy(2);
} else {
int t=(l+r)/2;
from->down();
if (rr<=t) setMul2(from->ls,l,t,ll,rr);
else if (ll>t) setMul2(from->rs,t+1,r,ll,rr);
else {
setMul2(from->ls,l,t,ll,t);
setMul2(from->rs,t+1,r,t+1,rr);
}
from->repair();
}
}
int getSum(SeqTreeNode *from,int l,int r,int ll,int rr) {
if (from->sum==0) return 0;
if (l==ll&&r==rr) {
return from->sum;
} else {
int t=(l+r)/2;
from->down();
if (rr<=t) return getSum(from->ls,l,t,ll,rr);
else if (ll>t) return getSum(from->rs,t+1,r,ll,rr);
else {
return (getSum(from->ls,l,t,ll,t)+getSum(from->rs,t+1,r,t+1,rr))%mod;
}
}
}
int getMulSum(SeqTreeNode *from,int l,int r,int ll,int rr) {
if (from->sum==0) return 0;
if (l==ll&&r==rr) {
return from->mulSum;
} else {
int t=(l+r)/2;
from->down();
if (rr<=t) return getMulSum(from->ls,l,t,ll,rr);
else if (ll>t) return getMulSum(from->rs,t+1,r,ll,rr);
else {
return (getMulSum(from->ls,l,t,ll,t)+getMulSum(from->rs,t+1,r,t+1,rr))%mod;
}
}
}
void print() {
for (int i=0;i<5;i++) {
printf("%d ",getSum(root,L,R,i,i));
}
printf("\n");
}
int main() {
int i;
while (scanf("%d",&n)!=EOF) {
for (i=0;i<n;i++) a[i].read();
sort(a,a+n);
int ans=0;
bp=b;
root=(bp++)->clear();
set(root,L,R,0,1);
for (i=0;i<n;i++) {
int lsum=getSum(root,L,R,L,a[i].a);
ans=(ans+
(long long)pow(3,a[i].b)*(getMulSum(root,L,R,a[i].a+1,R)+(long long)pow(2,a[i].a)*lsum%mod)%mod
)%mod;
set(root,L,R,a[i].a,lsum);
setMul2(root,L,R,a[i].a+1,R);
}
printf("%d\n",ans);
}
return 0;
}