一颗有n个点的树,然后上面有m条关系链[u, v](就是u到v的路径),选出x条链,相互没有边重合,但是点是可以重合的。求x的最大值。
思路:树形dp+状压dp。
对于以u为根的子树,dp[u]表示这棵子树能有的最大x的值(也就是被选中的x条链满足题目条件且每条链的所有边都在子树中,不经过[fa, u])。
第一部分就是dp[u] = ∑dp[v] && v is a child of u
第二部分就看以v为根节点的子树中是否有一u结尾的链存在,有就dp[u]++,然后删除v子树中的所有点
第三部分就是跨越u的链了(两棵子树中的点行成的),这就先找出那两棵树中的点是可以构成链的,然后状压枚举最优的方案
最后就是存下u后面的所有孩子节点了,包括u本身
/*****************************************
Author :Crazy_AC(JamesQi)
Time :2016
File Name :
*****************************************/
// #pragma comment(linker, "/STACK:1024000000,1024000000")
#include <iostream>
#include <algorithm>
#include <iomanip>
#include <sstream>
#include <string>
#include <stack>
#include <queue>
#include <deque>
#include <vector>
#include <map>
#include <set>
#include <cstdio>
#include <cstring>
#include <cmath>
#include <cstdlib>
#include <climits>
using namespace std;
#define MEM(x,y) memset(x, y,sizeof x)
#define pk push_back
#define lson rt << 1
#define rson rt << 1 | 1
#define bug cout << "BUG HERE\n"
#define debug(x) cout << #x << " = " << x << endl
#define ALL(v) (v).begin(), (v).end()
#define lowbit(x) ((x)&(-x))
#define Unique(x) sort(ALL(x)); (x).resize(unique(ALL(x)) - (x).begin())
#define BitOne(x) __builtin_popcount(x)
#define showtime printf("time = %.15f\n",clock() / (double)CLOCKS_PER_SEC)
#define Rep(i, l, r) for (int i = l;i <= r;++i)
#define Rrep(i, r, l) for (int i = r;i >= l;--i)
typedef long long LL;
typedef unsigned long long ULL;
typedef pair<int,int> ii;
typedef pair<ii,int> iii;
const double eps = 1e-8;
const double pi = 4 * atan(1);
const int inf = 1 << 30;
const int INF = 0x3f3f3f3f;
const int MOD = 1e9 + 7;
int nCase = 0;
//精度正负、0的判断
int dcmp(double x){if (fabs(x) < eps) return 0;return x < 0?-1:1;}
template<class T> inline bool read(T &n){
T x = 0, tmp = 1;
char c = getchar();
while((c < '0' || c > '9') && c != '-' && c != EOF) c = getchar();
if(c == EOF) return false;
if(c == '-') c = getchar(), tmp = -1;
while(c >= '0' && c <= '9') x *= 10, x += (c - '0'),c = getchar();
n = x*tmp;
return true;
}
template <class T> inline void write(T n){
if(n < 0){putchar('-');n = -n;}
int len = 0,data[20];
while(n){data[len++] = n%10;n /= 10;}
if(!len) data[len++] = 0;
while(len--) putchar(data[len]+48);
}
LL QMOD(LL x, LL k) {
LL res = 1LL;
while(k) {if (k & 1) res = res * x % MOD;k >>= 1;x = x * x % MOD;}
return res;
}
const int N = 1000;
vector<vector<int> > g;
bool p[N][N], local[10][10];
int n, m;
vector<int> up[N];
int dp[1 << 10];
int lowbit[1<< 10];
int dfs(int u,int f) {
int result = 0;
vector<int> sons;
for (int v : g[u]) {
if (v != f) {
result += dfs(v, u);
sons.push_back(v);
}
}
/*以v为根的子树中是否有点与u构成一条路径,有就删除v子树*/
/*因为直接和u相连,不会有比这更优的方案了*/
for (int v : sons) {
for (int x : up[v]) {
if (p[u][x]) {
result ++;
up[v].clear();
break;
}
}
}
/*与u直接相连的判断完了后,在判断两棵子树跨越u形成路径的情况*/
/*子树的数目小于10*/
for (int i = 0;i < sons.size();++i) {
for (int j = i + 1;j < sons.size();++j) {
local[i][j] = false;
for (int a : up[sons[i]]) {
for (int b : up[sons[j]]) {
if (p[a][b]) {
local[i][j] = true;
break;
}
}
if (local[i][j]) break;
}
}
}
/* 因为题目保证了每个节点连接的节点数不回超过10个,故可以状压枚举*/
dp[0] = 0;
/*小状态先算出来,用来更新大状态*/
for (int sta = 1;sta < (1 << sons.size());++sta) {
int a = lowbit[sta];
/*枚举a,b两棵子树中的点匹配,先确定a,在枚举b*/
/*所以是dp[sta] = dp[sta - (1 << a)]*/
dp[sta] = dp[sta - (1 << a)];/*sta > sta - (1 << a)*/
for (int b = a + 1;b < sons.size();++b) {
if (((sta >> b) & 1) && local[a][b])/*sta > sta - (1 << a) > sta - (1 << b)*/
dp[sta] = max(dp[sta], dp[sta - (1 << a) - (1 << b)] + 1);
}
}
result += dp[(1 << sons.size()) - 1];
up[u].clear();
up[u].push_back(u);/*u为根的子树,包含u节点*/
for (int i = 0;i < sons.size();++i) {
/*判断第i棵子树是否必须要,只要剔除i看看是否相等就行了*/
/*只需要判断dp[(1 << sons.size()) - 1] 是否等于 dp[(1 << sons.size()) - 1 - (1 << i)]*/
if (dp[(1 << sons.size()) - 1] == dp[(1 << sons.size()) - 1 - (1 << i)])
for (int x : up[sons[i]])
up[u].push_back(x);
}
return result;
}
int main(int argc, const char * argv[])
{
freopen("in.txt","r",stdin);
// freopen("out.txt","w",stdout);
// ios::sync_with_stdio(false);
// cout.sync_with_stdio(false);
// cin.sync_with_stdio(false);
/*计算出状态i二进制中从左到右第一个1的位置*/
for (int i = 1;i < (1 << 10);++i) {
lowbit[i] = (i & 1) ? 0 : 1 + lowbit[i >> 1];
// printf("%4d", lowbit[i]);
// if (i % 10 == 0) puts("");
}
// puts("");
int kase;cin >> kase;
while(kase--) {
g.clear();
cin >> n ;
Rep(i, 0, n - 1) Rep(j, 0, n - 1) p[i][j] = false;
g.resize(n + 2);
int u, v;
Rep(i, 1, n - 1) {
scanf("%d%d", &u, &v);
u--, v--;
g[u].push_back(v);
g[v].push_back(u);
}
cin >> m;
Rep(i, 1, m) {
scanf("%d%d", &u, &v);
u--, v--;
p[u][v] = p[v][u] = true;
}
cout << dfs(0, -1) << endl;
}
// showtime;
return 0;
}