最近一個月的時間,基本上都在加班加點的寫業務,在寫代碼的時候,也遇到了一個有趣的問題,值得記錄一下。
簡單來說,需求是從一個字典(python dict)中隨機選出K個滿足條件的key。代碼如下(python2.7):
1 def choose_items(item_dict, K, filter): 2 '''item_dict = {id:info} ''' 3 candidate_ids = [id for id in item_dict if filter(item_dict[id])] 4 if len(candidate_ids) <= K: 5 return set(candidate_ids) 6 else: 7 return set(random.sample(candidate_ids, K))
代碼邏輯很簡單,也能正常工作。但我知道這個函數調用的頻率會很高,len(item_dict)也會比較大,那么這段代碼會不會有效率問題呢。當然,一切都要基於profile,如果確實有問題,那么就需要優化。但首先,我想搞明白的是,我使用了random.sample這個函數,這個函數的時間復雜度如何呢。另外,我也經常會使用random.shuffle函數,也就想一並搞清楚。
本文記錄對shuffle,sample兩個算法的思考,參考的是python2.7.3中的random模塊。當然,這兩個算法與語言並不相關。另外,本人對算法研究甚少,認識錯誤之處還請大家不吝賜教。
本文地址:http://www.cnblogs.com/xybaby/p/8280936.html
Shuffle
shuffle的意思就是讓序列亂序,本質上就是讓序列里面的每一個元素等概率的重新分布在序列的任何位置。在使用MP3聽歌(是不是暴露的年齡)的時候,就有兩個功能:shuffle,random,二者的區別在於,前者打亂播放順序,保證所有的歌曲都會播放一遍;而后者每次隨機選擇一首。
Python里面random.shuffle源碼如下:
1 def shuffle(self, x, random=None, int=int): 2 """x, random=random.random -> shuffle list x in place; return None. 3 4 Optional arg random is a 0-argument function returning a random 5 float in [0.0, 1.0); by default, the standard random.random. 6 """ 7 8 if random is None: 9 random = self.random 10 for i in reversed(xrange(1, len(x))): 11 # pick an element in x[:i+1] with which to exchange x[i] 12 j = int(random() * (i+1)) 13 x[i], x[j] = x[j], x[i]
核心的代碼就3行,其實就是非常經典的Fisher–Yates shuffle算法的實現,Fisher–Yates shuffle算法偽碼如下:
-- To shuffle an array a of n elements (indices 0..n-1): for i from n−1 downto 1 do j ← random integer such that 0 ≤ j ≤ i exchange a[j] and a[i]
第一步 即從0到N-1個元素中隨機選擇一個與第N-1個替換
第二步 從0到N-2個元素中隨機選擇一個與第N-2個替換
第k步 從0到N-k個元素中隨機選擇一個與第N-K個替換
要證明算法的正確性也很簡單,即任何一個元素shuffle之后出現在任意位置的概率都是1/N。任意一個元素,放在第N-1個位置的概率是1/N, 放在pos N-2的位置是 (N-1)/N * 1 / (N-1) = 1/N 。需要注意的是,一個元素一旦被交換到了序列的尾部,那么就不會再被選中,這也是算法一目了然的原因。
上面的實現是從后到前的,當然也可以從前到后,即先從0到N-1個元素中隨機選擇一個與第0個交換,然后從1到N-1個元素中隨機選擇一個與第1個交換 。。。只不過寫代碼的時候會稍微麻煩一點點,wiki上也有相應的偽碼。
但是,我也看到網上有這么一種實現:
1 void get_rand_number(int array[], int length) 2 { 3 int index; 4 int value; 5 int median; 6 7 if(NULL == array || 0 == length) 8 return ; 9 10 /* 每次發牌的時候任意分配待交換的數據 */ 11 for(index = 0; index < length; index ++){ 12 value = rand() % length; 13 14 median = array[index]; 15 array[index] = array[value]; 16 array[value] = median; 17 } 18 }
與Fisher–Yates shuffle算法的區別在於,上面的算法每次都是從整個序列中選擇一個元素作為被交換的元素,即先從整個序列選擇一個元素與第0個元素交換,然后再從整個序列選擇一個元素與第1個元素交換.。。。這個直覺就有點問題,比如一個元素(X)第一步就放到了第0個位置,但是之后有可能被交換到其他位置,以后X就再也不會回到第0個元素,當然,X也可能再第二步 第三步被交換到第0個位置。
但要證明該算法有問題似乎不是這么容易,那么首先用事實(數據)說話,於是我用python重寫了上述代碼,並做了測試,代碼如下
1 import random 2 def myshuffle(lst): 3 length = len(lst) 4 for idx in xrange(length): 5 t_idx = random.randint(0, length-1) 6 lst[idx], lst[t_idx] = lst[t_idx], lst[idx] 7 if __name__ == '__main__': 8 random.seed() 9 10 pre_lst = ['a', 'b', 'c'] 11 count = dict((e, {}) for e in pre_lst) 12 TRY = 1000000 13 14 for i in xrange(TRY): 15 lst = pre_lst[:] 16 myshuffle(lst) 17 for alpha in pre_lst: 18 idx = lst.index(alpha) 19 count[alpha][idx] = count[alpha].get(idx, 0) + 1 20 21 for alpha, alpha_count in sorted(count.iteritems(), key=lambda e: e[0]): 22 result_lst = [] 23 for k, v in sorted(alpha_count.iteritems(), key=lambda e: e[0]): 24 result_lst.append(round(v * 1.0 / TRY, 3)) 25 print alpha, result_lst
運算的結果是:
('a', [0.333, 0.334, 0.333])('b', [0.371, 0.296, 0.333])('c', [0.296, 0.37, 0.334])
如果將pre-list改成 pre_list = ['a', 'b', 'c', 'd', 'e'],那么輸出結果是:
('a', [0.2, 0.2, 0.2, 0.2, 0.199])('b', [0.242, 0.18, 0.186, 0.191, 0.2])('c', [0.209, 0.23, 0.175, 0.186, 0.2])('d', [0.184, 0.205, 0.23, 0.18, 0.2])('e', [0.164, 0.184, 0.209, 0.242, 0.2])
這里稍微解釋一下輸出,每一行是字母在shuffle之后,出現在每一個位置的概率。比如元素‘e',在pre_list的位置是4(從0開始),shuffle之后,出現在第0個位置的統計概率為0.164,出現在第1個位置的統計概率是0.184,顯然不是等概率的。
假設P[i][j]是原來序列種第i個元素shuffle之后移動到第j個位置的概率,那么這個公式怎么推導昵?我嘗試過,不過沒有推出來。
在stackoverflow上,我提問了這個問題,並沒有得到直接的答案,不過有一個回答很有意思,指出了從理論上這個shuffle算法就不可能是正確的
This algorithm has
n^n
different ways to go through the loop (n
iterations picking one ofn
indexes randomly), each equally likely way through the loop producing one ofn!
possible permutations. Butn^n
is almost never evenly divisible byn!
. Therefore, this algorithm cannot produce an even distribution.
就是說,myshuffle由N^N種可能,但按照排隊組合,N個元素由N的階乘種排列方式。N^N不能整除N的階乘,所以不可能是等概率的。
歡迎大家幫忙推倒這個公式,我自己只能推出P[N-1][0], P[N-2][0],真的頭大。
Sample
Python中random.sample的document是這樣的:
random.sample(population, k)
Return a k length list of unique elements chosen from the population sequence. Used for random sampling without replacement.
上面的document並不完整,不過也可以看出,是從序列(sequence)中隨機選擇k個元素,返回的是一個新的list,原來的序列不受影響。
但是從document中看不出時間復雜度問題。所以還是得看源碼:
1 def sample(self, population, k): 2 """Chooses k unique random elements from a population sequence. 3 4 Returns a new list containing elements from the population while 5 leaving the original population unchanged. The resulting list is 6 in selection order so that all sub-slices will also be valid random 7 samples. This allows raffle winners (the sample) to be partitioned 8 into grand prize and second place winners (the subslices). 9 10 Members of the population need not be hashable or unique. If the 11 population contains repeats, then each occurrence is a possible 12 selection in the sample. 13 14 To choose a sample in a range of integers, use xrange as an argument. 15 This is especially fast and space efficient for sampling from a 16 large population: sample(xrange(10000000), 60) 17 """ 18 19 # Sampling without replacement entails tracking either potential 20 # selections (the pool) in a list or previous selections in a set. 21 22 # When the number of selections is small compared to the 23 # population, then tracking selections is efficient, requiring 24 # only a small set and an occasional reselection. For 25 # a larger number of selections, the pool tracking method is 26 # preferred since the list takes less space than the 27 # set and it doesn't suffer from frequent reselections. 28 29 n = len(population) 30 if not 0 <= k <= n: 31 raise ValueError("sample larger than population") 32 random = self.random 33 _int = int 34 result = [None] * k 35 setsize = 21 # size of a small set minus size of an empty list 36 if k > 5: 37 setsize += 4 ** _ceil(_log(k * 3, 4)) # table size for big sets 38 if n <= setsize or hasattr(population, "keys"): 39 # An n-length list is smaller than a k-length set, or this is a 40 # mapping type so the other algorithm wouldn't work. 41 pool = list(population) 42 for i in xrange(k): # invariant: non-selected at [0,n-i) 43 j = _int(random() * (n-i)) 44 result[i] = pool[j] 45 pool[j] = pool[n-i-1] # move non-selected item into vacancy 46 else: 47 try: 48 selected = set() 49 selected_add = selected.add 50 for i in xrange(k): 51 j = _int(random() * n) 52 while j in selected: 53 j = _int(random() * n) 54 selected_add(j) 55 result[i] = population[j] 56 except (TypeError, KeyError): # handle (at least) sets 57 if isinstance(population, list): 58 raise 59 return self.sample(tuple(population), k) 60 return result
咋眼一看,不同的情況下有兩種方案(對應38行的if、46行的else),一種方案類似shuffle,復雜度是O(K);而另一種方案,看代碼的話,復雜度是O(NlogN) (后面會說明,事實並非如此)
我當時就驚呆了,這個時間復雜度不能接受吧,在Time complexity of random.sample中,也有網友說是O(NlogN)。這個我是不能接受的,這個是官方模塊,怎么可能這么不給力,那我自然想看明白這段代碼。
Sample的各種實現
在這之前,不妨考慮一下,如果要自己實現這個sample函數,那么有哪些方法呢。
我們首先放寬sample的定義,就是從有N個元素的序列中隨機取出K個元素,不考慮是否影響原序列
第一種,隨機抽取且不放回
跟抽牌一樣,隨機從序列中取出一個元素,同時從原序列中刪除,那么不難驗證每個元素被取出的概率都是K/N(N是序列長度),滿足Sample需求。
若不考慮元素從列表中刪除的代價,那么時間復雜度是O(K)。但問題也很明顯,就是會修改原序列
第二種,隨機抽取且放回
除了記錄所有被選擇的元素,還需要維護被選擇的元素在序列中的位置(selected_pos_set)。隨機從序列中取出一個元素,如果抽取到的元素的位置在selected_pos_set中,那么重新抽取;否則將新元素的位置放到selected_pos_set中。
不難發現,這個就是python random.sample代碼中第二種實現。
這個算法的好處在於,不影響原序列。
那么時間復雜度呢?在抽取第i個元素的時候,抽取到重復位置元素的概率是(i - 1)/N,那么平均抽取次數就是N/(N - i +1)。那么抽取K個元素的平均抽取測試就是,sum(N/(N - i +1) ), 1 <= i <= K; 等於N(logN - log(N-K+1)) 。當K等於N時,也就是NlogN
第三種,先shuffle整個序列,然后取前K個元素
算法的正確性很容易驗證,時間復雜度是O(N),而且原序列會被修改(亂序也算做修改)
第四種,部分shuffle,得到K個元素就返回
如果了解shuffle算法,那么算法理解還是很容易的。random.sample中第一種方案也是這樣的算法。
單獨實現這個算法的話就是這個樣子的:
1 def sample_with_shuffle(self, population, k): 2 n = len(population) 3 result = [None] * k 4 for i in xrange(k): # invariant: non-selected at [0,n-i) 5 j = int(random.random() * (n-i)) 6 result[i] = population[j] 7 population[j] = population[n-i-1] # move non-selected item into vacancy 8 return result
時間復雜度是O(K),但缺點就是原序列會被改變。
第五種,水塘抽樣算法
水塘抽樣算法(Reservoir_sampling)解決的是 樣本總體很大,無法一次性放進內存;或者是在數據流上的隨機采樣問題。即不管有多少個元素,被選中的K個元素都是等概率的。算法很巧妙,也是非常經典的面試題。
算法偽碼是這樣的:
1 ReservoirSample(S[1..n], R[1..k]) 2 // fill the reservoir array 3 for i = 1 to k 4 R[i] := S[i] 5 6 // replace elements with gradually decreasing probability 7 for i = k+1 to n 8 j := random(1, i) // important: inclusive range 9 if j <= k 10 R[j] := S[i]
算法的時間復雜度是O(N),且不會影響原序列。
回到random.sample
通過上面的思考可見,最低復雜度是O(K),但需要改變原序列。如果不改變原序列,時間復雜度最低為O(N)。
但是如果重新拷貝一份原序列,那么是可以使用部分shuffle,但拷貝操作本身,需要時間與額外的空間。
其實python random.sample這個函數的注釋說明了實現的原因:
# Sampling without replacement entails tracking either potential
# selections (the pool) in a list or previous selections in a set.# When the number of selections is small compared to the
# population, then tracking selections is efficient, requiring
# only a small set and an occasional reselection. For
# a larger number of selections, the pool tracking method is
# preferred since the list takes less space than the
# set and it doesn't suffer from frequent reselections.
當K相對N較小時,那么使用python set記錄已選擇的元素位置,重試的概率也會較小。當K較大時,就用list拷貝原序列。顯然,這是一個 hybrid algorithm實現,不管輸入如何,都能有較好的性能。
因此,算法的實現主要考慮的是額外使用的內存,如果list拷貝原序列內存占用少,那么用部分shuffle;如果set占用內存少,那么使用記錄已選項的辦法。
因此核心的問題,就是對使用的內存的判斷。看代碼,有幾個magic num,其實就其中在這三行:
涉及到兩個magic num:21 與 5, 還有一個公式。
magic 21
代碼中是有注釋的,即21是small 減去 empty list的大小。但是,我並沒有搞懂為啥是21.
對於64位的python:
>>> import sys
>>> sys.getsizeof(set())
232
>>> sys.getsizeof([])
72
可以看到,二者相差160。另外,在Linux下,這個size應該都是8的倍數,所以至今不知道21是咋來的
magic 5
這個比較好理解,新創建的set默認會分配一個長度為8的數組。
當set中的元素超過了容量的2/3,那么會開辟新的存儲空間,因此,所以當set中的元素小於等於5個時,使用默認的小數組,無需額外的空間
公式: 4 ** _ceil(_log(k * 3, 4))
log是取對數,base是4, ceil是向上取整。所以上面的公式,其范圍在[3K, 12K]之間。
為什么有這么公式,應該是來自這段代碼(setobject.c::set_add_entry)
set的容量臨界值是3/2 * len(set), 超過了這個值,那么會分配四倍的空間。那么set分配的容量(N)與元素數目(K)的比例大致是 [3/2, 12/2]。
由於set中,一個setentry包含16個字節(8個字節的元素本身,以及8個字節的hash值),而list中一個元素只占用8個字節。所以當對比set與list的內存消耗是,上述的比例乘以了2.
random.sample有沒有問題
當序列的長度小於K個元素所占用的空間時,使用的是部分shuffle的算法,當然,為了避免修改原序列,做了一個list拷貝。
否則使用隨機抽取且放回的算法,需要注意的是,在這個時候, N的范圍是[3K, 12K],即此時K是不可能趨近於N的,按照之前推導的公式 N(logN - log(N-K+1)), 時間復雜度均為O(K)。
因此,不管序列的長度與K的大小關系如何,時間復雜度都是O(K),且保證使用的內存最少。
這里也吐槽一下,在這個函數的docsting里面,提到的是對sequence進行隨機采樣,沒有提到支持dict set,按照對ABC的理解,collections.Sequence 是不包含dict,set的。但事實上這個函數又是支持這兩個類型的參數的。更令人費解的是,對set類型的參數,是通過捕獲異常之后轉換成tuple來來支持的。
random.sample還有這么一個特性:
The resulting list is in selection order so that all sub-slices will also be valid random samples. This allows raffle winners (the sample) to be partitioned into grand prize and second place winners (the subslices).
就是說,對於隨機采樣的結果,其sub slice也是符合隨機采樣結果的,即sample(population, K)[0, M] === sample(population, M), M<=K。在上面提到的各種sample方法中,水塘抽樣算法是不滿足這個特性的。
總結
本文記錄了在使用python random模塊時的一些思考與測試。搞清楚了random.shuffle, random.sample兩個函數的實現原理與時間復雜度。
不過,還有兩個沒有思考清楚的問題
第一:myshuffle的實現中, p[i][j]的公式推導
第二:random.sample中,21 這個magic num是怎么來的
如果園友知道答案,還望不吝賜教