class Solution:
def __init__(self, w: List[int]):
self.poss = [0.0] * len(w)
total = sum(w)
for i,x in enumerate(w):
if i==0:
self.poss[i] = x / total
else:
self.poss[i] = x / total + self.poss[i-1]
def pickIndex(self) -> int:
randNum = random.random()
for i,x in enumerate(self.poss):
if randNum <= x:
return i
return 0
用二分加快搜索速度
import random
from typing import List
from bisect import bisect_left
from itertools import accumulate
class Solution:
def __init__(self, w: List[int]):
self.pre = list(accumulate(w))
self.total = sum(w)
def pickIndex(self) -> int:
x = random.randint(1, self.total)
return bisect_left(self.pre, x)
solution = Solution([1,3])
print(solution.pickIndex())