题解
看重点·这位大佬讲的很详细了
构建图:若满足编号相加为平方数,就两两连边
题目问n根柱子最多放多少个球,
等同于问不超过n个路径,最多可以覆盖多少满足条件(一根柱子就是一条路径)
这里有个是DAG里二分的性质:
最小边覆盖 = 点总数 - 最大匹配
什么?你没学过最小路径覆盖?看看这个吧
大致做法:
因为珠子数随着柱子的增加是不递减,所以可以用二分找出珠子编号的上界,
然后网络流求出最大匹配,用公式求出最小边覆盖路径,如果结果大于柱子数,说明要减小,反之增大,
在网络流dfs()
里记录路径,
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N=1e6+10;
const int INF=0x3f3f3f3f;
int n,m,k;
namespace Network_flows { //网络流板子
//设定起点和终点
int st=0;//起点-源点
int ed=6000;//终点-汇点
struct egde {
int to, next;
int flow;//剩余流量
} e[N * 2];
int head[N], tot = 1;
void add(int u, int v, int w) {
e[++tot] = {v, head[u], w};
head[u] = tot;
e[++tot] = {u, head[v], 0};
head[v] = tot;//网络流反相边流量为0
}
int dep[N];//dep[]=-1时为炸点
queue<int> q;
bool bfs() {
memset(dep, 0, sizeof(dep));//顺便起到vis的功能
q.push(st);
dep[st] = 1;
while (!q.empty()) {
int u = q.front();
q.pop();
for (int i = head[u]; i; i = e[i].next) {
int v = e[i].to;
if (!dep[v] && e[i].flow) {
dep[v] = dep[u] + 1;
q.push(v);
}
}
}
return dep[ed];
}
int nxt[N];//路径
int dfs(int u, int Flow) {
if (u == ed) return Flow;
int now_flow = 0;//跑残流
for (int i = head[u]; i; i = e[i].next) {
int v = e[i].to;
if (dep[v] == dep[u] + 1 && e[i].flow) {
int f = dfs(v, min(Flow - now_flow, e[i].flow));
e[i].flow -= f;
e[i ^ 1].flow += f;
now_flow += f;
if(f)nxt[u]=v;//如果能往下流 就记录路径
if (now_flow == Flow) return Flow;
}
}
if (now_flow == 0)dep[u] = -1;
return now_flow;
}
#define max_flow dinic
int dinic() {//最大流
int res = 0;
while (bfs()) {
res += dfs(st, INF);
}
return res;
}
void init() {
tot = 1;
memset(head, 0, sizeof(head));
while (!q.empty()) q.pop();
memset(nxt, 0, sizeof(nxt));
}
}
using namespace Network_flows;
bool solve(int n){
init();
for (int i = 1; i <= n; ++i) {
for (int j = i + 1; j <= n; ++j) {
int x=sqrt(i+j);
if(x*x==i+j){
add(i,n+j,1);
}
}
}
for (int i = 1; i <= n; ++i) {
add(st,i,1);
add(i+n,ed,1);//加个基准值以区分 st -> i 的i点
}
int res=n-dinic();
return k < res;
}
int main(){
ios::sync_with_stdio(0);
cin>>k;
int l=0,r=2500;
//二分上界n
while(l<r){
int mid=l+r>>1;
if( solve(mid) ){//如果最小边覆盖 超过上限 缩小n
r=mid;
}else l=mid+1;
}
n=l-1;//应该在最后一个mid-1的位置符合
solve(n);//这一步是为了统计路径
printf("%d\n", n);
//输出路径
for (int i = 1; i <= n; ++i) {
if(!nxt[i]) continue;
int u=i;
while(u){
if(u>n)u-=n;//网络流的时候多加了n 要拆掉
printf("%d ", u);
int v=nxt[u];
nxt[u]=0;
u=v;
}
printf("\n");
}
return 0;
}