Proximal Policy Optimization(PPO)算法 / 2017


Intro

2016年Schulman等人提出了Trust Region Policy Optimization算法。后來他們又發現TRPO算法在scalable(用於大模型和並行實現), data efficient(高效利用采樣數據), robust(同一套超參,在大量不同的env上取得成功)上可以改進,於是作為TRPO的改進版本提出了PPO。

PPO在2017年被Schulman等人提出后就刷新了continous control領域的SOTA記錄,並且成為了OPENAI的default algorithm。雖然現在它已經不是領域的SOTA算法了,但因為PPO易於部署而且迭代過程方差小,訓練較穩定,關鍵是使用方便,所以目前(2020.11)它還是大多數場景下的default algorithm。

PPO造出來前,其他的流行RL算法缺點在哪?

  • DQN pooly understood,而且在很多簡單任務上失敗; 不支持continious control;訓練過程poorly robust。

  • vanilla policy gradient 數據利用效率低,訓練過程poorly robust。

  • trust region policy gradient 算法結構復雜,而且兼容性差,

PPO算法結構設計思想

  • 為actor設計新的損失函數。clipped surrogate objective

  • 采樣得到的數據,在更新agent的時候重復使用。multiple epochs of minibatch updates

基本算法構造

神經網絡架構

在PPO中critic的價值函數是V(s),而不是Q(s,a)。這和DDPG就相反,DDPG中critic的價值函數是Q(s,a)。

PPO的paper並未寫明算法的具體實現。因此我瀏覽github,調查了兩種實現方法。

我和作者都proposed的版本是讓actor和critic參數共享。一個輸入,即obs。兩個輸出,即actor輸出和value輸出。如:

class PPO(nn.Module):
    def __init__(self, num_inputs, num_actions):
        super(PPO, self).__init__()
        self.conv1 = nn.Conv2d(num_inputs, 32, 3, stride=2, padding=1)
        self.conv2 = nn.Conv2d(32, 32, 3, stride=2, padding=1)
        self.conv3 = nn.Conv2d(32, 32, 3, stride=2, padding=1)
        self.conv4 = nn.Conv2d(32, 32, 3, stride=2, padding=1)
        self.linear = nn.Linear(32 * 6 * 6, 512)
        self.critic_linear = nn.Linear(512, 1)
        self.actor_linear = nn.Linear(512, num_actions)
        self._initialize_weights()

    def _initialize_weights(self):
        for module in self.modules():
            if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear):
                nn.init.orthogonal_(module.weight, nn.init.calculate_gain('relu'))
                # nn.init.xavier_uniform_(module.weight)
                # nn.init.kaiming_uniform_(module.weight)
                nn.init.constant_(module.bias, 0)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = F.relu(self.conv4(x))
        x = self.linear(x.view(x.size(0), -1))
        return self.actor_linear(x), self.critic_linear(x)

GAE(generalized advantage estimator)

一項在Schulman, John, et al. "High-dimensional continuous control using generalized advantage estimation." arXiv preprint arXiv:1506.02438 (2015)中提出的技術。它是針對過去policy gradient系列算法中的returns作了調整,用GAE方法獲取returns。

def compute_gae(next_value, rewards, masks, values, gamma=0.99, tau=0.95):
    values = values + [next_value]
    gae = 0
    returns = []
    for step in reversed(range(len(rewards))):
        delta = rewards[step] + gamma * values[step + 1] * masks[step] - values[step]
        gae = delta + gamma * tau * masks[step] * gae
        returns.insert(0, gae + values[step])
    return returns

actor損失函數設計

PPO paper中對比了三個損失函數設計,CPI版本(作為baseline,就是沒有clipping也沒有KL penalty的),Clipping版本,KL penalty版本。
根據實驗結果,Clipping版本的損失函數作為proposed損失函數。

Clipping版本如下:

\(L^{CLIP}(\theta) = E_{t}( min<r_t(\theta)A_{t}, clip(r_t(\theta), 1-\epsilon,1+\epsilon)A_{t}> )\)

\(\1\) \(A_{t}\)

這里的A表示advantage。\(advantage = returns - values\)。returns是計算得到的GAE值,values是模型的一步輸出。

\(\2\) \(r(\theta)_{t}\)

這里的r表示ratio,是新舊策略得到的logit_prob的比值。舊,指從env采樣的時候用的policy。新,指在多步更新agent時的實時policy。

\(r(\theta)_{t} = \frac{\pi(a_{t}|s_{t})}{\pi_{old}(a_{t}|s_{t})}\)

PPO整體損失函數設計

PPO的actor和critic參數共享,用一個loss來同時更新actor和critic。loss設計如下:

\(LOSS = actor_loss + 0.5*critic_loss\)

actor-loss即前面的\(L^{CLIP}(\theta)\)。critic_loss如下:

\(critic_loss = \sum(return - value)\)

這里return就是前面advantage用的return,value是多步更新時實時從agent輸出出來的value。

采樣(sample from environment)與更新(update our agent)

強化學習算法,無外乎兩個東西的交替迭代,即“sample from environment”和“update our agent”。在PPO算法中,采用這樣的時序設計。

LOOP for l rides:
      LOOP for s steps:
            sample from env, select action, according to current agent
            compute advantages for each transition
            compute GAE for each transition
      compute GAE as return for each step
      enough (states, actions, log_probs, returns, advantages) were collected 
      LOOP for p minibatchs:
            compute logit_probs, value, according to current agent
            compute loss 
            update our agent
            

大規模、並行場景下實現PPO算法

./keep


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM