题目:
给一棵N个节点的树,编号从1到N,再给定m对点(u,v),你要将树上的每条无向边变为有向边,使得给定的点对都满足u能到达v或v能到达u。问有多少种不同的方案,答案对(1e9+7)求余。
输入格式
第一行两个正整数N and M(1 ≤ N, M≤ 3*1e5 ),表示树的结点个数,和点对的个数。
接下来N-1行,每行两个整数,表示树上的边。
接下来M行,每行两个不同的正整数(ai,bi),表示对应的点对,点对互不相同。
输出格式
一行一个数,表示不同的方案数模1e9+7
20%的数据树是一个链,即第i个点连在i+1上。
40%的数据N,M≤ 5*1e3 .
样例输入:
4 1
1 2
2 3
3 4
2 4
样例输出:
4
思路:
这道题肯定不能爆搜出每一种情况,
因为是树上对两个点进行操作,所以考虑LCA
结合题意,容易知道,u到v的树上路径其实我们看将它看作一条边
接着我们就可以以此建一棵树,
但是如果原树中还有点没有考虑到怎么办,
答案很见到,直接统计边数
最后还需要考虑树上的路径是否有冲突
最终的答案就是2^(新树的边数)
给出代码辅助理解
#include<iostream>
#include<algorithm>
#include<cstring>
#include<vector>
#include<cstdio>
using namespace std;
void read(int &x)
{
x=0;
int f=1;
char c=getchar();
while('0'>c||c>'9')
{
if(c=='-')
f=-1;
c=getchar();
}
while('0'<=c&&c<='9')
{
x=(x<<3)+(x<<1)+c-'0';
c=getchar();
}
x*=f;
}
void write(int x)
{
if(x<0)
{
putchar('-');
write(-x);
return;
}
if(x<10)
putchar(x+'0');
else
{
write(x/10);
putchar(x%10+'0');
}
}
const int mod=1e9+7;
const int dp_k=25;
int n,m;
int dep[300005];
int hig[300005];
int vis[300005];
int dp[300005][30];
long long ans=1;
vector<int> g[300005];
vector<pair<int,int> > min_tre[300005];
void dfs_depth(int u,int fat)
{
dp[u][0]=fat;
dep[u]=dep[fat]+1;
for(int i=0;i<g[u].size();i++)
{
if(g[u][i]!=fat)
dfs_depth(g[u][i],u);
}
}
int lca(int u,int v)
{
if(dep[u]<dep[v])
swap(u,v);
for(int i=dp_k;i>=0;i--)
if((dep[u]-dep[v])>>i&1)
u=dp[u][i];
if(u==v)
return u;
for(int i=dp_k;i>=0;i--)
{
if(dp[u][i]!=dp[v][i])
{
u=dp[u][i];
v=dp[v][i];
}
}
return dp[u][0];
}
int dfs_need(int u,int fa)
{
for(int i=0;i<g[u].size();i++)
{
int v=g[u][i];
if(v!=fa)
{
int t=dfs_need(v,u);
hig[u]=min(hig[u],t);
if(t<dep[u])
{
min_tre[v].push_back(make_pair(u,0));
min_tre[u].push_back(make_pair(v,0));
}
}
}
return hig[u];
}
bool dfs(int u,int v)
{
if(vis[u]!=-1)
{
if(vis[u]==v)
return 1;
else
return 0;
}
vis[u]=v;
for(int i=0;i<min_tre[u].size();i++)
{
pair<int,int> so=min_tre[u][i];
if(!dfs(so.first,v^so.second))
return 0;
}
return 1;
}
int main()
{
//freopen("usmjeri.in","r",stdin);
//freopen("usmjeri.out","w",stdout);
memset(vis,-1,sizeof(vis));
read(n);
read(m);
for(int i=1;i<=n-1;i++)
{
int u,v;
read(u);
read(v);
g[u].push_back(v);
g[v].push_back(u);
}
dfs_depth(1,0);
for(int j=1;j<=dp_k;j++)
for(int i=1;i<=n;i++)
dp[i][j]=dp[dp[i][j-1]][j-1];
for(int i=1;i<=n;i++)
hig[i]=dep[i];
for(int i=1;i<=m;i++)
{
int a,b;
read(a);
read(b);
int c=lca(a,b);
hig[a]=min(hig[a],dep[c]);
hig[b]=min(hig[b],dep[c]);
if(a!=c&&b!=c)
{
min_tre[a].push_back(make_pair(b,1));
min_tre[b].push_back(make_pair(a,1));
}
}
dfs_need(1,0);
for(int i=2;i<=n;i++)
{
if(vis[i]==-1)
{
if(!dfs(i,0))
{
write(0);
return 0;
}
ans=2*ans%mod;
}
}
write(ans);
return 0;
}