np.random.choices的使用


在看莫煩python的RL源碼時,他的DDPG記憶庫Memory的實現是這樣寫的:

class Memory(object):
    def __init__(self, capacity, dims):
        self.capacity = capacity
        self.data = np.zeros((capacity, dims))
        self.pointer = 0

    def store_transition(self, s, a, r, s_):
        transition = np.hstack((s, a, [r], s_))
        index = self.pointer % self.capacity  # replace the old memory with new memory
        self.data[index, :] = transition
        self.pointer += 1

    def sample(self, n):
        assert self.pointer >= self.capacity, 'Memory has not been fulfilled'
        indices = np.random.choice(self.capacity, size=n)
        return self.data[indices, :]

其中sample方法用assert斷言pointer >= capacity,也就是說Memory必須滿了才能學習。

我在設計一種方案,一開始往記憶庫里存比較好的transition(也就是reward比較高的),要是等記憶庫填滿再學習好像有點浪費,因為會在填滿之后很快被差的transition所替代,甚至好的transition不能填滿Memory,從而不能有效學習好的經驗。

此時就需要關注np.random.choice方法了,看源碼解釋:

def choice(a, size=None, replace=True, p=None): # real signature unknown; restored from __doc__
    """
    choice(a, size=None, replace=True, p=None)
    
            Generates a random sample from a given 1-D array
    
                    .. versionadded:: 1.7.0
    
            Parameters
            -----------
            a : 1-D array-like or int
                If an ndarray, a random sample is generated from its elements.
                If an int, the random sample is generated as if a were np.arange(a)
            size : int or tuple of ints, optional
                Output shape.  If the given shape is, e.g., ``(m, n, k)``, then
                ``m * n * k`` samples are drawn.  Default is None, in which case a
                single value is returned.
            replace : boolean, optional
                Whether the sample is with or without replacement
            p : 1-D array-like, optional
                The probabilities associated with each entry in a.
                If not given the sample assumes a uniform distribution over all
                entries in a.
    
            Returns
            --------
            samples : single item or ndarray
                The generated random samples

主要第一個參數為ndarray,如果給的是int,np會自動將其通過np.arange(a)轉換為ndarray。

此處主要關注的是,a(我們使用int)< size時,np會怎么取?

上代碼測試

import numpy as np

samples = np.random.choice(3, 5)
print(samples)

輸出:

[2 1 2 1 1]

所以,是會從np.array(a)重復取,可以推斷出,np.random.choice是“有放回地取”(具體我也沒看源碼,從重復情況來看,至少a<size時是這樣的)

然后我分別測試了np.random.choice(5, 5)、np.random.choice(10, 5)等。多試幾次會發現samples中確實是會有重復的。:

import numpy as np

samples = np.random.choice(10, 5)
print(samples)

[3 4 3 4 5]

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM