Solution
主要是记录几个小技巧吧。
首先我们发现 n n n 与 m m m 的值互换不影响结果。那么就有了第一个优化:令 n n n 为较小值,则 n n n 的范围就是 [ 1 , N ] [1,\sqrt{N}] [1,N]。( N N N 就是题目给的 n n n)
然后我们沿着对角线往下填。如果列数超过 m m m 就将列置为 1 1 1。可以发现这样填保证每种数最多填 n n n 个。
我们将每种数的出现次数从小到大排序,可以 log ( N ) \log(N) log(N) 维护每个 n n n 时最多能填的数。
最后填数的环节有一个问题:直接开 v [ 4 e 5 ] [ 4 e 5 ] v[4e5][4e5] v[4e5][4e5] 的数组会炸。我们可以用 v e c t o r \mathtt{vector} vector 的 r e s i z e \mathtt{resize} resize。
Code
#include <map>
#include <cmath>
#include <cstdio>
#include <vector>
#include <iostream>
#include <algorithm>
using namespace std;
const int N = 4e5 + 5;
int n, Sqrt, a[N], p[N], cnt, pre[N], maxsum, ans, sum, res, lim, b[N];
map <int, int> ma;
vector < vector <int> > v;
struct node {int tot, co;} s[N];
int read() {
int x = 0, f = 1; char s;
while((s = getchar()) > '9' || s < '0') if(s == '-') f = -1;
while(s >= '0' && s <= '9') x = (x << 1) + (x << 3) + (s ^ 48), s = getchar();
return x * f;
}
bool cmp(const int x, const int y) {return s[x].tot < s[y].tot;}
bool Cmp(const node x, const node y) {return x.tot > y.tot;}
int main() {
n = read(); Sqrt = sqrt(n);
for(int i = 1; i <= n; ++ i) {
a[i] = read();
if(! ma[a[i]]) ++ cnt, p[cnt] = cnt, ma[a[i]] = cnt, s[cnt].co = a[i];
++ s[ma[a[i]]].tot;
}
for(int i = 1; i <= cnt; ++ i) b[i] = s[i].tot;
sort(p + 1, p + cnt + 1, cmp); sort(b + 1, b + cnt + 1);
for(int i = 1; i <= cnt; ++ i) pre[i] = pre[i - 1] + b[i];
for(int i = 1; i <= Sqrt; ++ i) {
int pos = upper_bound(b + 1, b + cnt + 1, i) - b - 1;
sum = pre[pos] + (cnt - pos) * i;
if(sum / i >= i && maxsum < sum / i * i) maxsum = sum / i * i, ans = i;
}
printf("%d\n%d %d\n", maxsum, ans, maxsum / ans); res = maxsum / ans;
v.resize(ans);
for(int i = 0; i < ans; ++ i) v[i].resize(res);
int x = 1, y = 1; lim = 1;
sort(s + 1, s + cnt + 1, Cmp);
for(int i = 1; i <= cnt; ++ i) {
s[i].tot = min(s[i].tot, ans);
while(s[i].tot) {
v[x - 1][y - 1] = s[i].co; ++ x; ++ y; -- s[i].tot;
if(x == ans + 1) x = 1, ++ lim, y = lim;
if(y == res + 1) y = 1;
if(lim == res + 1) break;
}
if(lim == res + 1) break;
}
for(int i = 0; i < ans; ++ i) {
for(int j = 0; j < res; ++ j) printf("%d ", v[i][j]);
puts("");
}
return 0;
}