def dropout(X,drop_prob):
X=X.float()//將張量變成浮點數張量
assert 0<=drop_prob<=1//drop_prob不滿足0-1則終止程序
keep_prob=1-drop_prob//對未丟棄的函數進行拉伸
if keep_prob==0:
return torch.zeros_like(X)//返回和X大小相同的全0矩陣
mask=(torch.randn(X.shape)<keep_prob).float()//如果該矩陣的元素小於keep_prob的值返回Fasle大於返回True用float讓布爾值變為浮點數
return mask*x/keep_prob//讓這倆個矩陣進行點對點乘積,再除以keep_prob。