题目:https://www.luogu.com.cn/problem/P3884
C语言代码:
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
struct BinTree
{
int left, right;
int deep, fa;
};
struct BinTree bt[105];
int t[105], width[105], fa[105][105];//fa[x,y]:xÔÚµÚy²ãµÄ×æÏÈ
int MaxW = 0, MaxD = 0;
int max(int x, int y)
{
if (x > y)
return x;
else
return y;
}
void build(int x, int y)
{
int i;
if (t[y]==0)
{
bt[x].left = y;
t[x] = 1;
}
else
{
bt[x].right = y;
}
bt[y].deep = bt[x].deep + 1;
width[bt[y].deep]++;
MaxW = max(width[bt[y].deep], MaxW);
MaxD = max(bt[y].deep, MaxD);
for (i = 1; i <= bt[x].deep; i++)
fa[y][i] = fa[x][i];
fa[y][bt[y].deep] = y;
}
int find(int x, int y)
{
int i;
i = bt[y].deep;
while (fa[x][i] != fa[y][i])
{
i--;
}
return i;
}
int main()
{
int i, n, x, y, z;
scanf("%d", &n);
memset(t, 0, sizeof(t));
memset(width, 0, sizeof(width));
bt[1].deep = 1;
fa[1][1] = 1;
width[1] = 1;
for (i = 1; i < n; i++)
{
scanf("%d%d", &x, &y);
build(x, y);
}
scanf("%d%d", &x, &y);
if (bt[x].deep > bt[y].deep)
z = find(x, y);
else
z = find(y, x);
printf("%d\n%d\n", MaxD, MaxW);
printf("%d\n", (bt[x].deep - z)*2 + bt[y].deep - z);
return 0;
}
Python:
数组初始化:
NodeList = [BinTree() for i in range(n+1)]
width = [0 for i in range(n+1)]
fa = [[0 for i in range(n+1)] for j in range(n+1)]
定义结构体:
class BinTree(object):
def __init__(self):
self.left = None
self.right = None
self.deep = None
None表示没有初始值
定义变量:
x = BinTree()
x.left = ...
代码:
class BinTree(object):
def __init__(self):
self.left = None
self.right = None
self.deep = None
def build(x, y):
global MaxD, MaxW
if NodeList[x].left == None:
NodeList[x].left = y
else:
NodeList[x].right = y
NodeList[y].deep = NodeList[x].deep + 1
width[NodeList[y].deep] += 1
MaxD = max(MaxD, NodeList[y].deep)
MaxW = max(MaxW, width[NodeList[y].deep])
for i in range(1, NodeList[x].deep + 1):
fa[y][i] = fa[x][i]
# print(y, i, fa[y][i])
fa[y][NodeList[y].deep] = y
# print(y, NodeList[y].deep, fa[y][NodeList[y].deep])
def find(x, y):
i = NodeList[y].deep
while fa[x][i] != fa[y][i]:
i -= 1
return i
n = eval(input())
MaxD = 0
MaxW = 0
NodeList = [BinTree() for i in range(n+1)]
width = [0 for i in range(n+1)]
fa = [[0 for i in range(n+1)] for j in range(n+1)]
NodeList[1].deep = 1
fa[1][1] = 1
width[1] = 1
for i in range(n - 1):
s = input()
x = s.split(' ')
build(eval(x[0]), eval(x[1]))
s = input()
y = s.split(' ')
x1 = eval(y[0])
y1 = eval(y[1])
if NodeList[x1].deep > NodeList[y1].deep:
z = find(x1, y1)
else:
z = find(y1, x1)
print(MaxD)
print(MaxW)
print((NodeList[x1].deep - z) * 2 + NodeList[y1].deep - z)