题目链接
思路:如果三者之和不能被3整除,显然无解。
设三者平均数为m,,从大到小为x,y,z 。那么我们先将最大的数拿x-m出来,因为对于次数来说,顺序没有任何意义。
接下来再枚举分到z上的数有多少,设分到z上的为t。那么三者就成了m,y+x-m-t,m+t
三者构成等差数列,等价于0,c,2c。
我们设
f
c
f_c
fc为0 ,c,2c状态下的期望次数,那么显然
f
0
=
0
f_0=0
f0=0
2c要拿c个出来分给其他两个小的,分完之后,得到三个数又等价于一个新的状态,即得转移方程:
f
x
=
x
+
∑
i
=
0
x
(
i
x
)
2
x
f
i
f_x=x+\sum_{i=0}^{x} \frac{(^x_i)}{2^x}f_i
fx=x+∑i=0x2x(ix)fi,因为组合数
(
i
x
)
=
(
x
−
i
x
)
(^x_i)=(^x_{x-i})
(ix)=(x−ix)所以后面写成
f
i
f_i
fi是 没关系的,代码中体现出来的更直观一点…
f x = x + ∑ i = 0 x − 1 ( i x ) 2 x f i + f x 2 x f_x=x+\sum_{i=0}^{x-1}\frac{(^x_i)}{2^x}f_i+\frac{f_x}{2^x} fx=x+∑i=0x−12x(ix)fi+2xfx
2 x − 1 2 x f x = x + ∑ i = 0 x − 1 ( i x ) 2 x f i \frac{2^x-1}{2^x}f_x=x+\sum_{i=0}^{x-1}\frac{(^x_i)}{2^x}f_i 2x2x−1fx=x+∑i=0x−12x(ix)fi
2 x − 1 2 x f x = x + ∑ i = 0 x − 1 f i 2 x − 1 2 x \frac{2^x-1}{2^x}f_x=x+\sum_{i=0}^{x-1}f_i\frac{2^x-1}{2^x} 2x2x−1fx=x+∑i=0x−1fi2x2x−1
f x = x 2 x 2 x − 1 + ∑ i = 0 x − 1 ( i x ) f i 2 x − 1 f_x=\frac{x2^x}{2^x-1}+\sum_{i=0}^{x-1}\frac{(^x_i)f_i}{2^x-1} fx=2x−1x2x+∑i=0x−12x−1(ix)fi
打表发现,
f
x
=
2
x
f_x=2x
fx=2x
那么就做完了…
#pragma GCC optimize(2)
#pragma GCC optimize(3)
#include <bits/stdc++.h>
using namespace std;
typedef long long LL;
const int N = 1e6 + 10;
#define fi first
#define se second
#define pb push_back
#define wzh(x) cerr<<#x<<'='<<x<<endl;
int x, y, z;
int t;
LL fac[N], inv[N];
const LL mod = 998244353;
LL pm(LL x, LL y) {
LL z = 1;
while (y) {
if (y & 1)z = z * x % mod;
x = x * x % mod;
y >>= 1;
}
return z;
}
LL s[N], f[N];
LL fa[N], in[N];
void P() {
fa[0] = 1;
for (int i = 1; i < N; i++) {
fa[i] = fa[i - 1] * i % mod;
}
in[N - 1] = pm(fa[N - 1], mod - 2);
for (int i = N - 2; i >= 0; i--)in[i] = in[i + 1] * (i + 1) % mod;
}
LL get(int x, int y) {
return fa[x] * in[y] % mod * in[x - y] % mod;
}
int main() {
ios::sync_with_stdio(false);
fac[0] = 1;
for (int i = 1; i < N; i++) {
fac[i] = fac[i - 1] * 2 % mod;
inv[i] = pm(fac[i] - 1, mod - 2);
}
P();
for (cin >> t; t; t--) {
cin >> x >> y >> z;
if ((x + y + z) % 3) {
cout << -1 << '\n';
} else {
int m = (x + y + z) / 3;
if (y < z)swap(y, z);
if (x < y)swap(x, y);
if (y < z)swap(y, z);
LL ans = 0;
for (int i = 0; i <= (x - m); i++) {
int now = min({z + i, m, y + x - m + i});
vector<int>v;
v.pb(z + i - now);
v.pb(m - now);
v.pb(y + x - m + i - now);
sort(v.begin(), v.end());
ans = ans + get(x - m, i) % mod * 2 * v[1] % mod;
ans %= mod;
}
cout << ans*pm(fac[x - m], mod - 2) % mod + (x - m) % mod << '\n';
}
}
return 0;
}