思路用并查集统计一个连通块的节点个数,最后用总的减去他,设x是连通块的节点个数,o个联通块
#include<bits/stdc++.h>
#define fi first
#define se second
#define INF 0x3f3f3f3f
#define ll long long
#define ld long double
#define mem(ar,num) memset(ar,num,sizeof(ar))
#define me(ar) memset(ar,0,sizeof(ar))
#define lowbit(x) (x&(-x))
#define IOS ios::sync_with_stdio(0); cin.tie(0); cout.tie(0);
#define lcm(a,b) ((a)*(b)/(__gcd((a),(b))))
#define Max 200010
#define mod 1000000007
using namespace std;
int n, ans, k, f, mp[Max], u, sz[Max];
int F(int x) {
return x == mp[x] ? x : mp[x] = F(mp[x]);
}
int poww(int a, int b) {
ll ans = 1, base = a;
while (b != 0) {
if (b & 1 != 0)
ans *= base, ans %= mod;
base *= base, base %= mod;
b >>= 1;
}
return ans % mod;
}
int main() {
cin >> n >> k;
for(int i = 1; i <= n; i++)
mp[i] = i, sz[i] = 1;
for(int a, b, c, i = 1; i < n; i++) {
cin >> a >> b >> c;
if(!c) {
int x = F(a), y = F(b);
mp[x] = y, sz[y] += sz[x];
}
}
for(int i = 0; i <= n; i++)
if(mp[i] == i)
(ans += poww(sz[i], k)) %= mod;
printf("%d", (poww(n, k) - ans + mod) % mod);
return 0;
}