算法复杂度需要在O(n)附近,log10**8大概在18左右,nlogn应该超时了,不过python做这题,怎么尝试也会超时。
算法只能尽量精简,不断剪枝优化。
刚开始想法是这样:
n, q = map(int, input().split())
rec = []
vis = [True]*(n+1)
for i in range(2, n+1):
if not vis[i]:rec.append(i)
for j in range(2,n+1):
if i*j > n:break
vis[i*j] = False
for i in range(q):
idx = int(input())
print(rec[idx-1])
两个for循环,把i的j倍都设置成false。显然不太ok。
接着改变数据结构,vis改编成字典会节省一部分内存,不过还不行。
看了b站学习线性筛,相当于找到数学上的规律,剪枝。
from collections import defaultdict
n, q = map(int, input().split())
rec = []
vis = defaultdict(int)
for i in range(2, n+1):
if not vis[i]:rec.append(i)
for j in rec:
if i*j > n:break
vis[i*j] = 1
if i%j == 0:break
for i in range(q):
idx = int(input())
print(rec[idx-1])