蓄水池采樣算法
問題描述分析
采樣問題經常會被遇到,比如:
- 從 100000 份調查報告中抽取 1000 份進行統計。
- 從一本很厚的電話簿中抽取 1000 人進行姓氏統計。
- 從 Google 搜索 "Ken Thompson",從中抽取 100 個結果查看哪些是今年的。
這些都是很基本的采用問題。
既然說到采樣問題,最重要的就是做到公平,也就是保證每個元素被采樣到的概率是相同的。所以可以想到要想實現這樣的算法,就需要擲骰子,也就是隨機數算法。(這里就不具體討論隨機數算法了,假定我們有了一套很成熟的隨機數算法了)
對於第一個問題,還是比較簡單,通過算法生成 \([0, 100000 - 1)\) 間的隨機數 1000 個,並且保證不重復即可。再取出對應的元素即可。
但是對於第二和第三個問題,就有些不同了,我們不知道數據的整體規模有多大。可能有人會想到,我可以先對數據進行一次遍歷,計算出數據的數量 \(N\),然后再按照上述的方法進行采樣即可。這當然可以,但是並不好,畢竟這可能需要花上很多時間。也可以嘗試估算數據的規模,但是這樣得到的采樣數據分布可能並不平均。
算法過程
終於要講到蓄水池采樣算法(Reservoir Sampling)了。先說一下算法的過程:
假設數據序列的規模為 \(n\),需要采樣的數量的為 \(k\)。
首先構建一個可容納 \(k\) 個元素的數組,將序列的前 \(k\) 個元素放入數組中。
然后從第 \(k+1\) 個元素開始,以 \(\frac{k}{n}\) 的概率來決定該元素是否被替換到數組中(數組中的元素被替換的概率是相同的)。 當遍歷完所有元素之后,數組中剩下的元素即為所需采取的樣本。
證明過程
對於第 \(i\) 個數(\(i \le k\))。在 \(k\) 步之前,被選中的概率為 \(1\)。當走到第 \(k + 1\) 步時,被 \(k + 1\) 個元素替換的概率 = \(k + 1\) 個元素被選中的概率 * \(i\) 被選中替換的概率,即為 \(\frac{k}{k + 1} \times \frac{1}{k} = \frac{1}{k + 1}\)。則被保留的概率為 \(1 - \frac{1}{k + 1} = \frac{k}{k + 1}\)。依次類推,不被 \(k + 2\) 個元素替換的概率為 \(1 - \frac{k}{k + 2} \times \frac{1}{k} = \frac{k + 1}{k + 2}\)。則運行到第 \(n\) 步時,被保留的概率 = 被選中的概率 * 不被替換的概率,即:
對於第 \(j\) 個數(\(j > k\))。在第 \(j\) 步被選中的概率為 \(\frac{k}{j}\)。不被 \(j + 1\) 個元素替換的概率為 \(1 - \frac{k}{j + 1} \times \frac{1}{k} = \frac{j}{j + 1}\)。則運行到第 \(n\) 步時,被保留的概率 = 被選中的概率 * 不被替換的概率,即:
所以對於其中每個元素,被保留的概率都為 \(\frac{k}{n}\).
代碼示例
貼出測試用的示例代碼(Java 實現):
public class ReservoirSamplingTest {
private int[] pool; // 所有數據
private final int N = 100000; // 數據規模
private Random random = new Random();
@Before
public void setUp() throws Exception {
// 初始化
pool = new int[N];
for (int i = 0; i < N; i++) {
pool[i] = i;
}
}
private int[] sampling(int K) {
int[] result = new int[K];
for (int i = 0; i < K; i++) { // 前 K 個元素直接放入數組中
result[i] = pool[i];
}
for (int i = K; i < N; i++) { // K + 1 個元素開始進行概率采樣
int r = random.nextInt(i + 1);
if (r < K) {
result[r] = pool[i];
}
}
return result;
}
@Test
public void test() throws Exception {
for (int i : sampling(100)) {
System.out.println(i);
}
}
}
結果就不貼出來了,畢竟每次運行結果都不一樣。
總結
蓄水池算法適用於對一個不清楚規模的數據集進行采樣。以前在某個地方看到過一個面試題,說是從一個字符流中進行采樣,最后保留 10 個字符,而並不知道這個流什么時候結束,且須保證每個字符被采樣到的幾率相同。用的就是這個算法。
在高德納的 TAOCP 中有對於這個算法的描述,可以說這是個很精巧的算法。在看到這個算法實現前,很難想到可以通過這樣的一種方式進行采樣。