Pytorch实战学习(三):多维输入


《PyTorch深度学习实践》完结合集_哔哩哔哩_bilibili

Multiple Dimension Imput

 

1、糖尿病预测案例

 

 

2、输入8个特征变量

3、Mini-batch

N个样本,每个样本有8个特征变量

 

3、输入8维变量,输出1维,代码部分修改

 

 

 

 4、构造神经网络

增加网络层数,增加网络复杂度。

 

Layer1:从8D降到6D

Layer2:从6D降到4D

Layer3:从4D降到1D

!!通过网络,维度增加也是可以的!!

 

 

 5、不同激活函数

 

 

 

 

6、代码实现

import torch
import numpy as np

## 载入数据集,delimiter--分隔符
xy = np.loadtxt('diabetes.csv.gz', delimiter=',', dtype=np.float32)
#从numpy中生成Tensor
x_data = torch.from_numpy(xy[:, :-1])
y_data = torch.from_numpy(xy[:, [-1]])



##Design Model

##构造类,继承torch.nn.Module类
class Model(torch.nn.Module):
    ## 构造函数,初始化对象
    def __init__(self):
        ##super调用父类
        super(Model, self).__init__()
        ##构造三层神经网络
        self.linear1 = torch.nn.Linear(8, 6)
        self.linear2 = torch.nn.Linear(6, 4)
        self.linear3 = torch.nn.Linear(4, 1)
        ##激活函数,进行非线性变换
        self.sigmoid = torch.nn.Sigmoid()
        
    ## 构造函数,前馈运算
    def forward(self, x):
        x = self.sigmoid(self.linear1(x))
        x = self.sigmoid(self.linear2(x))
        x = self.sigmoid(self.linear3(x))
        return x
    
# =============================================================================
#     # 激活函数,进行非线性变换
#         self.activate = torch.nn.ReLU()
#         
#     # 构造函数,前馈运算
#     def forward(self, x):
#         x = self.activate(self.linear1(x))
#         x = self.activate(self.linear2(x))
#         #最后一层为了保证输出结果(概率)在[0,1],要用sigmoid
#         x = self.sigmoid(self.linear3(x))
#         return x
# =============================================================================
    
model = Model()

##Construct Loss and Optimizer

##损失函数,传入y和y_pred,size_average--是否取平均
criterion = torch.nn.BCELoss(size_average = True)

##优化器,model.parameters()找出模型所有的参数,Lr--学习率
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)


## Training cycle

for epoch in range(100):
    ##前向传播
    y_pred = model(x_data)
    loss = criterion(y_pred, y_data)
    print(epoch, loss.item())
    
    ##梯度归零
    optimizer.zero_grad()
    ##反向传播
    loss.backward()
    ##更新
    optimizer.step()

 


免责声明!

本站转载的文章为个人学习借鉴使用,本站对版权不负任何法律责任。如果侵犯了您的隐私权益,请联系本站邮箱yoyou2525@163.com删除。



 
粤ICP备18138465号  © 2018-2025 CODEPRJ.COM