题目传送门:http://acm.hdu.edu.cn/showproblem.php?pid=1796
题意:给一个集合set和一个数n,让你构建一个新集合。新集合中的任意一个元素都能被set中某个元素整除,而且新集合中的元素要小于n,问你这个新集合中最多有多少个元素
题解:这是cls挂的容斥专题的B题,一开始我的思路就是错的,然后去瞟了一眼题解,然后恍然大悟,又重新搞一发,结果还是无限wa。而且自认为题解求lcm完全没必要,然后就各种想不通。最后终于在hdu的discuss中找到真相= =,终于算是a了
思路就是容斥,先加上小于n且能被set中一个元素整除的数字个数,然后减去能被两个数字整除(这两个数的lcm)的个数。接着就是反复这个过程,奇数个元素的就加,偶数个元素的就减,直到达到最大元素。
还有一些注意事项就是:
1、原数组中0要挖去。
2、元素中若一个元素能被另一个元素整除,则挖去这个元素。(这个不难理解)
ps:我刚开始一直wa就是因为我自认为我做了注意事项中的2后,集合中的元素就互质了。但这是不可能的,比如集合中存在6,9,我的想法很明显就错了。还是too naive = =
代码如下:
#include <iostream>
#include <algorithm>
#include <cstdlib>
#include <cstring>
#include <stdio.h>
#include <string>
#include <cmath>
#include <queue>
#include <set>
#include <map>
#include <stack>
#include <bitset>
#include <cstdlib>
#include <vector>
using namespace std;
#define lson l,m,rt<<1
#define rson m+1,r,rt<<1|1
#define ll long long
#define ull unsigned long long
#define mem(n,v) memset(n,v,sizeof(n))
#define MAX 200005
#define MAXN 300005
#define PI 3.1415926
#define E 2.718281828459
#define opnin freopen("input.txt","r",stdin)
#define opnout freopen("output.txt","w",stdout)
#define clsin fclose(stdin)
#define clsout fclose(stdout)
const int INF = 0x3f3f3f3f;
const ll INFF = 0x3f3f3f3f3f3f3f3f;
const double pi = 3.141592653589793;
const double inf = 1e18;
const double eps = 1e-8;
const ll mod = 1e18;
const ull mx = 133333331;
/**************************************************************************/
ll a[25];
ll b[25];
ll sum;
int flag;
int cnt;
ll n,m;
ll gcd(ll a, ll b)
{
if(b==0)return a;
return gcd(b,a%b);
}
ll lcm(ll a,ll b)
{
return a/gcd(a,b)*b;
}
ll dfs(int num,int cur,ll value,int id)
{
if(cur == num){
return n/(value);
}
ll temp = 0;
for(int i=id+1;i<m;i++){
temp += dfs(num,cur+1,lcm(value,b[i]),i);
}
return temp;
}
int main()
{
while(cin >> n >> m){
n--;
sum = 0;
cnt = 0;
for(int i=0;i<m;i++){
ll x;
scanf("%lld",&x);
if(x) a[cnt++] = x;
}
m = cnt;
sort(b,b+m);
cnt = 0;
for(int i=0;i<m;i++){
flag = 0;
for(int j=0;j<i;j++){
if(a[i] % a[j] == 0){
flag = 1;
break;
}
}
if(!flag){
b[cnt++] = a[i];
}
}
m = cnt;
cnt = 0;
ll temp;
for(int i=1;i<=m;i++){
temp = 0;
for(int j=0;j<m;j++){
temp += dfs(i,1,b[j],j);
}
// cout << "i,tmp " << i << ' ' << temp << endl;
if(i % 2 == 1) sum += temp;
else sum -= temp;
}
cout << sum <<endl;
}
return 0;
}