背景知識
最近再看一些量化交易相關的材料,偶然在網上看到了一個關於用RNN實現股票預測的文章,出於好奇心把文章中介紹的代碼在本地跑了一遍,發現可以work。於是就花了兩個晚上的時間學習了下代碼,順便把核心的內容翻譯成中文分享給大家。
首先講講對於股票預測的理解,股票是一種可以輕易用數字表現律動的交易形式。因為大數定理的存在,定義了世間所有的行為都可以通過數字表示,並且存在一定的客觀規律。股票也不例外,量化交易要做的就是通過數學模型發現股票的走勢趨勢。“趨勢”要這樣理解:對於股票的預測,不是說我知道這個股票昨天指數是多少,然后預測今天他的指數能漲到多少。而是,我們通過過去一段時間股票的跌或者漲,總結出當出現某種波動的時候股票會有相應的漲或者跌的趨勢。於是就引出了RNN的概念。
RNN是一種深度學習的網絡結構,RNN的優勢是它在訓練的過程中會考慮數據的上下文聯系,非常適合股票的場景,因為某一時刻的波動往往跟之前的走勢蘊含某種聯系。RNN是由一個個神經元cell組成,然而傳統的RNN當網絡過於復雜的時候,后方節點對於前方的感知力會下降,LSTM(Long-short Term Memory)是一種變型,從名字就可以看出來,LSTM可以增加記憶力,解決上面提到的問題。對於股票這個場景,我們就可以通過LSTM來實現股票的走勢的預測。
在股票這個場景下,通過上面這個圖可以看出來,輸入的是時間t、t+1、t+2的股票信息,可以返回t+1、t+2、t+3的股票信息,而且上下節點前后依賴,通過LSTM模型對於這樣的股票序列進行預測,所以股票預測的關鍵就是首先構建股票序列化數據,然后訓練LSTM模型,最終通過這個模型對於股票進行預測,以上就是大體的一些思路。
數據說明
本次實驗使用的是一只叫SP500的股票,可以從雅虎下載這只股從50年到現在每天的走勢情況,這里只需要關心每次收盤價格,也就是close字段即可。數據截圖:
代碼
代碼文件有以下四部分:
其中SP500的股票數據需要放在data文件夾下。依賴的庫包括,
numpy==1.13.1pandas==0.16.2 scikit-learn==0.16.1 scipy==0.19.1 tensorflow==1.2.1
在項目目錄下執行以下shell即可開始訓練:
python main.py --stock_symbol=SP500 --train --input_size=1 --lstm_size=128 --max_epoch=50
分別介紹下每個代碼:
data_model.py
這個文件是構建訓練數據,通過pandas庫去讀數據SP500.csv文件,然后只取close這個字段,將每天的close數據作為代表當天股票的市值,如下圖所示。
這里做了一次歸一化,因為股票在50年的市值是每股19塊左右,到了2017年漲到了2600多塊,分布很不均勻,於是通過把每天股票close值除以歷史股票最高值,將所有數據的定義域限定在0到1之間。接着構建預測集,涉及到兩個參數input_size和num_steps,當input_size=3 and num_steps=2時會構建以下數據集。
第0、1、2天的股票和第3、4、5天的股票為訓練集,第6、7、8天的股票是目標列,就構成了監督學習數據。以此類推,將所有數據構成訓練數據集。
model_rnn.py
構建模型的文件,通過build_graph函數去構建整個的LSTM網絡,同時定義最優化求法的optimizer。通過train函數定義數據如何在graph中訓練,包括model參數的存儲。plot_samples會在訓練過程中將測試集數據和訓練數據的比較打印成圖片輸出。
main.py
入口代碼,定義運行參數,包括epoch的輪數、learning_rate等等。
結果評估
其實,在測試的時候,整個工程就將生成的預測數據和真實數據進行比較並且在images文件夾下生成圖片。我們通過圖片直觀的可以看下隨着訓練的進行,是否真正可以模擬出股票曲線,首先是epoch=5的時候,也就是訓練第5輪的時候,我們看到綠色的predict曲線和藍色的truth曲線擬合的並不好。
再來看下又過了40多輪訓練生成的圖片:
我們看到股票的曲線擬合程度已經進步非常多,相信隨着數據和訓練輪次的增加,預測值會越來越精確。
PS:總結完了,建議大家想學習的自己跟一遍代碼,我自己看了2個晚上,加起來4個小時左右。我整理的代碼和數據下載鏈接在下面已經給出。另外誰認識北京的做量化交易相關的同學,請幫忙引薦,最近在工作之余自學量化交易相關的內容,希望可以有業內同學當面交流一下,多謝。
參考
項目地址:https://github.com/lilianweng/stock-rnn
作者寫的介紹博文,很詳細,學到很多:
https://lilianweng.github.io/lil-log/2017/07/08/predict-stock-prices-using-RNN-part-1.html
另外我基於lilianweng的工作,精簡了一部分代碼,並且修改了部分版本不兼容的第三方庫函數,並且在工程中提供了從雅虎股票下載好的數據,可以直接運行,項目地址:https://github.com/jimenbian/stock-rnn