import torch from torch.autograd import Variable %matplotlib inline from matplotlib import pyplot as plt from IPython import display
torch.manual_seed(2019) def get_fake_data(batch_size=16): x = torch.randn(batch_size,1) * 20 y = x * 2 + (1 + torch.rand(batch_size,1)) * 3 return x,y #返回的是二維數組
x_train,y_train = get_fake_data() plt.scatter(x_train.squeeze().numpy(),y_train.squeeze().numpy()) #x.squeeze()將二維變為一維
w = Variable(torch.rand(1,1),requires_grad=True) b = Variable(torch.zeros(1,1),requires_grad=True) lr = 1e-6 #lr不能設置太大,否則會梯度爆炸
for i in range(100000): x_train,y_train = get_fake_data() x_train,y_train = Variable(x_train),Variable(y_train) y_pred = x_train.mm(w) + b.expand_as(y_train) loss = 0.5 * (y_pred - y_train) ** 2 loss = loss.sum() loss.backward() w.data.sub_(lr * w.grad.data) b.data.sub_(lr * b.grad.data) w.grad.data.zero_() b.grad.data.zero_() if i % 1000 == 0: display.clear_output(wait=True) x_test = torch.arange(0,20).view(-1,1).float() y_test = x_test.mm(w.data) + b.data.expand_as(x_test) plt.plot(x_test.numpy(),y_test.numpy()) x_train,y_train = get_fake_data(batch_size=20) plt.scatter(x_train.numpy(),y_train.numpy()) plt.xlim(0,20) plt.ylim(0,41) plt.show() plt.pause(0.5) print(w.data.squeeze().item(),b.data.squeeze().item())
最后結果:
代碼來自於《深度學習框架PyTorch:入門與實踐》,環境為PyTorch1.0 + Jupyter