sgd學習率選擇問題


關於使用SGD時如何選擇初始的學習率(這里SGD是指帶動量的SGD,momentum=0.9):

訓練一個epoch,把學習率從一個較小的值(10-8)上升到一個較大的值(10),畫出學習率(取log)和經過平滑后的loss的曲線,根據曲線來選擇合適的初始學習率。

從上圖可以看出學習率和loss之間的關系,最曲線的最低點的學習率已經有了使loss上升的趨勢,曲線的最低點不選。最低點左邊的點都是可供選擇的點,但是選擇太小的學習率會導致收斂的速度過慢,所以根據上圖我們可以選擇0.01(10-2)為初始的學習率。

關於學習率的調整策略,在使用SGD時不建議使用指數型連續下降的調節方法,建議使用階梯式調節學習率的方法。每隔一定數量的epoch學習率調節為之前的0.1倍(根據自己實際任務調節每個階段迭代epoch的數量)。

如果不想使用上述方法,這里提供幾個經驗值供選擇,fine-tune模型初始學習率可設置為0.01,從頭開始訓練模型學習率可設置為0.1(僅供參考)。

供參考的尋找初始學習率的pytorch代碼(根據自己的任務進行修改):

def find_lr(init_value = 1e-8, final_value=10., beta = 0.98):
    num = len(train_loader)-1
    mult = (final_value / init_value) ** (1/num)
    lr = init_value
    optimizer.param_groups[0]['lr'] = lr
    avg_loss = 0.
    best_loss = 0.
    batch_num = 0
    losses = []
    log_lrs = []
    for data in train_loader:
        batch_num += 1
        #As before, get the loss for this mini-batch of inputs/outputs
        inputs,labels = data
        inputs, labels = Variable(inputs), Variable(labels)
        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        #Compute the smoothed loss
        avg_loss = beta * avg_loss + (1-beta) *loss.data[0]
        smoothed_loss = avg_loss / (1 - beta**batch_num)
        #Stop if the loss is exploding
        if batch_num > 1 and smoothed_loss > 4 * best_loss:
            return log_lrs, losses
        #Record the best loss
        if smoothed_loss < best_loss or batch_num==1:
            best_loss = smoothed_loss
        #Store the values
        losses.append(smoothed_loss)
        log_lrs.append(math.log10(lr))
        #Do the SGD step
        loss.backward()
        optimizer.step()
        #Update the lr for the next step
        lr *= mult
        optimizer.param_groups[0]['lr'] = lr
    return log_lrs, losses
參考論文《Cyclical Learning Rates for Training Neural Networks》
和博客https://sgugger.github.io/how-do-you-find-a-good-learning-rate.html


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM