题目:
n = int(input())
nums = [int(x) for x in input().split()]
if n <= 2:
print(0)
else:
presum = [0 for _ in range(n + 1)]
for i in range(n):
presum[i + 1] = presum[i] + nums[i]
tot_sum = presum[n]
if tot_sum % 3 != 0:
print(0)
else:
aver = tot_sum // 3
a = aver
b = aver * 2
c = aver * 3
cntb = 0
for i in range(1, n):
if presum[i] == b:
cntb += 1
res = 0
for i in range(1, n - 1):
if presum[i] == b:
cntb -= 1
if presum[i] == a:
res += cntb
print(res)