一般DQN中的經驗池類,都類似於下面這段代碼。
import random
from collections import namedtuple, deque
Transition = namedtuple('Transition', ('state', 'next_state', 'action', 'reward'))
# 經驗池類
class ReplayMemory(object):
def __init__(self, capacity):
self.capacity = capacity # 容量
self.memory = []
self.position = 0
# 將四元組壓入經驗池
def push(self, *args):
if len(self.memory) < self.capacity:
self.memory.append(None)
self.memory[self.position] = Transition(*args)
self.position = (self.position + 1) % self.capacity
# 從經驗池中隨機壓出一個四元組
def sample(self, batch_size):
transitions = random.sample(self.memory, batch_size)
batch = Transition(*zip(*transitions))
return batch
def __len__(self):
return len(self.memory)
對Python不太熟悉的我里邊就有兩點比較迷惑,一個是namedtuple()方法,一個是sample方法的倒數第二行,為什么要這樣處理。
第一點,namedtuple()是繼承自tuple的子類,namedtuple()方法能夠創建一個和tuple類似的對象,而且對象擁有可訪問的屬性。
第二點,也就是sample方法中的倒數第二行,這里進行了一個轉換, 將batch_size個四元組,轉換成,四個元祖,每個元祖一共有batch_size項,這里放個程序解釋一下。
import random
from collections import namedtuple
if __name__ == '__main__':
batch_size = 3
Transition = namedtuple('Transition', ('state', 'next_state', 'action', 'reward'))
a=Transition(state=1,next_state=2,action=3,reward=4)
b=Transition(state=11,next_state=12,action=13,reward=14)
c=Transition(state=21,next_state=22,action=23,reward=24)
d=Transition(state=31,next_state=32,action=33,reward=34)
e=Transition(state=41,next_state=42,action=43,reward=44)
f=[a,b,c,d,e]
# 從f中隨機抽取batch_size個數據
t=random.sample(f,batch_size)
print("隨機抽取的batch_size個四元祖是:")
for i in range(batch_size):
print(t[i])
print()
# 將t進行解壓操作
print("將四元組進行解壓后是:")
print(*zip(*t))
print()
# 將t進行解壓操作,再進行Transition轉換
# 將batch_size個四元組,轉換成,四個元組,每個元組一共有batch_size項
print("將四元組進行解壓后再進行Transition轉換后是:")
batch=Transition(*zip(*t))
print(batch)
輸出結果:
隨機抽取的batch_size個四元祖是:
Transition(state=21, next_state=22, action=23, reward=24)
Transition(state=11, next_state=12, action=13, reward=14)
Transition(state=41, next_state=42, action=43, reward=44)
將四元組進行解壓后是:
(21, 11, 41) (22, 12, 42) (23, 13, 43) (24, 14, 44)
將四元組進行解壓后再進行Transition轉換后是:
Transition(state=(21, 11, 41), next_state=(22, 12, 42), action=(23, 13, 43), reward=(24, 14, 44))