【FFT】【NTT】codeforces1096G Lucky Tickets

G. Lucky Tickets
time limit per test5 seconds
memory limit per test256 megabytes
inputstandard input
outputstandard output
All bus tickets in Berland have their numbers. A number consists of n digits (n is even). Only k decimal digits d1,d2,…,dk can be used to form ticket numbers. If 0 is among these digits, then numbers may have leading zeroes. For example, if n=4 and only digits 0 and 4 can be used, then 0000, 4004, 4440 are valid ticket numbers, and 0002, 00, 44443 are not.

A ticket is lucky if the sum of first n/2 digits is equal to the sum of remaining n/2 digits.

Calculate the number of different lucky tickets in Berland. Since the answer may be big, print it modulo 998244353.

Input
The first line contains two integers n and k (2≤n≤2⋅105,1≤k≤10) — the number of digits in each ticket number, and the number of different decimal digits that may be used. n is even.

The second line contains a sequence of pairwise distinct integers d1,d2,…,dk (0≤di≤9) — the digits that may be used in ticket numbers. The digits are given in arbitrary order.

Output
Print the number of lucky ticket numbers, taken modulo 998244353.

Examples
inputCopy
4 2
1 8
outputCopy
6
inputCopy
20 1
6
outputCopy
1
inputCopy
10 5
6 1 4 0 3
outputCopy
569725
inputCopy
1000 7
5 4 0 1 8 3 2
outputCopy
460571165
Note
In the first example there are 6 lucky ticket numbers: 1111, 1818, 1881, 8118, 8181 and 8888.

There is only one ticket number in the second example, it consists of 20 digits 6. This ticket number is lucky, so the answer is 1.


 最近因为在考期,所以写题有些懈怠,考后继续加油。另外,最近计划参加ACM的新人大佬非常多,不得不说有些担心明年的选拔了。
 此题出自昨天CF的EDU Round,前面四题都1A,大概还有一个多小时,写完这题就rating就可以上黄色了,然而显然上天不愿意给我这个机会QwQ。
 题目很容易理解,也很容易想到多项式乘法。因为问题核心是要求n/2位数来表示某个数字的方法数,dp的方式就是不断乘以一个不超过十位的,由0和1组成的数。那么用快速幂或者分治的形式分解问题。两个大数相乘时采用FFT,那么效率就是nlognlogn,似乎勉强能过。
 很开心地写了一发。因为比赛时间问题,我写了更清晰但是低效的递归分治法,并且采用vector传参(反正CF不用担心爆栈,随便乱玩)。

#include<cstdio>
#include<vector>
#include<complex>
#define mo 998244353
using namespace std;
using db=long double;
using LL=long long;
using cp=complex<db>;

const int maxn=(1<<21)+5;
const db PI=acos(-1.0L); 
int rev[maxn],s,sum;
int n,k;
vector<int> h(10);

void get_rev(int bit)
{
	for(int i=0;i<(1<<bit);i++)
		rev[i]=(rev[i>>1]>>1)|((i&1)<<bit-1);
}

void FFT(vector<cp> &a, int n, int dft)
{
	cp x,y;
	for(int i=0;i<n;i++)
		if(i<rev[i])
			swap(a[i],a[rev[i]]);
	for(int stp=1;stp<n;stp<<=1)
	{
		cp wn=exp(cp(0,dft*PI/stp));
		for(int j=0;j<n;j+=stp<<1)
		{
			cp wnk(1,0);
			for(int k=j;k<j+stp;k++)
			{
				x=a[k];
				y=wnk*a[k+stp];
				a[k]=x+y;
				a[k+stp]=x-y;
				wnk*=wn;
			}
		}
	}
	if(dft==-1)
		for(int i=0;i<n;i++)
			a[i]/=n;
}

vector<int> mul(vector<int> x, vector<int> y)
{
	int s=2,l1=x.size(),l2=y.size(),bit;
	for(bit=1;(1<<bit)<l1+l2-1;bit++)
		s<<=1;
	get_rev(bit);
	vector<int> op(s);
	vector<cp> a(s),b(s);
	for(int i=0;i<l1;i++)
		a[i]=(db)x[i];
	for(int i=0;i<l2;i++)
		b[i]=(db)y[i];
	FFT(a,s,1);
	FFT(b,s,1);
	for(int i=0;i<s;i++)
		a[i]*=b[i];
	FFT(a,s,-1);
	for(int i=0;i<s;i++)
		op[i]=(LL)(a[i].real()+0.5)%mo;
	while(!op.empty()&&op.back()==0)
		op.pop_back();
	return op;
}

vector<int> solve(int x)
{
	if(x==0)
		return vector<int>(1,1);
	if(x==1)
		return h;
	vector<int> a=solve(x/2);
	if(x&1)
		return mul(mul(a,a),h);
	else
		return mul(a,a);
}

int main()
{
	scanf("%d%d",&n,&k);
	for(int i=1,tmp;i<=k;i++)
		scanf("%d",&tmp),h[tmp]=1;
	vector<int> ans=solve(n/2);
	for(int i:ans)
		sum=(sum+(LL)i*i%mo)%mo;
	printf("%d",sum);
	return 0;
}

 然而出现了两个问题,第一是最后一个样例过不了,这是因为浮点数精度的问题,因为FFT过程中,答案每一位最大值是(1E91E9位数)级别的,运算过程中a[i]最大值是1E91E9位数*位数级别的,位数最多1E6,那么总共可以达到1E30,即使用long double也力不从心。(妈耶,一开始还以为自己模板写错了,查了半天错)
通过printf("%.20Lf",3.14159265358979323846264338327950288L);输出可知long double大概18-19位精度。
 第二个问题就是,FFT效率不高,极限数据会TLE。
 第二天看了看大佬们的代码,查了查,才知道有一种NTT的东西(快速数论变换)通过整数代替复数,原根代替FFT中的单位根,逆元代替除法得到答案。这里的原根指的就是数论的原根,和复数单位根有相同的性质。
 这样的话只要在变换时不断取模就不会出现精度问题,而且效率快得不是一星半点。
 参考文章:
https://blog.csdn.net/enjoy_pascal/article/details/81771910
基础NTT(求FFT取模,避免了精度问题)
https://www.cnblogs.com/fenghaoran/p/7107608.html
https://blog.csdn.net/qq_35950004/article/details/79477797 任意模数NTT(利用中国剩余定理,可以计算出真值,效果完全等价于FFT,同时避免了精度问题)
 任意模数的NTT还可以用另一种把多项式的每一项系数拆成AM+B的取巧方法,参见
https://www.cnblogs.com/xzyxzy/p/9263480.html

 模数998244353是素数,而且是NTT素数(为啥必须是NTT素数,看过FFT的都知道了)即(P-1)有超过序列长度的2的正整数幂因子的质数,其中一个原根是3,如果是其他NTT素数的话,可以暴力法求原根。如果不是NTT素数,但是能分解出NTT素数因子,求每一个素因子的模数的答案,然后综合即可,除此之外只能用任意模数NTT的方法来做了。

#include<cstdio>
#include<vector>
#include<complex>
#define mo 998244353
#define root 3
using namespace std;
using LL=long long;

const int maxn=(1<<21)+5;
int rev[maxn],s,sum;
int n,k;
vector<int> h(10);

int quick_power(int a, int b)
{
	int res=1,base=a;
	while(b)
	{
		if(b&1)
			res=(LL)res*base%mo;
		base=(LL)base*base%mo;
		b>>=1;
	}
	return res;
}

void get_rev(int bit)
{
	for(int i=0;i<(1<<bit);i++)
		rev[i]=(rev[i>>1]>>1)|((i&1)<<bit-1);
}

void FFT(vector<int> &a, int n, int dft)
{
	int x,y;
	for(int i=0;i<n;i++)
		if(i<rev[i])
			swap(a[i],a[rev[i]]);
	for(int stp=1;stp<n;stp<<=1)
	{
		int wn=quick_power(root,(mo-1)/(stp*2));
		if(dft==-1)
			wn=quick_power(wn,mo-2);
		for(int j=0;j<n;j+=stp<<1)
		{
			int wnk=1;
			for(int k=j;k<j+stp;k++)
			{
				x=a[k];
				y=(LL)wnk*a[k+stp]%mo;
				a[k]=(x+y)%mo;
				a[k+stp]=(x-y+mo)%mo;
				wnk=(LL)wnk*wn%mo;
			}
		}
	}
	if(dft==-1)
	{
		int t=quick_power(n,mo-2);
		for(int i=0;i<n;i++)
			a[i]=(LL)a[i]*t%mo;
	}
}

vector<int> mul(vector<int> x, vector<int> y)
{
	int s=2,l1=x.size(),l2=y.size(),bit;
	for(bit=1;(1<<bit)<l1+l2-1;bit++)
		s<<=1;
	get_rev(bit);
	x.resize(s),y.resize(s);
	FFT(x,s,1);
	FFT(y,s,1);
	for(int i=0;i<s;i++)
		x[i]=(LL)x[i]*y[i]%mo;
	FFT(x,s,-1);
	while(!x.empty()&&x.back()==0)
		x.pop_back();
	return x;
}

vector<int> solve(int x)
{
	if(x==0)
		return vector<int>(1,1);
	if(x==1)
		return h;
	vector<int> a=solve(x/2);
	if(x&1)
		return mul(mul(a,a),h);
	else
		return mul(a,a);
}

int main()
{
	scanf("%d%d",&n,&k);
	for(int i=1,tmp;i<=k;i++)
		scanf("%d",&tmp),h[tmp]=1;
	vector<int> ans=solve(n/2);
	for(int i:ans)
		sum=(sum+(LL)i*i%mo)%mo;
	printf("%d",sum);
	return 0;
}

 最后才意识到一件很傻的事情,FFT不只是可以干两个多项式的乘法,多项式快速幂也是OK的,DFT之后每答案取p次方再IDFT就是答案,根本不用分治,自己思维之前有些僵化了。附上最终的代码吧,跑得还挺快的。

#include<cstdio>
#include<vector>
#include<complex>
#define mo 998244353
#define root 3
using namespace std;
using LL=long long;

const int maxn=(1<<21)+5;
int rev[maxn],s,sum;
int n,k;
vector<int> h(10);

int quick_power(int a, int b)
{
	int res=1,base=a;
	while(b)
	{
		if(b&1)
			res=(LL)res*base%mo;
		base=(LL)base*base%mo;
		b>>=1;
	}
	return res;
}

void get_rev(int bit)
{
	for(int i=0;i<(1<<bit);i++)
		rev[i]=(rev[i>>1]>>1)|((i&1)<<bit-1);
}

void FFT(vector<int> &a, int n, int dft)
{
	int x,y;
	for(int i=0;i<n;i++)
		if(i<rev[i])
			swap(a[i],a[rev[i]]);
	for(int stp=1;stp<n;stp<<=1)
	{
		int wn=quick_power(root,(mo-1)/(stp*2));
		if(dft==-1)
			wn=quick_power(wn,mo-2);
		for(int j=0;j<n;j+=stp<<1)
		{
			int wnk=1;
			for(int k=j;k<j+stp;k++)
			{
				x=a[k];
				y=(LL)wnk*a[k+stp]%mo;
				a[k]=(x+y)%mo;
				a[k+stp]=(x-y+mo)%mo;
				wnk=(LL)wnk*wn%mo;
			}
		}
	}
	if(dft==-1)
	{
		int t=quick_power(n,mo-2);
		for(int i=0;i<n;i++)
			a[i]=(LL)a[i]*t%mo;
	}
}

void ntt_pow(vector<int> &x, int p)
{
	int s=2,l1=x.size(),bit;
	for(bit=1;(1<<bit)<p*l1-p+1;bit++)
		s<<=1;
	get_rev(bit);
	x.resize(s);
	FFT(x,s,1);
	for(int i=0;i<s;i++)
		x[i]=quick_power(x[i],p);
	FFT(x,s,-1);
}

int main()
{
	scanf("%d%d",&n,&k);
	for(int i=1,tmp;i<=k;i++)
		scanf("%d",&tmp),h[tmp]=1;
	ntt_pow(h,n/2);
	for(int i:h)
		sum=(sum+(LL)i*i%mo)%mo;
	printf("%d",sum);
	return 0;
}
  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值