ABC323G 题解
G - Inversion of Tree
题意:给定一个1-n的排列p,问有多少棵连接1-n节点的无向树,使得其中有k条边 ( u 1 , v 1 ) , . . . , ( u k , v k ) (u_1,v_1),...,(u_k,v_k) (u1,v1),...,(uk,vk)
其中 u i < v i u_i<v_i ui<vi 满足: p u i > p v i p_{u_i}>p_{v_i} pui>pvi
解答:如果不加任何限制条件,就是求完全图 K n K_n Kn有多少个生成树,由矩阵树定理,这个问题很好解答,直接算出图对应的Laplace矩阵,然后求行列式即可。不过这个问题中有一些边比较特殊,我们希望知道含k个这样的边的生成树的个数,那我们在计算Laplace矩阵时可以这样:对 i < j i<j i<j若 p i < p j p_i<p_j pi<pj,则 L i j = 1 L_{ij}=1 Lij=1而若 p i > p j p_i>p_j pi>pj,则 L i j = x L_{ij}=x Lij=x,然后算行列式 p ( x ) = d e t ( L ) p(x)=det(L) p(x)=det(L),其中 x i x^i xi项的系数就是含i条特殊边的生成树个数。
但是算行列式并不是很简单的任务,我们把L看成常数矩阵和只含x的矩阵之和, L = A + x B L=A+xB L=A+xB,然后通过初等变换尽量将 A + x B A+xB A+xB变为 A ′ + x I A'+xI A′+xI(如果中间某行没有x,我们可以将这一行整体乘以x然后继续尝试变换。如果在变换过程中已经乘了x多于n次,可以直接返回一个全零多项式,参考 d e t ( A + x B ) det(A+xB) det(A+xB)的计算方法),然后通过Hessenberg法求 d e t ( A ′ + x I ) det(A'+xI) det(A′+xI)得到结果
代码:
#include<bits/stdc++.h>
using namespace std;
using i64 = long long;
template<typename T>
constexpr T power(T a,i64 b){
T ans=1;
for(;b;b/=2){
if(b%2==1){
ans*=a;
}
a*=a;
}
return ans;
}
template<int P>
struct MInt{
constexpr MInt(): x{} {}
constexpr MInt(i64 x): x{norm(x%P)} {}
int x;
constexpr int norm(int x) const{
if(x<0){
return x+P;
}
return x;
}
constexpr int val() const{
return x;
}
constexpr MInt inv() const{
return power(*this,P-2);
}
constexpr MInt operator-() const{
MInt res;
res.x=P-x;
return res;
}
constexpr MInt& operator+=(const MInt& rhs){
x=(x+rhs.x)%P;
return *this;
}
constexpr MInt& operator-=(const MInt& rhs){
x=norm(x-rhs.x);
return *this;
}
constexpr MInt& operator*=(const MInt& rhs){
x=(1ll*x*rhs.x)%P;
return *this;
}
constexpr MInt& operator/=(const MInt& rhs){
return *this*=rhs.inv();
}
friend constexpr MInt operator+(const MInt& lhs,const MInt& rhs){
MInt res=lhs;
res+=rhs;
return res;
}
friend constexpr MInt operator-(const MInt& lhs,const MInt& rhs){
MInt res=lhs;
res-=rhs;
return res;
}
friend constexpr MInt operator*(const MInt& lhs,const MInt& rhs){
MInt res=lhs;
res*=rhs;
return res;
}
friend constexpr MInt operator/(const MInt& lhs,const MInt& rhs){
MInt res=lhs;
res/=rhs;
return res;
}
friend std::istream& operator>>(std::istream& is,MInt& a){
i64 v;
is>>v;
a=MInt(v);
return is;
}
friend std::ostream& operator<<(std::ostream& os,const MInt& a){
os<<a.val();
return os;
}
friend constexpr bool operator==(const MInt& lhs,const MInt& rhs){
return lhs.x==rhs.x;
}
friend constexpr bool operator!=(const MInt& lhs,const MInt& rhs){
return lhs.x!=rhs.x;
}
};
// toUpperHessenberg
// O(n^3)
template<typename T>
void hessen(vector<vector<T>>& a){
const int n=a.size();
for(int i=0;i<n-1;i++){
for(int j=i+1;j<n;j++){
if(a[j][i]!=0){
swap(a[i+1],a[j]);
for(int k=0;k<n;k++){
swap(a[k][i+1],a[k][j]);
}
break;
}
}
if(a[i+1][i]==0){
continue;
}
T inv=T(1)/a[i+1][i];
for(int j=i+2;j<n;j++){
if(a[j][i]==0) continue;
T tmp=a[j][i]*inv;
for(int k=0;k<n;k++){
a[j][k]-=tmp*a[i+1][k];
}
for(int k=0;k<n;k++){
a[k][i+1]+=tmp*a[k][j];
}
}
}
}
// det(A+xI)
// O(n^3)
template<typename T>
vector<T> charPoly(vector<vector<T>> a){
const int n=a.size();
hessen(a);
vector<vector<T>> p(n+1);
p[0]={1};
for(int i=0;i<n;i++){
p[i+1].assign(i+2,0);
for(int j=0;j<i+1;j++){
p[i+1][j+1]+=p[i][j];
p[i+1][j]+=p[i][j]*a[i][i];
}
T prod=1;
for(int j=i-1;j>=0;j--){
prod*=-a[j+1][j];
const T t=prod*a[j][i];
for(int k=0;k<=j;k++){
p[i+1][k]+=t*p[j][k];
}
}
}
return p[n];
}
// det(A+xB)
// O(n^3)
template<typename T>
vector<T> detPoly(vector<vector<T>> a,vector<vector<T>> b){
const int n=a.size();
T prod=1;
int off=0;
for(int i=0;i<n;i++){
while(true){
for(int j=i;j<n;j++){
if(b[i][j]!=0){
for(int k=0;k<n;k++){
swap(b[k][i],b[k][j]);
swap(a[k][i],a[k][j]);
}
if(i!=j) prod*=-1;
break;
}
}
if(b[i][i]!=0){
break;
}
if(++off>n){
return vector<T>(n+1,0);
}
for(int j=0;j<n;j++){
b[i][j]=a[i][j];
a[i][j]=0;
}
for(int j=0;j<i;j++){
T t=b[i][j];
for(int k=0;k<n;k++){
a[i][k]-=t*a[j][k];
b[i][k]-=t*b[j][k];
}
}
}
prod*=b[i][i];
T t=1/b[i][i];
for(int j=0;j<n;j++){
a[i][j]*=t;
b[i][j]*=t;
}
for(int j=0;j<n;j++){
if(i==j) continue;
T s=b[j][i]/b[i][i];
for(int k=0;k<n;k++){
a[j][k]-=s*a[i][k];
b[j][k]-=s*b[i][k];
}
}
}
vector<T> p=charPoly(a);
vector<T> ans(n+1,0);
for(int i=0;i<=n-off;i++){
ans[i]=prod*p[i+off];
}
return ans;
}
constexpr int P = 998244353;
using Z = MInt<P>;
int main(){
ios::sync_with_stdio(false);
cin.tie(nullptr);
int n;
cin>>n;
vector<int> p(n);
vector<vector<Z>> a(n,vector<Z>(n,0));
vector<vector<Z>> b(n,vector<Z>(n,0));
for(int i=0;i<n;i++){
cin>>p[i];
}
for(int i=0;i<n;i++){
for(int j=i+1;j<n;j++){
if(p[i]<p[j]){
a[i][i]+=1;
a[i][j]-=1;
a[j][i]-=1;
a[j][j]+=1;
}
else{
b[i][i]+=1;
b[i][j]-=1;
b[j][i]-=1;
b[j][j]+=1;
}
}
}
a.resize(n-1);
b.resize(n-1);
vector<Z> ans=detPoly(a,b);
for(int i=0;i<n;i++){
cout<<ans[i]<<" \n"[i==n-1];
}
return 0;
}