problem
solution
想必已经看了 CQOI2017 老C的键盘 一题题解了。
这里直接考虑优化状态转移方程。
我们发现 f ( u , i ) , f ( v , j ) f(u,i),f(v,j) f(u,i),f(v,j),当枚举 j j j 后,对应的 k k k 是一段连续区间。
f ( u , i ) ∗ f ( v , j ) ∗ ( k − 1 i − 1 ) ∗ ( s i z u + s i z v − k s i z u − i ) f(u,i)*f(v,j)*\binom{k-1}{i-1}*\binom{siz_u+siz_v-k}{siz_u-i} f(u,i)∗f(v,j)∗(i−1k−1)∗(sizu−isizu+sizv−k) 转移贡献变化的只有 k k k,且有两项都与之挂钩。
我们不妨反过来固定 k k k,能转移到这个 k k k 的肯定也是一段连续的 j j j。
所以我们求出来 f ( v , j ) f(v,j) f(v,j) 后,直接做个前缀和, f ( v , j ) = ∑ j ′ ≤ j f ( v , j ′ ) f(v,j)=\sum_{j'\le j}f(v,j') f(v,j)=∑j′≤jf(v,j′)。
以要求 h u < h v h_u<h_v hu<hv 为例。
其原来是 i + j ≤ k ≤ i + s i z v i+j\le k\le i+siz_v i+j≤k≤i+sizv,固定 i , k i,k i,k 后,能转移过来的 j j j 范围则为 1 ∼ k − i 1\sim k-i 1∼k−i。
另一种则是 k − i < j ≤ s i z v k-i<j\le siz_v k−i<j≤sizv。
两种的 k k k 范围也稍稍不同。
具体可以看代码实现。
code
#include <bits/stdc++.h>
using namespace std;
#define int long long
#define mod 1000000007
#define maxn 1005
char s[5];
int n;
int c[maxn][maxn], f[maxn][maxn], g[maxn], siz[maxn];
vector < pair < int, int > > G[maxn];
void dfs( int u, int fa ) {
f[u][1] = siz[u] = 1;
for( int i = 0;i < G[u].size();i ++ ) {
int v = G[u][i].first, w = G[u][i].second;
if( v == fa ) continue;
dfs( v, u );
memcpy( g, f[u], sizeof( f[u] ) );
memset( f[u], 0, sizeof( f[u] ) );
for( int i = 1;i <= siz[u];i ++ )
if( w ) {
for( int k = i + 1;k <= i + siz[v];k ++ )
(f[u][k] += f[v][k - i] % mod * g[i] % mod * c[k - 1][i - 1] % mod * c[siz[u] + siz[v] - k][siz[u] - i]) %= mod;
}
else {
for( int k = i;k < i + siz[v];k ++ )
(f[u][k] += (f[v][siz[v]] - f[v][k - i]) % mod * g[i] % mod * c[k - 1][i - 1] % mod * c[siz[u] + siz[v] - k][siz[u] - i]) %= mod;
}
siz[u] += siz[v];
}
for( int i = 1;i <= siz[u];i ++ ) (f[u][i] += f[u][i - 1]) %= mod;
}
signed main() {
for( int i = 0;i <= 1000;i ++ ) {
c[i][0] = c[i][i] = 1;
for( int j = 1;j < i;j ++ )
c[i][j] = (c[i - 1][j - 1] + c[i - 1][j]) % mod;
}
int T;
scanf( "%lld", &T );
while( T -- ) {
scanf( "%lld", &n );
memset( f, 0, sizeof( f ) );
for( int i = 0;i < n;i ++ ) G[i].clear();
for( int i = 1, u, v;i < n;i ++ ) {
scanf( "%lld %s %lld", &u, s, &v );
G[u].push_back( make_pair( v, s[0] == '<' ) );
G[v].push_back( make_pair( u, s[0] == '>' ) );
}
dfs( 0, 0 );
printf( "%lld\n", ( f[0][n] + mod ) % mod );
}
return 0;
}