首先考虑n^2的做法那就是
for(int i = 1; i <= 3; i++){
for(int j = 1; j <= 3 * n; j++){
for(int k = 1; k <= n; k++){
dp[i][j + k] += dp[i - 1][j];
我们从中可以看出 当 j = 1的时候他对所有 j + 1 ~ j + n 那么 j = 2 就对 j + 2 ~ j + n + 1是有贡献的
所以我们只需求一个前缀和 即可
#include<iostream>
using namespace std;
const int N = 3e6 + 10;
typedef long long ll;
ll dp[4][N],dp2[4][N];
int main(){
ll n,m;
cin >> n >> m;
dp[0][0] = 1;
for(int j = 1; j <= n; j++){
dp[1][j] = 1;
}
for(int j = 1; j <= 3 * n; j++){
dp2[1][j] = dp2[1][j - 1] + dp[1][j];
}
for(int i = 2; i <= 3; i++){
for(int j = 1; j <= 3 * n; j++){
dp[i][j] += dp2[i - 1][j - 1];
if(j >= n + 1){
dp[i][j] -= dp2[i - 1][j - n - 1];
}
}
for(int j = 1; j <= 3 * n; j++){
dp2[i][j] = dp2[i][j - 1] + dp[i][j];
}
}
ll x;
for(int i = 3; i <= 3 * n; i++){
if(m <= dp2[3][i]){
m -= dp2[3][i - 1];
x = i;
break;
}
}
for(int i = 1; i <= n; i++){
ll minn = max(1ll,x - i - n);
ll maxn = min(n,x - i - 1);
if(minn > maxn) continue;
if(m > maxn - minn + 1){
m -= maxn - minn + 1;
continue;
}
ll y = minn + m - 1;
ll z = x - i - y;
cout << i << " " << y << " " << z << endl;
return 0;
}
return 0;
}