TensorFlow.js之根據數據擬合曲線


  這篇文章中,我們將使用TensorFlow.js來根據數據擬合曲線。即使用多項式產生數據然后再改變其中某些數據(點),然后我們會訓練模型來找到用於產生這些數據的多項式的系數。簡單的說,就是給一些在二維坐標中的散點圖,然后我們建立一個系數未知的多項式,通過TensorFlow.js來訓練模型,最終找到這些未知的系數,讓這個多項式和散點圖擬合。

  

一、運行代碼

  這篇文章關注的是創建模型以及學習模型的系數,完整的代碼在這里可以找到。為了在本地運行,如下所示:

$ git clone https://github.com/tensorflow/tfjs-examples.git
$ cd tfjs-examples/polynomial-regression-core
$ yarn
$ yarn watch

  即首先將核心代碼下載到本地,然后進入polynomial-regression-core(即多項式回歸核心)部分,最后進行yarn安裝並運行。

 

二、輸入數據

  我們的數據在x坐標軸和y坐標軸內,看上去就是將之放在了笛卡爾坐標系,如下所示:

  4

  

  即這個圖是 y = ax3 + bx2 + cx + d得到的,而在上圖中,我們也看到了其真實系數為a=-0.800,b=-0.200,c=0.900,d=0.500,然后這些點是根據真實的點做了一定的偏移。  

  我們的任務就是通過機器學習得到這個函數的系數a、b、c以及d來最好的匹配這些數據。接下來,我們就看看如何通過TensorFlow.js來學習得到這些數據。

 

三、學習步驟

第一步 :設置變量

  首先,我們需要創建一些變量。即開始我們是不知道a、b、c、d的值的,所以先給他們一個隨機數,入戲所示:

const a = tf.variable(tf.scalar(Math.random()));
const b = tf.variable(tf.scalar(Math.random()));
const c = tf.variable(tf.scalar(Math.random()));
const d = tf.variable(tf.scalar(Math.random()));

 

 

第二步:創建模型

  我們可以通過TensorFlow.js中的鏈式調用操作來實現這個多項式方程  y = ax3 + bx2 + cx + d,下面的代碼就創建了一個 predict 函數,這個函數將x作為輸入,y作為輸出:

function predict(x) {
  // y = a * x ^ 3 + b * x ^ 2 + c * x + d
  return tf.tidy(() => {
    return a.mul(x.pow(tf.scalar(3))) // a * x^3
      .add(b.mul(x.square())) // + b * x ^ 2
      .add(c.mul(x)) // + c * x
      .add(d); // + d
  });
}

  其中,在上一篇文章中,我們講到tf.tify函數用來清除中間張量,其他的都很好理解。

  接着,讓我們把這個多項式函數的系數使用之前得到的隨機數,可以看到,得到的圖應該是這樣:

  

  因為開始時,我們使用的系數是隨機數,所以這個函數和給定的數據匹配的非常差,而我們寫的模型就是為了通過學習得到更精確的系數值。

 

第三步:訓練模型

  最后一步就是要訓練這個模型使得系數和這些散點更加匹配,而為了訓練模型,我們需要定義下面的三樣東西:

  • 損失函數(loss function):這個損失函數代表了給定多項式和數據的匹配程度。 損失函數值越小,那么這個多項式和數據就跟匹配。
  • 優化器(optimizer):這個優化器實現了一個算法,它會基於損失函數的輸出來修正系數值。所以優化器的目的就是盡可能的減小損失函數的值。
  • 訓練迭代器(traing loop):即它會不斷地運行這個優化器來減少損失函數。

  所以,上面這三樣東西的 關系就非常清楚了: 訓練迭代器使得優化器不斷運行,使得損失函數的值不斷減小,以達到多項式和數據盡可能匹配的目的。這樣,最終我們就可以得到a、b、c、d較為精確的值了。

 

  

四、定義損失函數

  這篇文章中,我們使用MSE(均方誤差,mean squared error)作為我們的損失函數。MSE的計算非常簡單,就是先根據給定的x得到實際的y值與預測得到的y值之差 的平方,然后在對這些差的平方求平均數即可

  

  於是,我們可以這樣定義MSE損失函數:

function loss(predictions, labels) {
  // 將labels(實際的值)進行抽象
  // 然后獲取平均數.
  const meanSquareError = predictions.sub(labels).square().mean();
  return meanSquareError;
}

   即這個損失函數返回的就是一個均方差,如果這個損失函數的值越小,顯然數據和系數就擬合的越好。

  

 

五、定義優化器

  對於我們的優化器而言,我們選用 SGD (Stochastic Gradient Descent)優化器,即隨機梯度下降SGD的工作原理就是利用數據中任意的點的梯度以及使用它們的值來決定增加或者減少我們模型中系數的值

  TensorFlow.js提供了一個很方便的函數用來實現SGD,所以你不需要擔心自己不會這些特別復雜的數學運算。 即 tf.train.sdg 將一個學習率(learning rate)作為輸入,然后返回一個SGDOptimizer對象,它與優化損失函數的值是有關的。

  在提高它的預測能力時,學習率(learning rate)會控制模型調整幅度將會有多大。低的學習率會使得學習過程運行的更慢一些(更多的訓練迭代獲得更符合數據的系數),而高的學習率將會加速學習過程但是將會導致最終的模型可能在正確值周圍搖擺。簡單的說,你既想要學的快,又想要學的好,這是不可能的。

  下面的代碼就創建了一個學習率為0.5的SGD優化器。

const learningRate = 0.5;
const optimizer = tf.train.sgd(learningRate);

  

六、定義訓練迭代器

  既然我們已經定義了損失函數和優化器,那么現在我們就可以創建一個訓練迭代器了,它會不斷地運行SGD優化器來使不斷修正、完善模型的系數來減小損失(MSE)。下面就是我們創建的訓練迭代器:

function train(xs, ys, numIterations = 75) {

  const learningRate = 0.5;
  const optimizer = tf.train.sgd(learningRate);

  for (let iter = 0; iter < numIterations; iter++) {
    optimizer.minimize(() => {
      const predsYs = predict(xs);
      return loss(predsYs, ys);
    });
  }
}

  現在,讓我們一步一步地仔細看看上面的代碼。首先,我們定義了訓練函數,並且以數據中x和y的值以及制定的迭代次數作為輸入:

function train(xs, ys, numIterations) {
...
}

  接下來,我們定義了之前討論過的學習率(learning rate)以及SGD優化器:

const learningRate = 0.5;
const optimizer = tf.train.sgd(learningRate);

  

  最后,我們定義了一個for循環,這個循環會運行numIterations次訓練。在每一次迭代中,我們都調用了optimizer優化器的minimize函數,這就是見證奇跡的地方:

for (let iter = 0; iter < numIterations; iter++) {
  optimizer.minimize(() => {
    const predsYs = predict(xs);
    return loss(predsYs, ys);
  });
}

  minimize 接受了一個函數作為參數,這個函數做了下面的兩件事情:

  1. 首先它對所有的x值通過我們在之前定義的pridict函數預測了y值。
  2. 然后它通過我們之前定義的損失函數返回了這些預測的均方誤差。

    

  minimize函數之后會自動調整這些變量(即系數a、b、c、d)來使得損失函數更小。

  在運行訓練迭代器之后,a、b、c以及d就會是通過模型75次SGD迭代之后學習到的結果了。 

  

 

七、觀察結果吧!

  一旦程序運行結束,我們就可以得到最終的a、b、c和d的結果了,然后使用它們來繪制曲線,如下所示:

      

  這個結果已經比開始隨機分配系數的結果擬合的好得多了!


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM