Python(包括其包Numpy)中包含了了許多概率算法,包括基礎的隨機采樣以及許多經典的概率分布生成。我們這個系列介紹幾個在機器學習中常用的概率函數。先來看最基礎的功能——隨機采樣。
1. random.choice
如果我們只需要從序列里采一個樣本(所有樣本等概率被采),只需要使用random.choice
即可:
import random
res1 = random.choice([0, 1, 2, 3, 4])
print(res1) # 3
2. random.choices
(有放回)
當然,很多時候我們不只需要采一個數,而且我們需要設定序列中每一項被采的概率不同。此時我們可以采用random.random.choices
函數, 該函數用於有放回的(即一個數據項可以被重復采多次)對一個序列進行采樣。其函數原型如下:
random.choices(population, weights=None, *, cum_weights=None, k=1)
population
: 欲采樣的序列
weights
: 每個樣本被賦予的權重(又稱相對權重),決定每個樣本被采的概率,如[10, 0, 30, 60, 0]
cum_weights
: 累積權重,相對權重[10, 0, 30, 60, 0]相當於累積權重[10, 10, 40, 100, 100]
我們從[0, 1, 2, 3, 4]
中按照相對權重采樣3個樣本如下:
res2 = random.choices([0, 1, 2, 3, 4], weights=[10, 0, 30, 60, 0], k=3)
# 注意population不是關鍵字參數,在函數調用時不能寫成population=[0,1,2,3,4]來傳參
# 關於關鍵字參數和位置參數,可以參看我的博客《Python技法2:函數參數的進階用法》https://www.cnblogs.com/orion-orion/p/15647408.html
print(res2) # [3, 3, 2]
從[0, 1, 2, 3, 4]
中按照累積權重采樣3和樣本如下:
res3 = random.choices([0, 1, 2, 3, 4], cum_weights=[10, 10, 40, 100, 100], k=3)
print(res3) # [0, 3, 3]
注意,相對權重weights
和累計權重cum_weights
不能同時傳入,否則會報TypeError
異常'Cannot specify both weights and cumulative weights'
。
3. random.sample
(無放回)
random.sample
是無放回,如果我們需要無放回采樣(即每一項只能采一次),那我們需要使用random.sample
。需要注意的是,如果使用該函數,將無法定義樣本權重。該函數原型如下:
random.sample(population, k, *, counts=None)¶
population
: 欲采樣的序列
k
: 采樣元素個數
counts
: 用於population是可重復集合的情況,定義集合元素的重復次數。sample(['red', 'blue'], counts=[4, 2], k=5)
等價於sample(['red', 'red', 'red', 'red', 'blue', 'blue'], k=5)
我們無放回地對序列[0, 1, 2, 3, 4]
采樣3次如下:
res3 = random.sample([0, 1, 2, 3, 4], k=3)
print(res3) # [3, 2, 1]
注意,這里是依次采樣3次、每次采1個元素,而不是一次性采3個元素。故采出元素和原始序列元素的先后順序不一定相同。
無放回地對可重復集合[0, 1, 1, 2, 2, 3, 3, 4]
采樣3次如下:
res4 = random.sample([0, 1, 2, 3, 4], k=3, counts=[1, 2, 2, 2, 1])
print(res4) # [3, 2, 2]
如果counts
長度和population
序列長度不一致,會拋出異常ValueError
:"The number of counts does not match the population"
。
4.rng.choices
和 rng.sample
還有一種有放回采樣實現方法是我在論文代碼中學習到的。即先定義一個隨機數生成器,再調用隨機數生成器的choices
方法或sample
方法,其使用方法和random.choice
/random.sample
函數相同。
rng_seed = 1234
rng = random.Random(rng_seed)
res5 = rng.choices(
population=[0,1,2,3,4],
weights=[0.1, 0, 0.3, 0.6, 0],
k=3,
)
print(res5) # [3, 3, 0]
res6 = rng.sample(
population=[0, 1, 2, 3, 4],
k=3,
)
print(res6) # [4, 0, 2]
這兩個函數在聯邦學習論文的實現代碼中用來隨機選擇任務節點client
:
def sample_clients(self):
"""
sample a list of clients without repetition
"""
rng_seed = (seed if (seed is not None and seed >= 0) else int(time.time()))
self.rng = random.Random(rng_seed)
if self.sample_with_replacement:
self.sampled_clients = \
self.rng.choices(
population=self.clients,
weights=self.clients_weights,
k=self.n_clients_per_round,
)
else:
self.sampled_clients = self.rng.sample(self.clients, k=self.n_clients_per_round)
5. numpy.random.choices
從序列中按照權重分布采樣也可以采用numpy.random.choice
實現。其函數原型如下:
random.choice(a, size=None, replace=True, p=None)
a
: 1-D array-like or int 如果是1-D array-like,那么樣本會從其元素中抽取。如果是int,那么樣本會從np.arange(a)
中抽取;
size
: int or tuple of ints, optional 為輸出形狀大小,如果給定形狀為\((m, n, k)\),那么\(m\times n\times k\)的樣本會從中抽取。默認為None,即返回一個單一標量。
replace
: boolean, optional 表示采樣是又放回的還是無放回的。若replace=True
,則為又放回采樣(一個值可以被采多次),否則是無放回的(一個值只能被采一次)。
p
: 1-D array-like, optional 表示a
中每一項被采的概率。如果沒有給定,則我們假定a
中各項被采的概率服從均勻分布(即每一項被采的概率相同)。
從[0,1,2,3,4,5]
中重復/不重復采樣3次如下:
import numpy as np
res1 = np.random.choice(5, 3, replace=True)
print(res1) # [1 1 4]
res2 = np.random.choice(5, 3, replace=False)
print(res2) # [2 1 4]
同樣是[0,1,2,3,4,5]
中重復/不重復采樣3次,現在來看我們為每個樣本設定不同概率的情況:
res3 = np.random.choice(5, 3, p=[0.1, 0, 0.3, 0.6, 0])
print(res3) # [2 3 3]
res4 = np.random.choice(5, 3, replace=False, p=[0.1, 0, 0.3, 0.6, 0])
print(res4) # [3 2 0]
6. random.shuffle
和np.random.shuffle
將序列打亂可以使用 random.shuffle
和np.random.shuffle
。兩者都可以作用於可變序列(MutableSequence),后者還可以作用於ndarray
(規定沿第一維打亂)。
import random
import numpy as np
list_1 = [1, 2, 3, 4]
random.shuffle(list_1) # [2, 4, 3, 1]
print(list_1)
list_2 = [1, 2, 3, 4]
np.random.shuffle(list_2) # [4, 1, 2, 3]
print(list_2)
matrix_1 = np.array(
[[1, 2],
[3, 4],
[5, 6]]
)
np.random.shuffle(matrix_1)
print(matrix_1)
# [[5 6]
# [1 2]
# [3 4]]
注意,這兩個函數都是原地(in-place)修改,返回值為None。