Pytorch打印查看模型參數總數


1.用法:
G_skeleton, D_skeleton, start_epoch = self.build_models()
optimizerG_skeleton, optimizerD_skeleton = self.define_optimizers(G_skeleton, D_skeleton)
criterion = nn.BCELoss()

其中G_skeleton D_skeleton是我們用到的模型。使用以下代碼打印參數總數:

# 打印G和D的總參數數量
print("Total number of param in Generator is ", sum(x.numel() for x in G_skeleton.parameters()))
print("Total number of param in Discriminator is ", sum(x.numel() for x in D_skeleton.parameters()))

2.解析:

my_model.parameters() :用來返回模型中的參數

numel():獲取tensor中一共包含多少個元素
例:image

sum():python內置函數,對元組或列表求和


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM