pytorch简单实现dropout


def dropout(X,drop_prob):
X=X.float()//将张量变成浮点数张量

assert 0<=drop_prob<=1//drop_prob不满足0-1则终止程序

keep_prob=1-drop_prob//对未丢弃的函数进行拉伸

if keep_prob==0:

  return torch.zeros_like(X)//返回和X大小相同的全0矩阵

mask=(torch.randn(X.shape)<keep_prob).float()//如果该矩阵的元素小于keep_prob的值返回Fasle大于返回True用float让布尔值变为浮点数

return mask*x/keep_prob//让这俩个矩阵进行点对点乘积,再除以keep_prob。

 


免责声明!

本站转载的文章为个人学习借鉴使用,本站对版权不负任何法律责任。如果侵犯了您的隐私权益,请联系本站邮箱yoyou2525@163.com删除。



 
粤ICP备18138465号  © 2018-2025 CODEPRJ.COM