random.choices 函數
python 官方標准庫 random 中,有個函數 random.choices(population, weights=None, *, cum_weights=None, k=1)
,比起常用的 random.choice(seq)
,這個函數可以指定概率權重和選擇次數。
因為刷題的時候用到了這個函數,題目又對時間復雜度有限制,我就很好奇,然后來分析一下這個函數的時間復雜度。
源碼
def choices(self, population, weights=None, *, cum_weights=None, k=1):
"""Return a k sized list of population elements chosen with replacement.
If the relative weights or cumulative weights are not specified,
the selections are made with equal probability.
"""
random = self.random
n = len(population)
if cum_weights is None:
if weights is None:
floor = _floor
n += 0.0 # convert to float for a small speed improvement
return [population[floor(random() * n)] for i in _repeat(None, k)]
cum_weights = list(_accumulate(weights))
elif weights is not None:
raise TypeError('Cannot specify both weights and cumulative weights')
if len(cum_weights) != n:
raise ValueError('The number of weights does not match the population')
total = cum_weights[-1] + 0.0 # convert to float
if total <= 0.0:
raise ValueError('Total of weights must be greater than zero')
bisect = _bisect
hi = n - 1
return [population[bisect(cum_weights, random() * total, 0, hi)]
for i in _repeat(None, k)]
參數說明
population
: 輸入的待選取序列weights
: 權重序列cum_weights
: 累加的權重序列,相當於weights
的前綴和數組k
: 選取的次數,該函數會返回一個長度為k
的列表
功能說明
參考官方文檔可知,這個函數通過權重隨機選取數字,比如 choices([1, 2], weights=[3, 2])
,相當於使用 choice([1, 1, 1, 2, 2])
,也可以寫成 choices([1, 2], cum_weights=[3, 5])
假設給出了權重(weights
)但是沒有累加權重(cum_weights
):
- 函數內部會把權重累加
cum_weights = list(_accumulate(weights))
; - 使用
random()
函數輸出一個[0.0, 1.0)
區間的數,乘上所有權重的累加和,作為生成的隨機數。權重的累加和也是cum_weights
數組最后一個元素值; - 用二分查找 (標准庫函數:bisect) 在累加序列
cum_weights
中找到隨機數的位置,輸出該位置的數據。
時間復雜度分析
函數共有 2 個出口:
-
weights
和cum_weights
均為None
的情況:return [population[floor(random() * n)] for i in _repeat(None, k)]
時間復雜度:O(k) ,因為
k
為常數,所以也可以認為時間復雜度為 O(1)這種情況和直接使用
choice
沒有差別,所以我就不考慮在最終結果里了。 -
weights
不為None
的情況:return [population[bisect(cum_weights, random() * total, 0, hi)] for i in _repeat(None, k)]
時間復雜度:O(klog(n)),因為
k
為常數,所以也可以認為時間復雜度為 O(log(n)) (注:log(n) 來自二分查找)- 如果
cum_weights
為None
,還需要執行cum_weights = list(_accumulate(weights))
,_accumulate
類似於itertools.accumulate()
,時間復雜度:O(n),與上面的 O(log(n)) 疊加,總時間復雜度為:O(n)
- 如果
所以結論在於用戶有沒有給出累加權重,也就是 cum_weights
數組:
- 如果給出
cum_weights
:O(log(n)) ,精確一點就是 O(klog(n)) ,這個k
就是那個參數k
,是個常數。 - 如果沒有給出:O(n)
所以呢,如果數據規模特別大,還是要謹慎使用這個函數的,尤其是沒有提供 cum_weights
參數的時候。