水塘抽样(Reservoir sampling)

题目:给出一个数据流,这个数据流的长度很大或者未知。并且对该数据流中的数据只能访问一次。请写出一个随机选择算法,使得数据流中所有数据被选中的概率相等。

  这个问题的扩展就是:如何从未知或者很大样本空间随机的取k个数?或者说,数据流长度为N行,要随机抽取k行,则每一行被抽取的概率为k/N。

  这个问题也就是在大数据流中的随机抽样问题:当内存无法加载全部数据时,如何从包含未知大小的数据流中随机选取k个数据,并且要保证每个数据被抽取到的概率相等。

k = 1 k=1 k=1

首先考虑最简单的情况,当 k = 1 k=1 k=1时,如何选取:

  • 假设数据流含有N个数据,要保证每条数据被抽取到的概率相等,那么每个数被抽取的概率应该为 1 N \frac{1}{N} N1
    • 遇到第1个数 n 1 n_1 n1的时候,保留它, p ( n 1 ) = 1 p(n_1)=1 p(n1)=1
    • 遇到第2个数 n 2 n_2 n2的时候,以 1 2 \frac{1}{2} 21的概率保留它,那么 p ( n 1 ) = 1 × 1 2 = 1 2 p(n_1)=1×\frac{1}{2}=\frac{1}{2} p(n1)=1×21=21 p ( n 2 ) = 1 2 p(n_2)=\frac{1}{2} p(n2)=21
    • 遇到第3个数 n 3 n_3 n3的时候,以 1 3 \frac{1}{3} 31的概率保留它,那么 p ( n 1 ) = p ( n 2 ) = 1 2 × ( 1 1 3 ) = 1 3 p ( n 3 ) = 1 3 p(n_1)=p(n_2)=\frac{1}{2}×(1-\frac{1}{3})=\frac{1}{3},p(n_3)=\frac{1}{3} p(n1)=p(n2)=21×(131)=31p(n3)=31
    • 遇到第i个数 n i n_i ni的时候,以 1 i \frac{1}{i} i1的概率保留它,那么 p ( n 1 ) = p ( n 2 ) = p ( n 3 ) = . . . = p ( n i 1 ) = 1 i 1 × ( 1 1 i ) = 1 i p ( n i ) = 1 i p(n_1)=p(n_2)=p(n_3)=...=p(n_{i-1})=\frac{1}{i-1}×(1-\frac{1}{i})=\frac{1}{i},p(n_i)=\frac{1}{i} p(n1)=p(n2)=p(n3)=...=p(ni1)=i11×(1i1)=i1p(ni)=i1

  通过以上规律可以看出,对于 k = 1 k=1 k=1的情况,数据流中第i个数被保留的概率为 1 i \frac{1}{i} i1。只要采取这种策略,只需要遍历一遍数据流就可以得到采样值,并且保证所有数据被选中的概率均为 1 N \frac{1}{N} N1

k > 1 k>1 k>1

对于 k > 1 k>1 k>1的情况,我们可以采取类似的策略:

  • 假设数据流中含有N个数据,要保证每条数据被抽取到的概率相等,那么每个数被抽取的概率必然是 k N \frac{k}{N} Nk
    • 对于前k个数 n 1 , n 2 , . . . , n k n_1,n_2,...,n_k n1,n2,...,nk,我们保留下来,则 p ( n 1 ) = p ( n 2 ) = . . . = p ( n k ) = 1 p(n_1)=p(n_2)=...=p(n_k)=1 p(n1)=p(n2)=...=p(nk)=1(下面连等采用 p ( n 1 k ) p(n_{1-k}) p(n1k)的形式
    • 对于第k+1个数 n k + 1 n_{k+1} nk+1,以 k k + 1 \frac{k}{k+1} k+1k的概率保留它(这里只是指本次保留下来),那么前k个数中的 n r ( r 1 k ) n_r(r∈1-k) nr(r1k)被保留的概率可以这样表示: p ( n r ) = p ( n r ) × ( p ( n k + 1 ) + p ( n k + 1 ) × p ( n r ) ) p(n_r被保留)=p(上一轮n_r被保留)×(p(n_{k+1}被丢弃)+p(n_{k+1}被保留)×p(n_r未被替换)) p(nr)=p(nr)×(p(nk+1)+p(nk+1)×p(nr)),即 p 1 k = 1 k + 1 + k k + 1 × k 1 k = k k + 1 p_{1-k}=\frac{1}{k+1}+\frac{k}{k+1}×\frac{k-1}{k}=\frac{k}{k+1} p1k=k+11+k+1k×kk1=k+1k
    • 对于第k+2个数 n k + 2 n_{k+2} nk+2,以 k k + 2 \frac{k}{k+2} k+2k的概率保留它(这里只是指本次保留下来),那么前k+1个被保留下来的数中的 n r ( r 1 k + 1 ) n_r(r∈1-k+1) nr(r1k+1)被保留的概率为: p 1 k = k k + 1 × 2 k + 2 + k k + 1 × k 1 k + 2 p_{1-k}=\frac{k}{k+1}×\frac{2}{k+2}+\frac{k}{k+1}×\frac{k-1}{k+2} p1k=k+1k×k+22+k+1k×k+2k1
    • 对于第i(i>k)个数 n i n_i ni,以 k i \frac{k}{i} ik的概率保留它,前i-1个数中的 n r ( r 1 i 1 ) n_r(r∈1-i-1) nr(r1i1)被保留的概率为: p 1 k = k i 1 × i k i + k i 1 × k 1 i = k i p_{1-k}=\frac{k}{i-1}×\frac{i-k}{i}+\frac{k}{i-1}×\frac{k-1}{i}=\frac{k}{i} p1k=i1k×iik+i1k×ik1=ik

  对于前k个数,全部保留,对于第i(i>k)个数,以 k i \frac{k}{i} ik的概率保留第i个数,并以 1 k \frac{1}{k} k1的概率与前面已选择的k个数中的任意一个替换。

总结

  也就是说,在取第i个数据的时候,生成一个01的随机数p,如果 p < k i p<\frac{k}{i} p<ik,替换池中任意一个为第i个数;当 p > k i p>\frac{k}{i} p>ik,继续保留前面的数。直到数据流结束,返回此k个数。但是为了保证计算准确性,一般是生成一个0i的随机数,跟k相比。

Scala代码实现
import scala.util.Random

/**
 * @author xiaoer
 * @date 2020/1/11 23:53
 */
object ReservoirSampling {
    def main(args: Array[String]): Unit = {
        val s: Array[Int] = Array(1, 2, 3, 4, 5, 6, 7)
        val k: Array[Int] = new Array[Int](3)
        val result: Array[Int] = reservoirSampling(k, s)
        println(result.toBuffer)
    }

    /**
     * 水塘抽样算法
     *
     * @param k 抽样结果
     * @param s 样本总数
     * @return k 样本结果
     */
    def reservoirSampling(k: Array[Int], s: Array[Int]): Array[Int] = {
        // 将前 k 个数据全部抽取
        for (i <- k.indices) {
            k(i) = s(i)
        }

        // k+1 往后的数据
        for (i <- k.length until s.length) {
            val seed: Int = Random.nextInt(i)
            if (seed < k.length) {
                k(seed) = s(i)
            }
        }

        // 返回值
        k
    }
}