pytorch簡單實現dropout


def dropout(X,drop_prob):
X=X.float()//將張量變成浮點數張量

assert 0<=drop_prob<=1//drop_prob不滿足0-1則終止程序

keep_prob=1-drop_prob//對未丟棄的函數進行拉伸

if keep_prob==0:

  return torch.zeros_like(X)//返回和X大小相同的全0矩陣

mask=(torch.randn(X.shape)<keep_prob).float()//如果該矩陣的元素小於keep_prob的值返回Fasle大於返回True用float讓布爾值變為浮點數

return mask*x/keep_prob//讓這倆個矩陣進行點對點乘積,再除以keep_prob。

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM