PyTorch之具體顯存占用分析


PyTorch之具體顯存占用分析

原始文檔:https://www.yuque.com/lart/ugkv9f/cpionp

前言

PyTorch 使用中,由於顯卡顯存是固定的,並且短期內難以進一步提升,所以掌握顯存具體占用的細節有助於我們寫出更加高效的代碼,甚至跑出更好的結果。

所以本文結合 Connolly 的文章 《PyTorch 顯存機制分析》 按照自己的需求進行了修改,同時梳理了 checkpoint 機制使用過程中的顯存變換情況。

分析

直接看代碼。注釋中表明了特定的顯存占用和參數數量。

首先導入相關的包:

import torch
from torch.utils.checkpoint import checkpoint

初始化必要的數據和結構:

initial_usage = torch.cuda.memory_allocated()
print("0", initial_usage)  # 0

# 模型初始化
linear1 = torch.nn.Linear(1024, 1024, bias=False).cuda()
after_init_linear1 = torch.cuda.memory_allocated()
print("1", after_init_linear1 - initial_usage, linear1.weight.numel())  # 4194304 1048576

linear2 = torch.nn.Sequential(torch.nn.Linear(1024, 1024, bias=False), torch.nn.Linear(1024, 1, bias=False)).cuda()
after_init_linear2 = torch.cuda.memory_allocated()
print("2", after_init_linear2 - after_init_linear1, sum([m.weight.numel() for m in linear2]))  # 4198400 1049600

# 輸入定義
inputs = torch.randn(size=(1024, 1024), device="cuda:0")
after_init_inputs = torch.cuda.memory_allocated()
print("3", after_init_inputs - after_init_linear2, inputs.numel())  # 4194304 1048576

第一次迭代:

print("Iter: 0")

# 前向傳播
o = linear1(inputs)
after_linear1 = torch.cuda.memory_allocated()
print("4", after_linear1 - after_init_inputs, o.numel())  # 4194304 1048576

o = checkpoint(linear2, o)
after_linear2 = torch.cuda.memory_allocated()
# 4096 1024 這里使用了checkpoint,可以看到這里並沒有存儲linear2內部的結果,僅包含輸出o
print("5", after_linear2 - after_linear1, o.numel())

"""
在PyTorch中,顯存是按頁為單位進行分配的,這可能是CUDA設備的限制。
就算我們只想申請4字節的顯存,pytorch也會先向CUDA設備申請2MB的顯存到自己的cache區中,
然后pytorch再為我們分配512字節或者1024字節的空間。
這個在使用torch.cuda.memory_allocated()的時候可以看出來512字節;
用torch.cuda.memory_cached()可以看出向CUDA申請的2MB。
"""
loss = sum(o)
after_loss = torch.cuda.memory_allocated()
# 16785920 512
print("6", after_loss, after_loss - after_linear2)

# 后向傳播
"""
后向傳播會將模型的中間激活值給消耗並釋放掉掉,並為每一個模型中的參數計算其對應的梯度。
在第一次執行的時候,會為模型參數(即葉子結點)分配對應的用來存儲梯度的空間。
所以第一次之后,僅有中間激活值空間在變換。
"""
loss.backward()
after_backward = torch.cuda.memory_allocated()
# 20984320 4198400=-4194304(釋放linear1輸出的o)+4194304(申請linear1權重對應的梯度)+4198400(申請linear2權重對應的梯度)
# 由於checkpoint的使用,所以linear2沒有存儲中間激活值,但是保留了最終的激活值,因為變量o對其引用依然在,所以linear2的輸出未被釋放。
# linear1本身不涉及到中間激活值,而其輸出則由於變量o指向了新的內存,所以會被自動回收。
print("7", after_backward, after_backward - after_loss)

第二次迭代:

print("Iter: 1")

# 前向傳播
o = linear1(inputs)
after_linear1 = torch.cuda.memory_allocated()
print("8", after_linear1 - after_backward, o.numel())  # 4190208 1048576

o = checkpoint(linear2, o)
after_linear2 = torch.cuda.memory_allocated()
# 4096 1024
print("9", after_linear2 - after_linear1, o.numel())

"""
因為前一次計算的loss的引用還在,所以這里沒有再新申請空間。
"""
loss = sum(o)
after_loss = torch.cuda.memory_allocated()
print("10", after_loss, after_loss - after_linear2)  # 25178624 0

# 后向傳播
loss.backward()
after_backward = torch.cuda.memory_allocated()
# 20984320 -4194304
# 這減去部分的恰好等於中間激活值的占用:-4190208(linear1的輸出o)-4096(linear2輸出o)
# 這里的linaer2使用了checkpoint,則不存linear2中間特征的額外占用,因為這部分是在運算內部申請並實時釋放的
print("11", after_backward, after_backward - after_loss)

第三次迭代:

del loss  # 用於驗證loss對應的內存的回收情況

print("Iter: 2")

# 前向傳播
o = linear1(inputs)
after_linear1 = torch.cuda.memory_allocated()
print("12", after_linear1 - after_backward, o.numel())  # 4190208 1048576

o = linear2(o)
after_linear2 = torch.cuda.memory_allocated()
# 4198400=1024*1024*4(linear2的中間特征)+1024*4(linear2輸出o) 1024
print("13", after_linear2 - after_linear1, o.numel())

"""
在前一次計算后,del loss的話,可以看到這里會申請512字節的空間
"""
loss = sum(o)
after_loss = torch.cuda.memory_allocated()
print("14", after_loss, after_loss - after_linear2)  # 29372928 512

# 后向傳播
loss.backward()
after_backward = torch.cuda.memory_allocated()
# 20984320 -8388608
# 這減去部分的恰好等於中間激活值的占用:-4190208(linear1的輸出o)-4194304(1024*1024*4(linear2中間特征))-4096(linear2輸出o)
print("15", after_backward, after_backward - after_loss)


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM