题目链接:
http://codeforces.com/contest/109/problem/C
题目大意:
在一棵树上,有一些边是幸运边,现要求出点组(i, j, k)的数量,使得i 到j 的路径上和i 到k 的路径上都至少有一条幸运边。
算法:
法一:
将点按照非幸运边缩联通块。
枚举做i 的点,那么以该点为i 的所有解,即是由i 点与任意两个不与它在同一个联通块内的点的组合。
法二:
对于每个点,记录下它的子树内有多少个点到它的路径上有幸运边,以及它的子树外有多少个点到它的路径上有幸运边。
然后枚举这个点做i 时,另两个点分别有0,1,2个在它子树内的情况。
注意这两个值一个只能由下到上更新,一个只能由下到上更新,所以要DFS两次。计算的方法就是分情况树形DP传递一下。
代码:
#include <cstdio>
#include <iostream>
#include <algorithm>
#include <sstream>
#include <cstdlib>
#include <cstring>
#include <string>
#include <climits>
#include <cmath>
#include <queue>
#include <vector>
#include <stack>
#include <set>
#include <map>
#define INF 0x3f3f3f3f
#define eps 1e-8
#define fi first
#define nd second
#define mp make_pair
using namespace std;
const int maxn = 110000;
vector<pair<int, bool> > edge[maxn];
long long cot0[maxn], cot1[maxn], cot2[maxn], cot3[maxn];
long long ans;
inline bool check(int x)
{
if (! x)
{
return false;
}
while (x)
{
if ((x % 10 != 4) && (x % 10 != 7))
{
return false;
}
x /= 10;
}
return true;
}
void dfs0(int u, int p)
{
for (int i = 0; i < edge[u].size(); i ++)
{
int v = edge[u][i].fi;
if (v == p)
{
continue;
}
dfs0(v, u);
if (edge[u][i].nd)
{
cot1[u] += cot0[v] + cot1[v] + 1;
}
else
{
cot0[u] += cot0[v] + 1;
cot1[u] += cot1[v];
}
}
ans += cot1[u] * (cot1[u] - 1);
}
void dfs1(int u, int p)
{
for (int i = 0; i < edge[u].size(); i ++)
{
int v = edge[u][i].fi;
if (v == p)
{
continue;
}
if (edge[u][i].nd)
{
cot3[v] = cot2[u] + cot3[u] + cot0[u] + cot1[u] - cot0[v] - cot1[v];
}
else
{
cot2[v] = cot2[u] + cot0[u] - cot0[v];
cot3[v] = cot3[u] + cot1[u] - cot1[v];
}
ans += cot1[v] * cot3[v] * 2;
ans += cot3[v] * (cot3[v] - 1);
dfs1(v, u);
}
}
int main()
{
int n;
while (scanf("%d", &n) == 1)
{
ans = 0LL;
memset(cot0, 0, sizeof(cot0));
memset(cot1, 0, sizeof(cot1));
memset(cot2, 0, sizeof(cot2));
memset(cot3, 0, sizeof(cot3));
for (int i = 0; i <n; i ++)
{
edge[i].clear();
}
for (int i = 1; i < n; i ++)
{
int u, v, w;
scanf("%d %d %d", &u, &v, &w);
edge[u - 1].push_back(mp(v - 1, check(w)));
edge[v - 1].push_back(mp(u - 1, check(w)));
}
dfs0(0, -1);
dfs1(0, -1);
printf("%I64d\n", ans);
}
return 0;
}