【前言】
比赛的时候1013写了个假做法,然后数学题没做出来,罚时GG
rk65,校4/9
1001. Miserable Faith
【题意】
给一颗树,每个节点有一个颜色,初始 c i = i c_i=i ci=i,定义边权为:链接两点颜色相同为1,否则为0
- 点到根的颜色换成一个新颜色
- 询问两点的距离
- 询问子树所有点到它的距离和
- 询问全局每个节点最多往上能走多少条边权为0的边
【思路】
LCT,不妨让实边边权为1,虚边边权为0,然后考虑贡献。
操作1实际上是在模拟LCT的 a c c e s s access access操作,然后虚实边切换的时候需要用线段树区间修改,同时维护操作4的答案。
操作2相当于单点询问,操作3就是询问子树和然后扣掉单点 u u u的答案。
然而这种LCT的题已经不太写的来了所以比赛没写233.
复杂度 O ( n log 2 n ) O(n\log ^2n) O(nlog2n)
【参考代码】
std
#include <bits/stdc++.h>
using namespace std;
#define rep(i, a, b) for (int i = (a); i <= (b); i++)
#define per(i, a, b) for (int i = (a); i >= (b); i--)
#define fi first
#define se second
#define mp make_pair
#define pb push_back
#define mid (l + r >> 1)
#define lc (o << 1)
#define rc (o << 1 | 1)
typedef pair<int, int> pii;
typedef long long ll;
typedef double db;
inline ll getint()
{
char c = getchar();
ll ret = 0, f = 1;
while (c < '0' || c > '9')
{
if (c == '-') f = -1;
c = getchar();
}
while (c >= '0' && c <= '9')
ret = ret * 10 + c - '0', c = getchar();
return ret * f;
}
const int maxn = 1e5 + 10;
const int segn = maxn << 2;
vector<int> E[maxn];
int n, m;
int Fa[maxn], top[maxn], son[maxn], Siz[maxn], dep[maxn], dfn[maxn], tme;
int rt[maxn], ch[maxn][2], pfa[maxn], fa[maxn], siz[maxn];
ll sum[segn], at[segn], Ans;
inline void up(int o)
{
sum[o] = sum[lc] + sum[rc];
}
inline void addtag(int o, int l, int r, ll x)
{
at[o] += x; sum[o] += (r - l + 1) * x;
}
inline void down(int o, int l, int r)
{
if (at[o])
{
addtag(lc, l, mid, at[o]);
addtag(rc, mid + 1, r, at[o]);
at[o] = 0;
}
}
inline void modify(int o, int l, int r, int al, int ar, ll x)
{
if (al > ar) return;
if (al <= l && r <= ar) return addtag(o, l, r, x), void();
down(o, l, r);
if (al <= mid) modify(lc, l, mid, al, ar, x);
if (mid < ar) modify(rc, mid + 1, r, al, ar, x);
up(o);
}
inline ll query(int o, int l, int r, int al, int ar)
{
if (al > ar) return 0;
if (al <= l && r <= ar) return sum[o];
down(o, l, r);
ll ret = 0;
if (al <= mid) ret += query(lc, l, mid, al, ar);
if (mid < ar) ret += query(rc, mid + 1, r, al, ar);
return ret;
}
inline void clear(int o, int l, int r)
{
if (l > r) return;
sum[o] = at[o] = 0;
if (l == r) return;
clear(lc, l, mid);
clear(rc, mid + 1, r);
}
inline void maintain(int o)
{
siz[o] = siz[ch[o][0]] + siz[ch[o][1]] + 1;
rt[o] = ch[o][0] ? rt[ch[o][0]] : o;
}
inline void rotate(int x)
{
int o = fa[x], y = fa[o];
pfa[x] = pfa[o]; pfa[o] = 0;
int d = ch[o][1] == x ? 0 : 1;
ch[o][d ^ 1] = ch[x][d]; maintain(o);
if (ch[x][d]) fa[ch[x][d]] = o;
ch[x][d] = o; maintain(x);
fa[o] = x; fa[x] = y;
if (y) ch[y][ch[y][1] == o] = x, maintain(y);
}
inline void splay(int x)
{
for (int y = fa[x]; y; rotate(x), y = fa[x])
if (fa[y]) rotate((ch[y][1] == x) ^ (ch[fa[y]][1] == y) ? x : y);
}
inline void access(int x)
{
for (int u = x, v = 0; u; v = u, u = pfa[u])
{
splay(u);
if (ch[u][1])
{
int p = ch[u][1];
modify(1, 1, n, dfn[rt[p]], dfn[rt[p]] + Siz[rt[p]] - 1, 1);
Ans -= 1ll * siz[p] * (siz[u] - siz[p]);
fa[p] = 0, pfa[p] = u;
}
ch[u][1] = v; maintain(u);
if (v)
{
Ans += 1ll * siz[v] * (siz[u] - siz[v]);
modify(1, 1, n, dfn[rt[v]], dfn[rt[v]] + Siz[rt[v]] - 1, -1);
fa[v] = u, pfa[v] = 0;
}
}
}
inline void dfs(int u)
{
dfn[u] = ++tme; son[u] = 0;
dep[u] = dep[Fa[u]] + 1; Siz[u] = 1;
for (auto v : E[u])
{
if (v == Fa[u]) continue;
Fa[v] = u; dfs(v); Siz[u] += Siz[v];
if (Siz[v] > Siz[son[u]]) son[u] = v;
}
}
inline void cut(int u)
{
if (son[u]) top[son[u]] = top[u], cut(son[u]);
for (auto v : E[u])
if (v != Fa[u] && v != son[u]) top[v] = v, cut(v);
}
inline int getlca(int u, int v)
{
while (top[u] != top[v])
{
if (dep[top[u]] < dep[top[v]]) swap(u, v);
u = Fa[top[u]];
}
return dep[u] > dep[v] ? v : u;
}
inline void solve()
{
Ans = tme = 0; clear(1, 1, n);
rep(i, 1, n) E[i].clear(), ch[i][0] = ch[i][1] = 0, fa[i] = 0;
n = getint(); m = getint();
rep(i, 1, n - 1)
{
int u = getint(), v = getint();
E[u].pb(v); E[v].pb(u);
}
dfs(1); top[1] = 1; cut(1);
rep(i, 2, n) pfa[i] = Fa[i];
rep(i, 1, n) siz[i] = 1, rt[i] = i;
rep(i, 2, n)
modify(1, 1, n, dfn[i], dfn[i] + Siz[i] - 1, 1);
while (m--)
{
int typ = getint();
if (typ == 1)
{
access(getint());
int tmp = getint();
}
if (typ == 2)
{
int u = getint(), v = getint(), lca = getlca(u, v);
ll ans = query(1, 1, n, dfn[u], dfn[u]) + query(1, 1, n, dfn[v], dfn[v]);
ans -= query(1, 1, n, dfn[lca], dfn[lca]) << 1;
printf("%lld\n", ans);
}
if (typ == 3)
{
int u = getint();
ll ans = query(1, 1, n, dfn[u], dfn[u] + Siz[u] - 1);
ans -= query(1, 1, n, dfn[u], dfn[u]) * Siz[u];
printf("%lld\n", ans);
}
if (typ == 4) printf("%lld\n", Ans);
}
}
int main()
{
int t = getint();
while (t--) solve();
return 0;
}
1002. String Mod
【题意】
对于长度为 n n n的字符范围为 ′ a ′ ∼ ′ a ′ + k − 1 'a'\sim 'a'+k-1 ′a′∼′a′+k−1的字符串,有 k L k^L kL种。
现在对于每个数对 ( i , j ) ( 0 ≤ i , j ≤ n − 1 ) (i,j)(0\leq i,j\leq n-1) (i,j)(0≤i,j≤n−1),求出有 p p p个字母 a a a和 q q q个字母 b b b的字符串个数,其中 p ≡ i ( mod n ) , q ≡ j ( mod n ) p\equiv i(\text{ mod }n),q\equiv j(\text{ mod }n) p≡i( mod n),q≡j( mod n),输出矩阵 A [ i ] [ j ] A[i][j] A[i][j]表示答案。
k ≤ 26 , L ≤ 1 0 18 , n ≤ 500 k\leq 26,L\leq 10^{18},n\leq 500 k≤26,L≤1018,n≤500,保证 n n n是模数 P P P的约数。
【思路】
不会。
可以通过枚举字符 a a a, b b b在字符串出现的次数 x x x, y y y得到:
a n s [ i ] [ j ] = ∑ x = 0 L ∑ y = 0 L − x [ n ∣ x − i ] [ n ∣ y − j ] ( L x ) ( L − x y ) ( k − 2 ) L − x y ans[i][j] = \sum_{x=0}^L\sum_{y=0}^{L-x}[n\mid x-i][n\mid y-j]\tbinom{L}{x}\tbinom{L - x}{y}(k - 2)^{L-xy} ans[i][j]=x=0∑Ly=0∑L−x[n∣x−i][n∣y−j](xL)(yL−x)(k−2)L−xy
= ∑ x = 0 L ∑ y = 0 L − x 1 n ∑ p = 0 n − 1 w n p × ( x − i ) 1 n ∑ q = 0 n − 1 w n q × ( y − j ) ( L x ) ( L − x y ) ( k − 2 ) L − x y \quad \quad\quad \quad =\sum_{x=0}^L\sum_{y=0}^{L-x}\frac{1}{n}\sum_{p = 0}^{n - 1}w_n^{p\times(x-i)}\frac{1}{n}\sum_{q=0}^{n - 1}w_n^{q\times(y-j)}\tbinom{L}{x}\tbinom{L - x}{y}(k - 2)^{L-xy} =x=0∑Ly=0∑L−xn1p=0∑n−1wnp×(x−i)n1q=0∑n−1wnq×(y−j)(xL)(yL−x)(k−2)L−xy
由单位根反演可得
= 1 n 2 ∑ x = 0 L ∑ y = 0 L − x ∑ p = 0 n − 1 w n p × ( x − i ) ∑ q = 0 n − 1 w n q × ( y − j ) ( L x ) ( L − x y ) ( k − 2 ) L − x y \quad \quad\quad \quad =\frac{1}{n^2}\sum_{x=0}^L\sum_{y=0}^{L-x}\sum_{p = 0}^{n - 1}w_n^{p\times(x-i)}\sum_{q=0}^{n - 1}w_n^{q\times(y-j)}\tbinom{L}{x}\tbinom{L - x}{y}(k - 2)^{L-xy} =n21x=0∑Ly=0∑L−xp=0∑n−1wnp×(x−i)q=0∑n−1wnq×(y−j)(xL)(yL−x)(k−2)L−xy
= 1 n 2 ∑ p = 0 n − 1 ∑ q = 0 n − 1 ∑ x = 0 L ∑ y = 0 L − x w n p × ( x − i ) w n q × ( y − j ) ( L x ) ( L − x y ) ( k − 2 ) L − x y \quad \quad\quad \quad=\frac{1}{n^2}\sum_{p=0}^{n - 1}\sum_{q = 0}^{n - 1}\sum_{x=0}^L\sum_{y=0}^{L-x}w_n^{p\times(x-i)}w_n^{q\times(y-j)}\tbinom{L}{x}\tbinom{L - x}{y}(k - 2)^{L-xy} =n21p=0∑n−1q=0∑n−1x=0∑Ly=0∑L−xwnp×(x−i)wnq×(y−j)(xL)(yL−x)(k−2)L−xy
= 1 n 2 ∑ p = 0 n − 1 ∑ q = 0 n − 1 ∑ x = 0 L ∑ y = 0 L − x w n p x w n q y ( L x ) ( L − x y ) ( k − 2 ) L − x y w n − p i w n − q j \quad \quad\quad \quad =\frac{1}{n^2}\sum_{p=0}^{n - 1}\sum_{q = 0}^{n - 1}\sum_{x=0}^L\sum_{y=0}^{L-x}w_n^{px}w_n^{qy}\tbinom{L}{x}\tbinom{L - x}{y}(k - 2)^{L-xy}w_n^{-pi}w_n^{-qj} =n21p=0∑n−1q=0∑n−1x=0∑Ly=0∑L−xwnpxwnqy(xL)(yL−x)(k−2)L−xywn−piwn−qj
= 1 n 2 ∑ p = 0 n − 1 ∑ q = 0 n − 1 ( w n p + w n q + k − 2 ) L w n − p i w n − q j \quad \quad\quad \quad =\frac{1}{n^2}\sum_{p=0}^{n - 1}\sum_{q = 0}^{n - 1}(w_n^{p} + w_n^{q} + k - 2) ^{ L}w_n^{-pi}w_n^{-qj} =n21p=0∑n−1q=0∑n−1(wnp+wnq+k−2)Lwn−piwn−qj
我们令
A [ i ] [ p ] = w n − i p A[i][p] = w_n^{-ip} A[i][p]=wn−ip B [ p ] [ q ] = 1 n 2 ( w n p + w n q + k − 2 ) L B[p][q] = \frac{1}{n ^2} (w_n^{p} + w_n^{q} + k - 2) ^{ L} B[p][q]=n21(wnp+wnq+k−2)L C [ q ] [ j ] = w n − q j C[q][j] = w_n^{-qj} C[q][j]=wn−qj
则 a n s = A × B × C ans = A \times B \times C ans=A×B×C
总时间复杂度为 O ( n 3 + n 2 l o g L ) O(n ^ 3 + n ^ 2 logL ) O(n3+n2logL)
【参考代码】
std
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef long double ld;
typedef unsigned long long ull;
typedef pair<ll,ll> pii;
#define rep(i,x,y) for(int i=x;i<y;i++)
#define rept(i,x,y) for(int i=x;i<=y;i++)
#define all(x) x.begin(),x.end()
#define fi first
#define se second
#define mes(a,b) memset(a,b,sizeof a)
#define mp make_pair
#define pb push_back
#define dd(x) cout<<#x<<"="<<x<<" "
#define de(x) cout<<#x<<"="<<x<<"\n"
const int inf=0x3f3f3f3f;
const int maxn=500;
const int mod=1e9+9;
ll x[maxn],y[maxn];
int p[maxn],q[maxn];
class matrix
{
public:
ll arrcy[maxn][maxn];//?????§³???¡À?0??row-1,0??column-1
int row,column;//row??§µ?column???
matrix()
{
memset(arrcy,0,sizeof arrcy);
column=row=0;
}
friend matrix operator *(matrix s1,matrix s2)
{
int i,j;
matrix s3;
for (i=0;i<s1.row;i++)
{
for (j=0;j<s2.column;j++)
{
for (int k=0;k<s1.column;k++)
{
s3.arrcy[i][j]+=s1.arrcy[i][k]*s2.arrcy[k][j];
s3.arrcy[i][j]%=mod;
}
}
}
s3.row=s1.row;
s3.column=s2.column;
return s3;
}
void show()
{
for(int i=0;i<row;i++)
{
for (int j=0;j<column;j++)
cout<<arrcy[i][j]<<" ";
cout<<"\n";
}
}
}mat1,mat2,mat3;
ll qpow(ll a,ll b)
{
ll ans=1;
for(;b;b>>=1,a=a*a%mod)
if(b&1)
ans=ans*a%mod;
return ans;
}
/*
matrix quick_pow(matrix s1,long long n)
{
matrix mul=s1,ans;
ans.row=ans.column=s1.row;
memset(ans.arrcy,0,sizeof ans.arrcy);
for(int i=0;i<ans.row;i++)
ans.arrcy[i][i]=1;
while(n)
{
if(n&1) ans=ans*mul;
mul=mul*mul;
n/=2;
}
return ans;
}
*/
int g=13;
void solve()
{
int k,n;
ll l;
cin>>k>>l>>n;
k-=2;
x[0]=1;
g=qpow(13,(mod-1)/n);
rept(i,1,n) x[i]=x[i-1]*g%mod;
mat1.row=mat1.column=n;
rep(i,0,n)
rep(j,0,n)
mat1.arrcy[i][j]=qpow(g,i*j);
// mat1.show();
mat3=mat1;
rep(i,0,n) rep(j,i+1,n) swap(mat3.arrcy[i][j],mat3.arrcy[j][i]);
mat2.row=mat2.column=n;
rep(i,0,n) rep(j,0,n) mat2.arrcy[i][j]=qpow(x[i]+x[j]+k,l);
//mat1.show();
//mat2.show();
//mat3.show();
mat1=mat1*mat2;
mat1=mat1*mat3;
p[0]=q[0]=0;
rep(i,1,n) p[i]=q[i]=n-i;
ll nn=qpow(n*n,mod-2);
rep(i,0,n)
{
rep(j,0,n)
{
cout<<mat1.arrcy[p[i]][q[j]]*nn%mod;
if(j==n-1) cout<<"\n";
else cout<<" ";
}
}
return ;
}
int main(){
ios::sync_with_stdio(false);
cin.tie(0);cout.tie(0);
int T;
cin>>T;
while(T--)
solve();
return 0;
}
1003. VC Is All You Need
【题意】
n n n维空间中,求 m m m的最大值,使得你可以找到 m m m个点(自己给定坐标),满足:
无论对这 m m m个点如何二染色,也就是对于 2 m 2^m 2m种染色方案中的每一种,都总存在一个 n − 1 n-1 n−1维超平面,严格分开这两种颜色的点。
【结论】
m max = n + 1 m_\text{max}=n+1 mmax=n+1
【证明思路】
显然有单调性,可以分两步证明结论:
- 证明
m
=
n
+
1
m=n+1
m=n+1可行
- 即:可构造一组点的坐标( x 0 , x 1 , x 2 , ⋯ , x n ∈ R n x_0,x_1,x_2,⋯,x_n\in R^n x0,x1,x2,⋯,xn∈Rn),证明其任意一组染色方案,都可以找到一个线性超平面将其染成这个方案
- 这里染成这个方案指的是,超平面( > 0 >0 >0)一侧是黑,一侧( < 0 <0 <0)是白
- 证明
m
≥
n
+
2
m\ge n+2
m≥n+2无解
- 即:证明 n + 2 n+2 n+2个点无论如何放置(设计坐标),都总存在一种染色方案,使得没有一个线性超平面可以将其染成这个方案
1004. Another String
给定一个字符串, ∀ i ∈ [ 1 , n − 1 ] \forall i\in[1,n-1] ∀i∈[1,n−1],将字符串分为 S [ 1.. i ] , S [ i + 1.. n ] S[1..i],S[i+1..n] S[1..i],S[i+1..n],求两个部分失配至多 k k k次能匹配的子串个数。
n ≤ 3000 n\leq 3000 n≤3000
【思路】
首先我们考虑这样一个问题:如果我们知道了 S [ x . . n ] S[x..n] S[x..n]和 S [ y . . n ] S[y..n] S[y..n]失配至多 k k k次能匹配到多长,那么原问题如何解决?
不难发现,我们会给分割点在 [ x , y − 1 ] [x,y-1] [x,y−1]的所有答案都加上一个数,而且这段加的数是一段公差为 1 1 1的等差数列和一段相同的数组成,这个问题可以通过二阶差分来处理。
考虑如何解决我们提出的问题:
设
f
[
i
]
[
j
]
f[i][j]
f[i][j],第一个串从
i
i
i开始,第二个从
j
j
j开始,在原串什么位置(相对于
i
i
i)失配
k
+
1
k+1
k+1次,
l
a
s
[
i
]
[
j
]
las[i][j]
las[i][j]表示从
i
,
j
i,j
i,j开始第一个失配的位置(不包括这个位置)。
s
[
i
−
1
]
=
=
s
[
j
−
1
]
⇒
l
a
s
[
i
]
[
j
]
=
l
a
s
[
i
−
1
]
[
j
−
1
]
s
[
i
−
1
]
≠
s
[
j
−
1
]
⇒
l
a
s
[
i
]
[
j
]
=
i
−
1
\begin{aligned} &s[i-1]==s[j-1]\Rightarrow las[i][j]=las[i-1][j-1]\\ &s[i-1]\neq s[j-1]\Rightarrow las[i][j]=i-1 \end{aligned}
s[i−1]==s[j−1]⇒las[i][j]=las[i−1][j−1]s[i−1]=s[j−1]⇒las[i][j]=i−1
f
f
f我们从后往前转移,如果失配没有
k
+
1
k+1
k+1次,那么:
f
[
i
]
[
j
]
=
i
+
(
n
−
j
)
+
1
f[i][j]=i+(n-j)+1
f[i][j]=i+(n−j)+1
否则,先找到第一次失配
k
+
1
k+1
k+1次的位置,然后:
s
[
i
−
1
]
=
=
s
[
j
−
1
]
⇒
f
[
i
−
1
]
[
j
−
1
]
=
f
[
i
]
[
j
]
s
[
i
−
1
]
≠
s
[
j
−
1
]
⇒
f
[
i
−
1
]
[
j
−
1
]
=
l
a
s
[
f
[
i
]
[
j
]
]
[
f
[
i
]
[
j
]
−
i
+
j
]
\begin{aligned} &s[i-1]==s[j-1]\Rightarrow f[i-1][j-1]=f[i][j]\\ &s[i-1]\neq s[j-1]\Rightarrow f[i-1][j-1]=las[f[i][j]][f[i][j]-i+j] \end{aligned}
s[i−1]==s[j−1]⇒f[i−1][j−1]=f[i][j]s[i−1]=s[j−1]⇒f[i−1][j−1]=las[f[i][j]][f[i][j]−i+j]
注意一下边界条件即可。
复杂度 O ( n 2 ) O(n^2) O(n2)
1005. Random Walk 2
【题意】
给定一个邻接矩阵 W W W表示走每条边的权重(可以求出概率矩阵 p p p),如果在一步中留在了原地,那就不会再动,对于每个 ( u , v ) (u,v) (u,v)求 u u u能走到 v v v的概率。
n ≤ 300 n\leq 300 n≤300
【思路】
考虑枚举终点
t
t
t,设
x
i
x_i
xi表示在
i
i
i停下的概率我们有:
x
i
=
{
∑
j
=
1
n
x
j
p
j
i
x
≠
t
p
i
i
+
∑
j
=
1
n
x
j
p
j
i
x
=
t
x_i=\left\{\begin{matrix} \sum_{j=1}^n x_jp_{ji}& x\neq t\\ p_{ii}+\sum_{j=1}^n x_jp_{ji}& x= t \end{matrix}\right.
xi={∑j=1nxjpjipii+∑j=1nxjpjix=tx=t
对于一个点来说我们直接高斯消元即可。但这样总的复杂度就是
O
(
n
4
)
O(n^4)
O(n4)的了。
p p p实际上是个经典的矩阵转移,观察到实际上每次更换终点 t t t,只有常数项不一样,而前面消元的过程都是一样的。于是我们可以对矩阵求逆,然后每次只需要做一个矩阵乘法就行了。
这样复杂度就是 O ( n 3 ) O(n^3) O(n3)了。
题解的思路大概是观察到原矩阵可以拆成一个对角矩阵和一个对角为0的矩阵的和,然后利用矩阵幂级数做:
记 λ i = W i , i , d i = ∑ j = 1 n W i , j \displaystyle \lambda_i=W_{i,i},d_i=\sum_{j=1}^nW_{i,j} λi=Wi,i,di=j=1∑nWi,j
则有转移方程
A i , i = λ i λ i + d i ∗ 1 + ∑ j = 1 , j ≠ i n W i , j λ i + d i A j , i \displaystyle A_{i,i}=\frac{\lambda_i}{\lambda_i+d_{i}}*1+\sum_{j=1,j\not=i}^n\frac{W_{i,j}}{\lambda_i+d_{i}}A_{j,i} Ai,i=λi+diλi∗1+j=1,j=i∑nλi+diWi,jAj,i
A i , j = ∑ k = 1 , k ≠ i n W i , k λ i + d i A k , j ( i ≠ j ) \displaystyle A_{i,j}=\sum_{k=1,k\not=i}^n\frac{W_{i,k}}{\lambda_i+d_{i}}A_{k,j}(i\not=j) Ai,j=k=1,k=i∑nλi+diWi,kAk,j(i=j)
构造无自环矩阵 W W W,自环的对角矩阵 Λ \Lambda Λ, W W W度数矩阵 D D D
( I − ( Λ + D ) − 1 W ) A = ( Λ + D ) − 1 Λ (I-(\Lambda +D)^{-1}W)A=(\Lambda+D)^{-1}\Lambda (I−(Λ+D)−1W)A=(Λ+D)−1Λ
第一步列转移方程
第二步写出矩阵形式
第三步解矩阵方程,套板子
算出 A − 1 A^{-1} A−1矩阵,套个求逆板子即可
【参考代码】
#include<bits/stdc++.h>
#define rep(i,a,b) for(int i=(a),i##ss=(b);i<=i##ss;i++)
#define dwn(i,a,b) for(int i=(a),i##ss=(b);i>=i##ss;i--)
#define deb(x) cerr<<(#x)<<":"<<(x)<<'\n'
#define pb push_back
#define mkp make_pair
#define fi first
#define se second
#define hvie '\n'
using namespace std;
typedef pair<int,int> pii;
typedef long long ll;
typedef unsigned long long ull;
typedef double db;
ll yh(){
ll x;scanf("%lld",&x);return x;
}
const int maxn=1e5+5,N=1e5,mod=998244353;
ll w[555][555],p[555][555],t[555][555];
ll a[555][1555],ans[555][555],v[555];
// ll inv[maxn*301];
int n,m;
ll ksm(ll x,ll p,ll a=1){
for(;p;p>>=1,x=x*x%mod) if(p&1) a=a*x%mod;
return a;
}
void INV(){
rep(i,1,n)rep(j,1,n) a[i][j+n]=(i==j);
// puts("??");
// rep(i,1,n){
// rep(j,1,m){
// cout<<a[i][j]<<" \n"[j==m];
// }
// }
// puts("??");
rep(i,1,n){
rep(j,i,n){
if(a[j][i]){
rep(k,1,m) swap(a[j][k],a[i][k]);
break;
}
}
if(!a[i][i]) {assert(0);break;}
ll r=ksm(a[i][i],mod-2);
rep(j,i,m) a[i][j]=a[i][j]*r%mod;
rep(j,1,n){
if(j!=i){
ll tmp=a[j][i];
rep(k,i,m){
a[j][k]=(a[j][k]-tmp*a[i][k]%mod+mod)%mod;
}
}
}
// puts("-------");
// rep(k,1,n)
// rep(j,1,m){
// cout<<a[k][j]<<" \n"[j==m];
// }
}
}
int main(){
// cout<<sizeof(inv)/1024./1024;
// cout<<(mod-2*ksm(3,mod-2)%mod)<<hvie;
// inv[1]=1;
// rep(i,2,N*301) inv[i]=(mod-(mod/i)*inv[mod%i]%mod)%mod;
dwn(_,yh(),1){
scanf("%d",&n);
m=n*2;
rep(i,1,n){
ll sum=0;
rep(j,1,n)
sum+=(p[i][j]=yh());
rep(j,1,n) p[i][j]=ksm(sum,mod-2,p[i][j]);
}
rep(i,1,n){
rep(j,1,n){
if(i!=j) a[i][j]=p[i][j];
else a[i][j]=mod-1;
}
}
INV();
// puts("??");
// rep(i,1,n){
// rep(j,1,m){
// cout<<a[i][j]<<" \n"[j==m];
// }
// }
// puts("??");
rep(t,1,n){
rep(i,1,n) v[i]=(i==t)?(mod-p[i][i])%mod:0;
rep(i,1,n){
ll x=0;
rep(j,1,n){
x=(x+a[i][j+n]*v[j]%mod)%mod;
}
ans[i][t]=x;
}
}
rep(i,1,n){
rep(j,1,n){
cout<<(ans[i][j]%mod+mod)%mod<<" \n"[j==n];
}
}
}
return 0;
}
1006. Cut Tree
签到模拟题,略
1007. Banzhuan
【题意】
要在一个 n × n × n n\times n\times n n×n×n的空间中放 1 × 1 × 1 1\times 1\times 1 1×1×1的小方块,放置在 ( x , y , z ) (x,y,z) (x,y,z)的权值是 x × y 2 × z x\times y^2\times z x×y2×z,可以虚空放,但是会下落。要最后放成三视图都是正方形,问最小和最大权值。
n ≤ 1 0 18 n\leq 10^{18} n≤1018
【思路】
最大的很容易求,就全部在 z = n z=n z=n的位置放就行。
最小的也很容易求,不难发现最小的权值就是先放最下面一层,然后竖着的话放 x = 1 x=1 x=1和 y = 1 y=1 y=1整个面就行,注意第 1 1 1行 1 1 1列是不用放的。
【参考代码】
#include<bits/stdc++.h>
#define pb push_back
#define mkp make_pair
#define fi first
#define se second
#define ri register int
#define int long long
using namespace std;
typedef double db;
typedef long long ll;
typedef pair<int,int> pii;
typedef pair<int,ll> pil;
const int N=1e5+10,M=3e5+10,mod=1e9+7;
ll inv2,inv6;
ll mul(ll x,ll y){return 1ll*x*y%mod;}
ll qpow(ll x,ll y)
{
ll ret=1;
for(;y;y>>=1,x=mul(x,x))
if(y&1) ret=mul(ret,x);
return ret;
}
ll calc1(ll x)
{
x%=mod;
return (x+1)*x%mod*inv2%mod;
}
ll calc2(ll x)
{
x%=mod;
return x*(x+1)%mod*(2ll*x%mod+1ll)%mod*inv6%mod;
}
signed main()
{
inv2=qpow(2,mod-2);inv6=qpow(6,mod-2);
int T;scanf("%lld",&T);
while(T--)
{
ll n;
scanf("%lld",&n);n%=mod;
//ll ans1=(calc1(n)*calc2(n)%mod*2ll%mod+calc1(n)*calc1(n)%mod)%mod-(calc1(n)+calc2(n)+calc2(n))%mod+1ll;
//ll ans1=(2ll*calc1(n)%mod*calc2(n)%mod+calc1(n)*calc1(n)%mod-2ll*calc1(n)%mod-calc2(n)%mod+1)%mod;
//ll ans2=calc1(n)*calc1(n)%mod*calc2(n)%mod;
ll ans1=((2ll*calc1(n)%mod*calc2(n)%mod+calc1(n)*calc1(n)%mod-2ll*calc1(n)%mod-calc2(n)%mod+1)%mod-calc1(n)+1)%mod;
ll ans2=n*n%mod*calc2(n)%mod*calc1(n)%mod;
ans1=(ans1%mod+mod)%mod;
ans2=(ans2%mod+mod)%mod;
printf("%lld\n%lld\n",ans1,ans2);
}
return 0;
}
1008. Supermarket
【题意】
这个题意十分诡异。
n n n个物品, m m m个买物品的方式, a i a_i ai用一个 n n n位二进制描述
求
∑
S
∑
T
P
(
T
∣
S
)
\sum_S\sum_T P(T|S)
S∑T∑P(T∣S)
也就是我们先确定一个
S
S
S和一个
T
T
T,选到
S
S
S的所有
a
i
a_i
ai里面再随便选一个包含
T
T
T的概率和。
n ≤ 20 , m ≤ 2 × 1 0 5 n\leq 20,m\leq 2\times10^5 n≤20,m≤2×105
【思路】
∑
S
∑
T
∑
S
∈
a
i
,
T
∈
a
i
1
∑
S
∈
a
i
1
\sum_S\sum_T\frac {\sum_{S\in a_i,T\in a_i}1}{\sum_{S\in a_i} 1}
S∑T∑∑S∈ai1∑S∈ai,T∈ai1
也就是
∑
S
∑
S
∈
a
i
∑
T
∈
a
i
1
∑
S
∈
a
i
1
\sum_S\frac {\sum_{S\in a_i}\sum_{T\in a_i}1}{\sum_{S\in a_i} 1}
S∑∑S∈ai1∑S∈ai∑T∈ai1
考虑设
g
(
S
)
=
∑
S
∈
a
i
1
g(S)=\sum_{S\in a_i} 1
g(S)=∑S∈ai1,实际上就是对于每个
a
i
a_i
ai,找到它的所有子集给它加上1。
对分母,我们发现实际上
T
T
T的选取和
S
S
S是无关的,而和
a
i
a_i
ai有关,因此对于每个
a
i
a_i
ai我们可以求出来
f
(
S
)
=
∑
T
∈
a
i
=
2
popcount
(
a
i
)
f(S)=\sum_{T\in a_i}=2^{\text{popcount}(a_i)}
f(S)=T∈ai∑=2popcount(ai)
其中
popcount
(
a
i
)
\text{popcount}(a_i)
popcount(ai)表示
a
i
a_i
ai中1的个数。
那么分母对每个 S S S找到所有能贡献给它的 f f f,这个同样从反面考虑,对每个 a i a_i ai,我们找到它的所有子集加上 f ( a i ) f(a_i) f(ai)就行。
于是两个都是高维前缀和问题。
复杂度 O ( n 2 n ) O(n2^n) O(n2n)
1009. Array
【题意】
给定一个长度为 n n n的序列,问有多少个区间满足区间众数的出现次数严格大于长度的一半。
n ≤ 1 0 6 n\leq 10^6 n≤106
【思路】
考虑每个数成为众数贡献的区间个数。当前考虑 x x x,我们可以将 x x x看作+1,其他数字看作-1,那么问题就变成了区间和大于0的区间个数,前缀和一下可以变为:对于每个 r r r, s u m [ r ] > s u m [ l − 1 ] sum[r]>sum[l-1] sum[r]>sum[l−1]的 l l l的个数和。
用线段树暴力做这个问题,就是 O ( n log n ) O(n\log n) O(nlogn)每种数字,我们并不能接受,因此考虑我们仅在+1的地方统计答案。
由于每个 s u m [ r ] sum[r] sum[r]对应的可行的 s u m sum sum是一个前缀,即 [ − i n f , s u m [ r ] − 1 ] [-inf,sum[r]-1] [−inf,sum[r]−1],而连续一段-1,实际上是前缀和的前缀和。
这就是一个经典问题了,差分一次可以用线段树做到 O ( n log n ) O(n\log n) O(nlogn),但常数太大了,事实上可以通过差分两次做到树状数组的 O ( n log n ) O(n\log n) O(nlogn)就可以过了。
详细可以看[BJOI2018链上二次求和],比如https://zhuanlan.zhihu.com/p/35963206
【参考代码】
#include<bits/stdc++.h>
#define pb push_back
#define mkp make_pair
#define fi first
#define se second
#define ri register int
using namespace std;
typedef double db;
typedef long long ll;
typedef pair<int,int> pii;
typedef pair<int,ll> pil;
const int N=1e6+10;
int n,m,MX;
int a[N],fg[N];
vector<int>vec[N],num;
ll ans;
void init()
{
for(auto i:num) vec[i].clear(),fg[i]=0;
num.clear();
ans=0;MX=0;
}
/*struct Seg
{
#define ls (x<<1)
#define rs (x<<1|1)
struct node
{
ll sum[2],tag;
node(ll a=0,ll b=0,ll c=0)
{
sum[0]=a;sum[1]=b;tag=c;
}
}t[N<<3];
void pushdown(int x,int l,int r)
{
if(!t[x].tag) return;
int mid=(l+r)>>1;ll val=t[x].tag;
t[ls].tag+=val;t[rs].tag+=val;
t[ls].sum[0]+=1ll*val*(mid-l+1);
t[rs].sum[0]+=1ll*val*(r-mid);
t[ls].sum[1]+=1ll*val*(mid+l-m)*(mid-l+1)/2;
t[rs].sum[1]+=1ll*val*(mid+r-m+1)*(r-mid)/2;
t[x].tag=0;
}
void pushup(int x)
{
t[x].sum[0]=t[ls].sum[0]+t[rs].sum[0];
t[x].sum[1]=t[ls].sum[1]+t[rs].sum[1];
}
void update(int x,int l,int r,int L,int R,int k)
{
if(L>R) return;
if(L<=l && r<=R)
{
t[x].tag+=k;
t[x].sum[0]+=1ll*(r-l+1)*k;
t[x].sum[1]+=1ll*(l+r-m)*(r-l+1)/2*k;
return;
}
pushdown(x,l,r);
int mid=(l+r)>>1;
if(L<=mid) update(ls,l,mid,L,R,k);
if(R>mid) update(rs,mid+1,r,L,R,k);
pushup(x);
}
ll query(int x,int l,int r,int L,int R,int op)
{
if(L>R) return 0;
if(L<=l && r<=R) return t[x].sum[op];
pushdown(x,l,r);
int mid=(l+r)>>1;ll ret=0;
if(L<=mid) ret+=query(ls,l,mid,L,R,op);
if(R>mid) ret+=query(rs,mid+1,r,L,R,op);
return ret;
}
#undef ls
#undef rs
}tr;
*/
struct BIT
{
#define lowbit(x) (x&(-x))
ll s1[N<<1],s2[N<<1],s3[N<<1];
ll query(int x)
{
ll ret=0;
for(int t=x;x>0;x-=lowbit(x))
ret+=s1[x]*(t+2)*(t+1)-s2[x]*(2*t+3)+s3[x];
return ret/2;
}
void update(int x,int d)
{
for(int t=x;x<=m;x+=lowbit(x))
{
s1[x]+=d;
s2[x]+=1ll*d*t;
s3[x]+=1ll*d*t*t;
}
}
}bit;
int main()
{
int T;scanf("%d",&T);
while(T--)
{
init();
scanf("%d",&n);m=n<<1|1;
for(int i=1;i<=n;++i) scanf("%d",&a[i]),vec[a[i]].pb(i),fg[a[i]]=1,MX=max(MX,a[i]);
for(int i=0;i<=MX;++i) if(fg[i]) num.pb(i);
for(auto i:num) vec[i].pb(n+1);
for(auto i:num)
{
int len=vec[i].size();
if(len<=1) continue;
ll st=0;
for(int j=0;j<len;++j)
{
ll r=2*j-st+n+1,l=2*j-vec[i][j]+1+n+1;
ans+=bit.query(r-1)-bit.query(l-2);
bit.update(l,1);bit.update(r+1,-1);
st=vec[i][j];
}
st=0;
for(int j=0;j<len;++j)
{
ll r=2*j-st+n+1,l=2*j-vec[i][j]+1+n+1;
bit.update(l,-1);bit.update(r+1,1);
st=vec[i][j];
}
}
printf("%lld\n",ans);
}
return 0;
}
/*int main()
{
int T;scanf("%d",&T);
while(T--)
{
init();
scanf("%d",&n);m=n<<1;
for(int i=1;i<=n;++i) scanf("%d",&a[i]),vec[a[i]].pb(i),fg[a[i]]=1,MX=max(MX,a[i]);
for(int i=0;i<=MX;++i) if(fg[i]) num.pb(i);
for(auto i:num) vec[i].pb(n+1);
for(auto i:num)
{
int len=vec[i].size();
if(len<=1) continue;
ll st=0,ed;
for(int j=0;j<len;++j)
{
ed=2*j+1-vec[i][j];
ans+=(st-ed+1)*tr.query(1,1,m,1,ed+n-1,0)+st*tr.query(1,1,m,ed+n,st-1+n,0)-tr.query(1,1,m,ed+n,st+n-1,1);
tr.update(1,1,m,ed+n,st+n,1);
st=ed+1;
}
st=0;
for(int j=0;j<len;++j)
{
ed=2*j+1-vec[i][j];
tr.update(1,1,m,ed+n,st+n,-1);
st=ed+1;
}
}
printf("%lld\n",ans);
}
return 0;
}*/
1010. Guess Or Not 2
不会,略了。
【参考代码】
std
#include<bits/stdc++.h>
using namespace std;
#define fi first
#define se second
#define mp make_pair
#define pb push_back
#define rep(i, a, b) for(int i=(a); i<(b); i++)
#define per(i, a, b) for(int i=(b)-1; i>=(a); i--)
#define sz(a) (int)a.size()
#define de(a) cout << #a << " = " << a << endl
#define dd(a) cout << #a << " = " << a << " "
#define all(a) a.begin(), a.end()
#define pw(x) (1ll<<(x))
#define endl "\n"
typedef long long ll;
typedef pair<int, int> pii;
typedef vector<int> vi;
typedef double db;
#define rep_it(it,x) for (__typeof((x).begin()) it=(x).begin(); it!=(x).end(); it++)
#define ____ puts("\n_______________\n\n")
#define debug(x) ____; cout<< #x << " => " << (x) << endl
#define debug_pair(x) cout<<"\n{ "<<(x).fir<<" , "<<(x).sec<<" }\n"
#define debug_arr(x,n) ____; cout<<#x<<":\n"; rep(i,0,n+1) cout<<#x<<"["<<(i)<<"] => "<<x[i]<<endl
#define debug_arr2(x,n,m) ____; cout<<#x<<":\n"; rep(i,0,n+1) rep(j,0,m+1) cout<<#x<<"["<<(i)<<"]["<<(j)<<"]= "<<x[i][j]<<((j==m)?"\n\n":" ")
#define debug_set(x) ____; cout<<#x<<": \n"; rep_it(it,x) cout<<(*it)<<" "; cout<<endl
#define debug_map(x) ____; cout<<#x<<": \n"; rep_it(it,x) debug_pair(*it)
void file_put() {
freopen("filename.in", "r", stdin);
freopen("filename.out", "w", stdout);
}
const int P=998244353,N=1e6+5;
int k,t,n,T; ll x[N],y[N],fac[N],ans,s;
ll Pow(ll x,ll k) {
assert(k>=0);
ll ret=1;
for (; k; k>>=1,x=x*x%P) if (k&1) ret=ret*x%P;
return ret;
}
ll Inv(ll x) {
assert(x>0);
return Pow(x,P-2);
}
void Mul(ll &x,ll y){
x*=y;
x%=P;
}
ll mul(ll x,ll y) {
return x*y%P;
}
void Add(ll &x,ll y) {
x+=y;
x%=P;
}
ll add(ll x,ll y) {
return (x+y)%P;
}
void work() {
ans=fac[k-1];
Mul(ans,Pow(t,k-1));
ll sum=0;
rep(i,1,k+1) {
Add(sum,mul(x[i],Inv(Pow(y[i],t))));
Mul(ans,mul(x[i],Inv(Pow(y[i],t+1))));
}
Mul(ans,Inv(Pow(sum,k)));
Add(ans,P);
}
int main() {
// file_put();
scanf("%d",&T);
fac[0]=1;
rep(i,1,N) fac[i]=fac[i-1]*i%P;
while (T--) {
scanf("%d%d",&k,&t),s=0;
rep(i,1,k+1) scanf("%lld",&x[i]);
rep(i,1,k+1) scanf("%lld",&y[i]),s+=y[i];
rep(i,1,k+1) Mul(y[i],Inv(s%P));
work();
printf("%lld\n",ans);
}
return 0;
}
1011. Jslgame
【题意】
n n n堆石子的NIM游戏,先手不能恰好拿 x x x个石子,后手不能恰好拿 y y y个石子,问谁赢。
【思路】
不会
【参考代码】
#include<bits/stdc++.h>
using namespace std;
const int N=1e5+5;
int a[N],t,n,x,y,cnt,sum;
int main(){
scanf("%d",&t);
while(t--){
cnt=sum=0;
scanf("%d%d%d",&n,&x,&y);
for(int i=1;i<=n;i++){
scanf("%d",&a[i]);
if(a[i]>=min(x,y)) cnt++;
}
if(x==y){
for(int i=1;i<=n;i++){
sum^=a[i]/(2*x)*x+a[i]%x;
}
}
else if(!cnt){
for(int i=1;i<=n;i++){
sum^=a[i];
}
}
else if(x<y&&cnt>=2){
sum=0;
}
else if(x<y&&cnt==1){
for(int i=1;i<=n;i++){
if(a[i]<x) sum^=a[i];
}
int mx=*max_element(a+1,a+n+1);
if(mx>sum&&sum<x&&mx-sum!=x) sum=1;
else sum=0;
}
else if(x>y){
sum=1;
}
if(sum) printf("Jslj\n");
else printf("yygqPenguin\n");
}
}
1012. Yet Another Matrix Problem
【题目】
给定 n , m n,m n,m,再令 r = n m r=n^m r=nm。
∀ x ∈ [ 0 , m ] \forall x\in[0,m] ∀x∈[0,m],问有多少个矩阵 A n , r , B r , n A_{n,r},B_{r,n} An,r,Br,n,每个数字在 [ 0 , m ] [0,m] [0,m],满足 A × B A\times B A×B得到的矩阵元素和为 x x x
n , m ≤ 1 0 5 n,m\leq 10^5 n,m≤105,答案对998244353取模
【思路】
我不会,我队友会。
【参考代码】
#include<bits/stdc++.h>
using namespace std;
#define fi first
#define se second
#define mp make_pair
#define pb push_back
#define rep(i, a, b) for(int i=(a); i<(b); i++)
#define per(i, a, b) for(int i=(b)-1; i>=(a); i--)
#define sz(a) (int)a.size()
#define de(a) cout << #a << " = " << a << endl
#define dd(a) cout << #a << " = " << a << " "
#define all(a) a.begin(), a.end()
#define pw(x) (1ll<<(x))
#define endl "\n"
typedef long long ll;
typedef pair<int, int> pii;
typedef vector<int> vi;
typedef double db;
#define rep_it(it,x) for (__typeof((x).begin()) it=(x).begin(); it!=(x).end(); it++)
#define ____ puts("\n_______________\n\n")
#define debug(x) ____; cout<< #x << " => " << (x) << endl
#define debug_pair(x) cout<<"\n{ "<<(x).fir<<" , "<<(x).sec<<" }\n"
#define debug_arr(x,n) ____; cout<<#x<<":\n"; rep(i,0,n+1) cout<<#x<<"["<<(i)<<"] => "<<x[i]<<endl
#define debug_arr2(x,n,m) ____; cout<<#x<<":\n"; rep(i,0,n+1) rep(j,0,m+1) cout<<#x<<"["<<(i)<<"]["<<(j)<<"]= "<<x[i][j]<<((j==m)?"\n\n":" ")
#define debug_set(x) ____; cout<<#x<<": \n"; rep_it(it,x) cout<<(*it)<<" "; cout<<endl
#define debug_map(x) ____; cout<<#x<<": \n"; rep_it(it,x) debug_pair(*it)
void file_put() {
freopen("filename.in", "r", stdin);
freopen("filename.out", "w", stdout);
}
const int P=998244353;
const int _N=300005; ll inv[_N<<2],fac[_N<<2],fac_inv[_N<<2];
inline ll add(ll x,ll y) { x+=y; return x%P; }
inline ll mul(ll x,ll y) { return (ll)x*y%P; }
inline ll Pow(ll x,ll k) { ll ans=1; for (;k;k>>=1,x=x*x%P) if (k&1) (ans*=x)%=P; return ans; }
inline void init_inv(int n) { inv[1]=1; rep(i,2,n+1) inv[i]=mul(P-P/i,inv[P%i]); }
inline void init_fac(int n) {
fac[0]=fac_inv[0]=1;
rep(i,1,n+1) fac[i]=mul(fac[i-1],i),fac_inv[i]=mul(fac_inv[i-1],inv[i]);
}
template <class V>
struct FT{
int n,nn; V w[2][_N<<2],rev[_N<<2],tmp;
inline int init_len(int _n) { for (n=1; n<=_n; n<<=1); return n; }
inline int Init(int _n) {
init_len(_n); if (n==nn) return n; nn=n;
V w0=Pow(3,(P-1)/n); w[0][0]=w[1][0]=1;
rep(i,1,n) w[0][i]=w[1][n-i]=mul(w[0][i-1],w0);
rep(i,0,n) rev[i]=(rev[i>>1]>>1)|((i&1)*(n>>1)); return n;
}
void FFT(V A[],int op){
rep(i,0,n) if (i<rev[i]) swap(A[i],A[rev[i]]);
for (int i=1; i<n; i<<=1)
for (int j=0,t=n/(i<<1); j<n; j+=i<<1)
for (int k=j,l=0; k<j+i; k++,l+=t) {
V x=A[k],y=mul(w[op][l],A[k+i]);
A[k]=add(x,y),A[k+i]=add(x-y,P);
}
if (op) { tmp=inv[n]; rep(i,0,n) A[i]=mul(A[i],tmp); }
}
};
template <class V>
struct Calculator{
FT<V> T; V X[_N<<2],Y[_N<<2],A[_N<<2],B[_N<<2],C[_N<<2];
inline void Fill(V a[],V b[],int n,int len) {
if (a!=b) memcpy(a,b,sizeof(V)*n); fill(a+n,a+len,0);
}
inline void Add(V a[],int n,V b[],int m,V c[],int t=1) {
n=max(n,m); rep(i,0,n) c[i]=add(a[i],t*b[i]);
}
inline void Dot_Mul(V a[],V b[],int len,V c[]) {
rep(i,0,len) c[i]=mul(a[i],b[i]);
}
inline void Dot_Mul(V a[],int len,V v,V c[]) {
rep(i,0,len) c[i]=mul(a[i],v);
}
inline void Mul(V a[],int n,V b[],int m,V c[]) {
int len=T.Init(n+m-1); Fill(X,a,n,len),Fill(Y,b,m,len);
T.FFT(X,0),T.FFT(Y,0),Dot_Mul(X,Y,len,c),T.FFT(c,1);
}
inline void Int(V a[],int n,V b[]) {
per(i,0,n) b[i+1]=mul(a[i],inv[i+1]); b[0]=0;
}
inline void Der(V a[],int n,V b[]) {
rep(i,1,n) b[i-1]=mul(a[i],i); b[n-1]=0;
}
inline void Inv(V a[],int n,V b[]) {
if (n==1) { b[0]=Pow(a[0],P-2),b[1]=0; return; }
Inv(a,(n+1)>>1,b); int len=T.Init(2*n-1);
Fill(X,a,n,len),Fill(b,b,n,len),T.FFT(X,0),T.FFT(b,0);
rep(i,0,len) b[i]=mul(b[i],2-mul(b[i],X[i]));
T.FFT(b,1),Fill(b,b,n,len);
}
inline void Log(V a[],int n,V b[]) {
static V A[_N<<2],B[_N<<2];
Der(a,n,A),Inv(a,n,B),Mul(A,n,B,n,b);
Int(b,n,b),Fill(b,b,n,T.n);
}
inline void Exp(V a[],int n,V b[]) {
if (n==1) { b[0]=exp(a[0]),b[1]=0; return; }
Exp(a,(n+1)>>1,A),Log(A,n,B),Add(a,n,B,n,B,-1);
(B[0]+=1)%=P,Mul(A,n,B,n,b),Fill(b,b,n,T.n);
}
inline void Power(V a[],int n,ll k,V b[]) {
Log(a,n,C),Dot_Mul(C,n,k,C),Exp(C,n,b),Fill(b,b,n,T.n);
}
inline void Dirichlet_Mul(V a[],int n,V b[],int m,V c[],int L) {
int len=min((ll)n*m,(ll)L); Fill(c,c,0,L+1);
rep(i,1,n+1) for (int j=1; j<=m && (ll)i*j<=len; j++)
c[i*j]=add(c[i*j],mul(a[i],b[j]));
}
};
ll C(ll n, ll m){
if (n<0 || m<0 || n<m) return 0;
return fac[n]*fac_inv[m]%P*fac_inv[n-m]%P;
}
ll POW(ll x, ll k){
ll ret=1;
for (;k;k>>=1,x=x*x%P) if (k&1) ret=ret*x%P;
return ret;
}
ll POW(ll x, ll k, ll P){
ll ret=1;
for (;k;k>>=1,x=x*x%P) if (k&1) ret=ret*x%P;
return ret;
}
ll a[_N<<2],b[_N<<2],c[_N<<2],s=0;
int n,m,k,_T;
Calculator<ll> T;
int main() {
//file_put();
init_inv(800005);
init_fac(800005);
scanf("%d",&_T);
while (_T--) {
scanf("%d%d",&n,&m);
rep(k,0,m+1) a[k]=C(k+n-1,n-1), s=(s+a[k])%P;
s=POW(m+1,n);
//debug_arr(a,20);
T.Dirichlet_Mul(a,m,a,m,b,m);
b[0]=2*s-1;
T.Dot_Mul(b,m+1,POW(2*s-1,P-2),b);
T.Power(b,m+1,POW(n,m),c);
//T.Power(b,m+1,n,c);
T.Dot_Mul(c,m+1,POW(2*s-1,POW(n,m,P-1)),c);
rep(k,0,m+1) printf("%lld\n",(c[k]+P)%P);
}
return 0;
}
1013. Penguin Love Tour
【题意】
一颗 n n n个点带边权和点权的树,现在每个点 i i i可以使与自己连接的边边权减少 a i a_i ai,最少减为0,问最长路最短是多少。
n , a i , w i ≤ 1 0 5 n,a_i,w_i\leq 10^5 n,ai,wi≤105
【思路】
一个可以做的很暴力的题。
问题等价于操作完 n n n个点后树的直径最短,我们考虑某个点 x x x处合并直径。
处理出 f [ i ] [ 0 / 1 ] f[i][0/1] f[i][0/1]表示点 i i i的1权值贡献给父边/某个儿子后,往叶子走的最长链的最小值, f r [ i ] [ 0 / 1 ] fr[i][0/1] fr[i][0/1]表示点 i i i的权值贡献給夫边/某个儿子后往父边走的最长链最小值。
f f f是很容易处理出来的,只需要记往下的最长链和次长链,然后考虑当前点往下贡献给每条边以后分别的答案是什么就行。
f r [ i ] [ 0 ] fr[i][0] fr[i][0]也很容易处理,但是我们要在考虑 f f f的贡献后来转移,因此需要记录最长链、次长链、第三长链,同时这个链可以是父边方向来的。
但是 f r [ i ] [ 1 ] fr[i][1] fr[i][1]的贡献极其复杂。
考虑合并,我们不仅要记最长链的长度,还要记是从哪里来的,这样转移到的儿子才不会重复。为了解决这个问题,我们可以通过对考虑该点贡献的每种情况下的最长链排序,然后再对 f r [ i ] [ 1 ] fr[i][1] fr[i][1]进行更新来得到一个复杂度正确的做法。
写这个题的时候由于心态爆炸,用了一个map来记录最长链是什么,以及当前情况下我这个点的权值贡献到了哪条边,当我要dfs的儿子是这两个中的一个,就暴力更新 f r [ i ] fr[i] fr[i],否则可以通过之前的信息来更新。
然后总的复杂度(排序做法)应该是 O ( n log n ) O(n\log n) O(nlogn)的,常数比较大。
std的做法:
【参考代码】
屎山
#include<bits/stdc++.h>
#define pb push_back
#define mkp make_pair
#define fi first
#define se second
#define ri register int
using namespace std;
typedef double db;
typedef long long ll;
typedef pair<int,int> pii;
typedef pair<int,ll> pil;
const int N=1e5+10,mod=998244353;
const ll inf=0x3f3f3f3f3f3f3f3f;
int n,tot;
int head[N],a[N];
ll fans,f[N][2],fr[N][2];
struct Tway
{
int v,w,nex;
Tway(int v=0,int w=0,int nx=0):v(v),w(w),nex(nx){}
}e[N<<1];
void add(int u,int v,int w)
{
e[++tot]=Tway(v,w,head[u]);head[u]=tot;
}
struct node
{
ll val,v;
node(ll val=0,ll v=0):val(val),v(v){}
};
ll check(ll x,ll y,ll z)
{
return x+y+z-min(x,min(y,z));
}
void gmin(ll &x,ll y){x=min(x,y);}
void dfs1(int x,int fa)//0:give fa,1:give son
{
int son=0;
for(int i=head[x];i;i=e[i].nex)
{
int v=e[i].v;
if(v==fa) continue;
++son;dfs1(v,x);
}
if(!son)
{
f[x][0]=0,f[x][1]=inf;
return;
}
node fir=node(0,0),sec=node(0,0);
for(int i=head[x];i;i=e[i].nex)
{
int v=e[i].v,w=e[i].w;
if(v==fa) continue;
ll now=min(f[v][0]+max(w-a[v],0),f[v][1]+w);
if(now>=fir.val) sec=fir,fir=node(now,v);
else if(now>=sec.val) sec=node(now,v);
}
f[x][0]=fir.val;f[x][1]=inf;
for(int i=head[x];i;i=e[i].nex)
{
int v=e[i].v,w=e[i].w;
if(v==fa) continue;
ll t=min(max(0,w-a[v]-a[x])+f[v][0],max(0,w-a[x])+f[v][1]);
if(v==fir.v) gmin(f[x][1],max(t,sec.val));
else gmin(f[x][1],max(t,fir.val));
}
}
void dfs2(int x,int fa,int tw)//fr 0 = pushdown
{
node fir=node(0,0),sec=node(0,0),thd=node(0,0);
int son=0;
for(int i=head[x];i;i=e[i].nex)
{
int v=e[i].v,w=e[i].w;
if(v==fa) continue;
++son;
ll now=min(f[v][0]+max(w-a[v],0),f[v][1]+w);
if(now>=fir.val)
{
thd=sec;sec=fir;
fir=node(now,v);
}
else if(now>=sec.val)
{
thd=sec;
sec=node(now,v);
}
else if(now>=thd.val)
{
thd=node(now,v);
}
}
if(!son) return;
if(fa)
{
ll now=min(fr[fa][0]+max(tw-a[fa],0),fr[fa][1]+tw);
if(now>=fir.val)
{
thd=sec;sec=fir;
fir=node(now,fa);
}
else if(now>=sec.val)
{
thd=sec;
sec=node(now,fa);
}
else if(now>=thd.val)
{
thd=node(now,fa);
}
}
//fr[0] done (contribute down)
//现在不仅要记最长链的长度,还要记是从哪来的,这样转移到孩子才不会重复
//其实这一步可以通过排序来搞
map<int,bool>mp;mp.clear();
ll ans=inf;ll remval=inf;int rem1=0,rem2=0;
for(int i=head[x];i;i=e[i].nex)
{
int v=e[i].v,w=e[i].w;
if(v==fa) continue;
ll now=min(max(0,w-a[x]-a[v])+f[v][0],max(0,w-a[x])+f[v][1]);
if(v==fir.v) gmin(ans,check(now,sec.val,thd.val));
else if(v==sec.v) gmin(ans,check(now,fir.val,thd.val));
else gmin(ans,check(now,fir.val,sec.val));
ll tmp,tmpid;
if(v==fir.v)
{
if(now>=sec.val) tmp=now,tmpid=v;
else tmp=sec.val,tmpid=sec.v;
}
else
{
if(now>=fir.val) tmp=now,tmpid=v;
else tmp=fir.val,tmpid=fir.v;
}
if(tmp<=remval) remval=tmp,rem1=v,rem2=tmpid;
mp[rem1]=mp[rem2]=1;
}
if(fa)
{
ll now=min(max(0,tw-a[x]-a[fa])+fr[fa][0],max(0,tw-a[x])+fr[fa][1]);
if(fa==fir.v) gmin(ans,check(now,sec.val,thd.val));
else if(fa==sec.v) gmin(ans,check(now,fir.val,thd.val));
else gmin(ans,check(now,fir.val,sec.val));
ll tmp,tmpid;
if(fa==fir.v)
{
if(now>=sec.val) tmp=now,tmpid=fa;
else tmp=sec.val,tmpid=sec.v;
}
else
{
if(now>=fir.val) tmp=now,tmpid=fa;
else tmp=fir.val,tmpid=fir.v;
}
if(tmp<=remval) remval=tmp,rem1=fa,rem2=tmpid;
mp[rem1]=mp[rem2]=1;
}
fans=max(fans,ans);
for(int i=head[x];i;i=e[i].nex)
{
int v=e[i].v,w=e[i].w;
if(v==fa) continue;
if(v==fir.v) fr[x][0]=sec.val;
else fr[x][0]=fir.val;
if(mp[v])
{
ll t=inf;
for(int j=head[x];j;j=e[j].nex)
{
int nv=e[j].v,nw=e[j].w;
if(nv==v || nv==fa) continue;
ll now=min(max(0,nw-a[x]-a[nv])+f[nv][0],max(0,nw-a[x])+f[nv][1]);
if(nv==fir.v || v==fir.v)
{
if(nv==sec.v || v==sec.v) gmin(t,max(now,thd.val));
else gmin(t,max(now,sec.val));
}
else gmin(t,max(now,fir.val));
}
if(fa)
{
ll now=min(max(0,tw-a[x]-a[fa])+fr[fa][0],max(0,tw-a[x])+fr[fa][1]);
if(fa==fir.v || v==fir.v)
{
if(fa==sec.v || v==sec.v) gmin(t,max(now,thd.val));
else gmin(t,max(now,sec.val));
}
else gmin(t,max(now,fir.val));
}
fr[x][1]=t;
}
else fr[x][1]=remval;
dfs2(v,x,w);
}
}
void init()
{
for(int i=0;i<=n;++i)
{
f[i][0]=f[i][1]=fr[i][0]=fr[i][1]=head[i]=0;
}
tot=0;fans=0;
}
int main()
{
//freopen("input.in","r",stdin);
//freopen("my.out","w",stdout);
int T;scanf("%d",&T);
for(int tt=1;tt<=T;++tt)
{
init();
scanf("%d",&n);
for(int i=1;i<=n;++i) scanf("%d",&a[i]);
for(int i=1,u,v,w;i<n;++i)
{
scanf("%d%d%d",&u,&v,&w);
add(u,v,w);add(v,u,w);
}
if(n==1)
{
puts("0");
continue;
}
else if(n==2)
{
printf("%d\n",max(0,e[1].w-a[1]-a[2]));
continue;
}
dfs1(1,0);dfs2(1,0,0);
printf("%lld\n",fans);
}
return 0;
}
std
#include<bits/stdc++.h>
using namespace std;
#define fi first
#define se second
#define mp make_pair
#define pb push_back
#define rep(i, a, b) for(int i=(a); i<(b); i++)
#define per(i, a, b) for(int i=(b)-1; i>=(a); i--)
#define sz(a) (int)a.size()
#define de(a) cout << #a << " = " << a << endl
#define dd(a) cout << #a << " = " << a << " "
#define all(a) a.begin(), a.end()
#define pw(x) (1ll<<(x))
#define endl "\n"
typedef long long ll;
typedef pair<int, int> pii;
typedef vector<int> vi;
typedef double db;
const int N=1e5+5;
const ll MAX=1e11;
vector<pii> v[N];
int a[N];
ll l,r,mid;
ll dp[N][2];
void dfs(int now,int pre){
dp[now][0]=dp[now][1]=0;
for(auto i:v[now]){
int to=i.first,w=i.second;
if(to==pre) continue;
dfs(to,now);
if(dp[now][0]+min(dp[to][1]+max(0,w-a[now]),dp[to][0]+max(0,w-a[to]-a[now]))<=mid){
if(dp[now][1]+min(dp[to][1]+w,dp[to][0]+max(0,w-a[to]))<=mid)
dp[now][1]=max(dp[now][1],min(max(dp[now][0],min(dp[to][1]+max(0,w-a[now]),dp[to][0]+max(0,w-a[to]-a[now]))),min(dp[to][1]+w,dp[to][0]+max(0,w-a[to]))));
else dp[now][1]=max(dp[now][0],min(dp[to][1]+max(0,w-a[now]),dp[to][0]+max(0,w-a[to]-a[now])));
}
else if(dp[now][1]+min(dp[to][1]+w,dp[to][0]+max(0,w-a[to]))<=mid)
dp[now][1]=max(dp[now][1],min(dp[to][1]+w,dp[to][0]+max(0,w-a[to])));
else dp[now][1]=MAX;
if(dp[now][0]+min(dp[to][1]+w,dp[to][0]+max(0,w-a[to]))<=mid)
dp[now][0]=max(dp[now][0],min(dp[to][1]+w,dp[to][0]+max(0,w-a[to])));
else dp[now][0]=MAX;
}
}
int main(){
int t;
scanf("%d",&t);
while(t--){
int n;
scanf("%d",&n);
for(int i=1;i<=n;i++){
scanf("%d",&a[i]);
}
l=0,r=0;
for(int i=1;i<n;i++){
int x,y,z;
scanf("%d%d%d",&x,&y,&z);
v[x].pb({y,z});v[y].pb({x,z});
r+=z;
}
ll ans=-1;
while(l<=r){
mid=(l+r)/2;
dfs(1,0);
if(dp[1][0]<=mid||dp[1][1]<=mid){
ans=mid;
r=mid-1;
}
else{
l=mid+1;
}
}
printf("%lld\n",ans);
for(int i=1;i<=n;i++){
v[i].clear();
}
}
}