攢了幾天,發一個大的
這是前幾天投了一家量化分析職位,他給的題目的是寫神經網絡擇時模型,大概就是用神經網絡預測收盤價
database類:該類用於獲得新浪網中的數據,並將其放入本地數據庫。在本地數據庫中建立兩個表,分別是Data2012to2015和Data2015to2016,表中都含有日期,當日開盤價、當日收盤價、當日最高價、當日最低價。Data2012to2015為訓練數據集,Data2015to2016為測試數據集。
package it.cast;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.UnsupportedEncodingException;
import java.net.HttpURLConnection;
import java.net.MalformedURLException;
import java.net.URL;
import java.sql.Connection;
import java.sql.DriverManager;
import java.sql.PreparedStatement;
import java.sql.SQLException;
import java.sql.Statement;
public class dataBase {
//創建訓練集:Data2012to2015和測試集Data2015to2016
public void createDataBase() {
try {
Connection conn = null;
Statement stmt = null;
//鏈接數據庫
Class.forName("oracle.jdbc.driver.OracleDriver");
String url = "jdbc:oracle:thin:@localhost:1521:ORCL";
String UserName = "system";
String password = "manager";
conn = DriverManager.getConnection(url, UserName, password);
stmt = conn.createStatement();
historyShare( conn, stmt);
}catch (Exception e) {
e.printStackTrace();
}
}
private void historyShare( Connection conn, Statement stmt)
throws SQLException, MalformedURLException, IOException,
UnsupportedEncodingException {
//創建表格
//表格列為:股票id號、日期、開盤價、最高價、收盤價、最低價、成交量
String sql = "create table Data2015to2016(stokeid integer not null primary key ," +
"data varchar2(20), openPrice varchar2(20), highPrice varchar2(20), overPrice varchar2(20),lowPrice varchar2(20)," +
"vol varchar2(20))";
stmt.executeUpdate(sql);
URL ur = null;
ur = new URL("http://biz.finance.sina.com.cn/stock/flash_hq/kline_data.php?&rand=random(10000)&symbol=sz000001&end_date=20161118&begin_date=20151118&type=plain");
HttpURLConnection uc = (HttpURLConnection) ur.openConnection();
BufferedReader reader = new BufferedReader(new InputStreamReader(ur.openStream(),"GBK"));
String line;
PreparedStatement stmt1 = null;
int i=1;
//插入數據
while((line = reader.readLine()) != null){
//普通股票
String sql1 = "insert into Data2015to2016 values(?,?, ?, ?, ?, ?, ?)";
stmt1 = conn.prepareStatement(sql1);
String[] data=line.split(",");
String date = data[0];
String openPrice = data[1];
String highPrice = data[2];
String overPrice = data[3];
String lowPrice = data[4];
stmt1.setInt(1, i++);
stmt1.setString(2, data[0]);
stmt1.setString(3, data[1]);
stmt1.setString(4, data[2]);
stmt1.setString(5, data[3]);
stmt1.setString(6, data[4]);
stmt1.setString(7, data[5]);
stmt1.executeUpdate();
stmt1.close();
}
}
}
Methods類:由於java沒有現成的包可以直接得出某只股票的波動率指標、短期和長期均線指標等指標,由於一些指標在網上沒有找到,例如動量和反轉指標:REVS5,就用了動量指標MTM。所以在百度百科等資料中搜集了一些公式, 分別對這些公式編寫代碼,就能觀測到的數據來說,是准確的。
最后采用了8個指標,分別是波動率指標:EMV;短期和長期均線指標:EMA5和EMA60,MA5和MA60;動量指標MTM;量能指標:MACD;能量指標:CR5.以這8個指標為自變量,收盤價為因變量建立神經網絡模型。
package it.cast;
import java.util.ArrayList;
import java.util.List;
public class Methods {
//搜狗百科:A=(今日最高+今日最低)/2;B=(前日最高+前日最低)/2;C=今日最高-今日最低;2.EM=(A-B)*C/今日成交額;3.EMV=N日內EM的累和;4.MAEMV=EMV的M日簡單移動平均.參數N為14,參數M為9
public List<Double> EMV(List<Double>highPrice,List<Double>lowPrice,List<Double>vol){
List<Double>EM = new ArrayList<Double>();
for(int i = 2;i<highPrice.size();i++){
double A = (highPrice.get(i)+lowPrice.get(i))/2;
double B = (highPrice.get(i-2)+lowPrice.get(i-2))/2;
double C = highPrice.get(i)-lowPrice.get(i);
EM.add(((A-B)*C)/vol.get(i));
}
List<Double>EMV = new ArrayList<Double>();
//取N為14,即14日的EM值之和;M為9,即9日的移動平均
int N = 14;
int M = 9;
for(int i = N;i<EM.size()+1;i++){
//14日累和
double sum = 0;
for(int j = i-N;j<i;j++){
sum += EM.get(j);
}
EMV.add(sum);
}
List<Double>MAEMV = new ArrayList<Double>();
for(int i = M;i<EMV.size()+1;i++){
//9日移動平均
double sum = 0;
for(int j = i-M;j<i;j++){
sum += EMV.get(j);
}
sum = sum/M;
MAEMV.add(sum);
}
return MAEMV;
}
//EMA=(當日或當期收盤價-上一日或上期EXPMA)/N+上一日或上期EXPMA,其中,首次上期EXPMA值為上一期收盤價,N為天數。
public List<Double> EMA5(List<Double>overPrice){
//取20121118年收盤價為初始EXPMA
List<Double>EMA5 = new ArrayList<Double>();
for(int i = 0;i<5;i++){
EMA5.add(overPrice.get(i));
}
for(int i = 5;i<overPrice.size();i++){
EMA5.add((overPrice.get(i)-EMA5.get(i-5))/5+EMA5.get(i-5));
}
return EMA5;
}
public List<Double> EMA60(List<Double>overPrice){
//取20121118年收盤價為初始EXPMA
List<Double>EMA60 = new ArrayList<Double>();
for(int i = 0;i<60;i++){
EMA60.add(overPrice.get(i));
}
for(int i = 60;i<overPrice.size();i++){
EMA60.add((overPrice.get(i)-EMA60.get(i-60))/60+EMA60.get(i-60));
}
return EMA60;
}
//5日均線
public List<Double> MA5(List<Double>overPrice){
List<Double>MA5 = new ArrayList<Double>();
for(int i = 5;i<overPrice.size()+1;i++){
double sum = 0;
for(int j = i-1;j>=i-5;j--){
sum += overPrice.get(j);
}
sum = sum/5;
MA5.add(sum);
}
return MA5;
}
//60日均線
public List<Double> MA60(List<Double>overPrice){
List<Double>MA60 = new ArrayList<Double>();
for(int i = 60;i<overPrice.size()+1;i++){
double sum = 0;
for(int j = i-1;j>=i-60;j--){
sum += overPrice.get(j);
}
sum = sum/60;
MA60.add(sum);
}
return MA60;
}
//動量指標MTM,1.MTM=當日收盤價-N日前收盤價;2.MTMMA=MTM的M日移動平均;3.參數N一般設置為12日參數M一般設置為6,表中當動量值減低或反轉增加時,應為買進或賣出時機
public List<Double> MTM(List<Double>overPrice){
List<Double>MTM = new ArrayList<Double>();
List<Double>MTMlist = new ArrayList<Double>();
int N = 12;
int M = 6;
for(int i = 12;i<overPrice.size();i++){
MTM.add(overPrice.get(i)-overPrice.get(i-12));
}
//移動平均參數為6
for(int i = 6;i<MTM.size()+1;i++){
double sum = 0;
for(int j = i-1;j>=i-6;j--){
sum += MTM.get(j);
}
sum = sum/6;
MTMlist.add(sum);
}
return MTMlist;
}
//百度百科:http://baike.baidu.com/link?url=XQf2I-JIyNR1AEM_EnMnuU90U1vmJDoXukUe1fQVsBA1Y_fqAA8dj7DoxLCoh5U-YysBkVT5aIZLXeG2g1snoK:量能指標就是通過動態分析成交量的變化,
public List<Double> MACD(List<Double>vol){
int shortN = 12;
List<Double>Short = new ArrayList<Double>();
for(int i = shortN;i<vol.size()+1;i++){
Short.add(2*vol.get(i-1)+(shortN-1)*vol.get(i-shortN));
}
int longN = 26;
List<Double>Long = new ArrayList<Double>();
for(int i = longN;i<vol.size()+1;i++){
Long.add(2*vol.get(i-1)+(longN-1)*vol.get(i-longN));
}
// 取兩個序列中較短序列的長度
int length = 0;
if(Short.size()>Long.size()){
length = Long.size();
}else{
length = Short.size();
}
List<Double>DIFF1 = new ArrayList<Double>();
for(int i = length-1;i>=0;i--){
DIFF1.add(Short.get(i)-Long.get(i));
}
List<Double>DIFF = new ArrayList<Double>();
for(int i = 0;i<DIFF1.size();i++){
DIFF.add(DIFF1.get(DIFF1.size()-i-1));
}
List<Double>DEA = new ArrayList<Double>();
for(int i = 0;i<DIFF.size()-1;i++){
DEA.add(2*DIFF.get(i+1)+(9-1)*DIFF.get(i));
}
List<Double>MACD = new ArrayList<Double>();
for(int i = 1;i<DIFF.size();i++){
MACD.add(DIFF.get(i)-DEA.get(i-1));
}
return MACD;
}
//能量指標:CR,見百度百科:http://baike.baidu.com/link?url=v5yYFep6wZioav0P-LOruuhkzjho6PqzQqfEBj5TYQLfaadLSADSQVl0njP7k1zY78KJMoBFrE4OO4wYolZXbMnRRQi7U66R0X2jeSV3ZoXKeuG2zEbqEqP4CnyiF7j6
public List<Double> CR5(List<Double>overPrice,List<Double>highPrice,List<Double>lowPrice,List<Double>openPrice){
List<Double> YM = new ArrayList<Double>();
List<Double> HYM = new ArrayList<Double>();
List<Double> YML = new ArrayList<Double>();
List<Double> CR = new ArrayList<Double>();
for(int i = 0;i<overPrice.size();i++){
YM.add((highPrice.get(i)+overPrice.get(i)+lowPrice.get(i)+openPrice.get(i))/4);
}
//p1表示5日以來多方力量總和,p2表示5日以來空方力量總和
for(int i = 6;i<highPrice.size()+1;i++){
double sum = 0;
for(int j = i-1;j>=i-5;j--){
sum += highPrice.get(j)-YM.get(j-1);
}
HYM.add(sum);
}
//p2表示5日以來空方力量總和,p2表示5日以來空方力量總和
for(int i = 6;i<lowPrice.size()+1;i++){
double sum = 0;
for(int j = i-1;j>=i-5;j--){
sum += YM.get(j-1)-lowPrice.get(j);
}
YML.add(sum);
}
for(int i = 0;i<YML.size();i++){
double temp = (double)HYM.get(i)/YML.get(i);
if(temp<0){
CR.add((double) 0);
}else{
CR.add(temp);
}
}
return CR;
}
public double[][] bpTrain(List<Double>overPrice,List<Double>highPrice,List<Double>lowPrice,List<Double>openPrice,List<Double>vol){
List<Double>EMV = EMV(highPrice, lowPrice, vol);
List<Double>EMA5 = EMA5(overPrice);
List<Double>EMA60 = EMA60(overPrice);
List<Double>MA5 = MA5(overPrice);
List<Double>MA60 = MA60(overPrice);
List<Double>MTM = MTM(overPrice);
List<Double>MACD = MACD(vol);
List<Double>CR5 = CR5(overPrice, highPrice, lowPrice, openPrice);
int length = 0;
if(EMA60.size()>MA60.size()){
length = MA60.size();
}else{
length = EMA60.size();
}
List<ArrayList<Double>>datalist = new ArrayList<ArrayList<Double>>();
for(int i = 0;i<length;i++){
ArrayList<Double>list = new ArrayList<Double>();
//list.add(EMV.get(EMV.size()-length+i));
list.add(EMA5.get(EMA5.size()-length+i));
list.add(EMA60.get(EMA60.size()-length+i));
list.add(MA5.get(MA5.size()-length+i));
list.add(MA60.get(MA60.size()-length+i));
list.add(MTM.get(MTM.size()-length+i));
// list.add(MACD.get(MACD.size()-length+i));
list.add(CR5.get(CR5.size()-length+i));
datalist.add(list);
}
double [][]data = new double[datalist.size()][6];
for(int i = 0;i<datalist.size();i++){
for(int j = 0;j<6;j++){
data[i][j] = datalist.get(i).get(j);
System.out.print(data[i][j]+" ");
}
System.out.println();
}
return data;
}
}
BPnet類:這里想建立輸入單元為8個,兩層隱含層,每個隱含層為13個單元,輸出層單元為1的神經網絡。
首先初始化輸入層到隱含層,隱含層之間,以及隱含層到輸出層的權重矩陣;
其次利用權重矩陣和輸入層分別計算出每個隱含層節點數據
之后利用計算得出的輸出層數據與真實值進行比較,並逐層調節權重;
反復上述過程直至精度達到要求或是達到迭代次數的要求;
這里設置迭代次數為5000次;
利用的測試數據集為Data2012to2015
下圖為訓練之后的模型對Data2012to2015自身進行擬合的效果:(這里由於自變量大概是10左右的數據,所以在利用激活函數1/(1+e^-ax))時,a取了0.01
package it.cast;
import java.util.Random;
public class BPnet {
public double[][] layer;//神經網絡各層節點
public double[][] layerErr;//神經網絡各節點誤差
public double[][][] layer_weight;//各層節點權重
public double[][][] layer_weight_delta;//各層節點權重動量
public double mobp;//動量系數
public double rate;//學習系數
public BPnet(int[] layernum, double rate, double mobp){
this.mobp = mobp;
this.rate = rate;
layer = new double[layernum.length][];
layerErr = new double[layernum.length][];
layer_weight = new double[layernum.length][][];
layer_weight_delta = new double[layernum.length][][];
Random random = new Random();
for(int l=0;l<layernum.length;l++){
layer[l]=new double[layernum[l]];
layerErr[l]=new double[layernum[l]];
if(l+1<layernum.length){
layer_weight[l]=new double[layernum[l]+1][layernum[l+1]];
layer_weight_delta[l]=new double[layernum[l]+1][layernum[l+1]];
for(int j=0;j<layernum[l]+1;j++)
for(int i=0;i<layernum[l+1];i++)
layer_weight[l][j][i]=random.nextDouble();//隨機初始化權重
}
}
}
//逐層向前計算輸出
public double[] computeOut(double[] in){
for(int l=1;l<layer.length;l++){
for(int j=0;j<layer[l].length;j++){
double z=layer_weight[l-1][layer[l-1].length][j];
for(int i=0;i<layer[l-1].length;i++){
layer[l-1][i]=l==1?in[i]:layer[l-1][i];
z+=layer_weight[l-1][i][j]*layer[l-1][i];
}
// System.out.println(z+"####");
layer[l][j]=1/(1+Math.exp(-0.01*z));
// System.out.println("&&**"+layer[l][j]);
}
}
//System.out.println("&&^^**"+layer[layer.length-1][0]);
return layer[layer.length-1];
}
//逐層反向計算誤差並修改權重
public void updateWeight(double[] tar){
int l=layer.length-1;
for(int j=0;j<layerErr[l].length;j++)
layerErr[l][j]=layer[l][j]*(1-layer[l][j])*(1/(1+Math.exp(-0.01*tar[j]))-layer[l][j]);
while(l-->0){
for(int j=0;j<layerErr[l].length;j++){
double z = 0.0;
for(int i=0;i<layerErr[l+1].length;i++){
z=z+l>0?layerErr[l+1][i]*layer_weight[l][j][i]:0;
layer_weight_delta[l][j][i]= mobp*layer_weight_delta[l][j][i]+rate*layerErr[l+1][i]*layer[l][j];//隱含層動量調整
layer_weight[l][j][i]+=layer_weight_delta[l][j][i];//隱含層權重調整
if(j==layerErr[l].length-1){
layer_weight_delta[l][j+1][i]= mobp*layer_weight_delta[l][j+1][i]+rate*layerErr[l+1][i];//截距動量調整
layer_weight[l][j+1][i]+=layer_weight_delta[l][j+1][i];//截距權重調整
}
}
layerErr[l][j]=z*layer[l][j]*(1-layer[l][j]);//記錄誤差
}
}
}
public void train(double[] in, double[] tar){
double[] out = computeOut(in);
updateWeight(tar);
}
}
從圖中可以看出2012年初,股市變化幅度很大時,模型擬合效果稍差,但總體擬合效果較好。(紅線表示擬合曲線,藍線表示真實收盤價)
測試數據集采用的是Data2015to2016,即2015年至2016年數據,擬合擬合效果如下:
從圖中可以看出曲線可以擬合大致趨勢,但是不能很好的擬合波動,可能是由於對訓練數據集過渡擬合的原因。
BackProce類:該類計算了如果按照神經網絡模型對該股票進行操作的結果,采用的策略是,如果下一天的預測值高於當天的收盤價,就買入,低於就賣出,設置初始賬戶金額為10000.
可得到最后的收益率為0.18364521221914928,賬戶金額為:11836.452122191493。
累計收益率如下圖:
累計收益率明顯呈現上升趨勢。
package it.cast;
import java.util.ArrayList;
import java.util.List;
public class BackProce {
public List<ArrayList<Double>> selectChance(List<ArrayList<Double>>result,double account){
double accountF = account;
System.out.println("初始賬戶為: "+account);
ArrayList<Double>expect = new ArrayList<Double>();
ArrayList<Double>target = new ArrayList<Double>();
for(int i = 0;i<result.size();i++){
expect.add(result.get(i).get(0));
target.add(result.get(i).get(1));
}
List<ArrayList<Double>>chance = new ArrayList<ArrayList<Double>>();
for(int i = 1;i<expect.size();i++){
if(expect.get(i)>target.get(i-1)){
//買入
account += account*(target.get(i)-target.get(i-1))/target.get(i-1);
}
ArrayList<Double>list = new ArrayList<Double>();
list.add((account-accountF)/accountF);
list.add((double) i);
chance.add(list);
}
System.out.println("期末賬戶為: "+account);
System.out.println("年化收益率為: "+(account-accountF)/accountF);
return chance;
}
}
輔助類Graph:該類借助了jfree包,用於繪制圖像
package it.cast;
import java.awt.BasicStroke;
import java.awt.Color;
import java.awt.Font;
import java.io.FileOutputStream;
import java.io.OutputStream;
import java.util.ArrayList;
import java.util.Date;
import java.util.List;
import javax.swing.JPanel;
import org.jfree.chart.ChartFactory;
import org.jfree.chart.ChartPanel;
import org.jfree.chart.ChartUtilities;
import org.jfree.chart.JFreeChart;
import org.jfree.chart.axis.NumberAxis;
import org.jfree.chart.plot.CategoryPlot;
import org.jfree.chart.plot.PlotOrientation;
import org.jfree.chart.renderer.category.LineAndShapeRenderer;
import org.jfree.chart.title.TextTitle;
import org.jfree.data.category.CategoryDataset;
import org.jfree.data.category.DefaultCategoryDataset;
import org.jfree.ui.ApplicationFrame;
import org.jfree.ui.HorizontalAlignment;
import org.jfree.ui.RectangleEdge;
public class Graph extends ApplicationFrame{
ChartPanel frame1;
private static final long serialVersionUID = 1L;
public Graph(String s , List<ArrayList<Double>> excel) {
super(s);
setContentPane(createDemoLine(excel));
}
public static DefaultCategoryDataset createDataset(List<ArrayList<Double>> excel) {
DefaultCategoryDataset linedataset = new DefaultCategoryDataset();
for (int i=0; i <excel.size(); i++) {
linedataset.addValue(excel.get(i).get(0), "expect", excel.get(i).get(1));
//linedataset.addValue(excel.get(i).get(1), "target", Integer.toString(i+1));
}
return linedataset;
}
public static JPanel createDemoLine(List<ArrayList<Double>> excel) {
JFreeChart jfreechart = createChart(createDataset(excel));
return new ChartPanel(jfreechart);
}
// 生成圖表主對象JFreeChart
public static JFreeChart createChart(DefaultCategoryDataset linedataset) {
// 定義圖表對象
JFreeChart chart = ChartFactory.createLineChart("Cumulative rate of return", //折線圖名稱
"time", // 橫坐標名稱
"Value", // 縱坐標名稱
linedataset, // 數據
PlotOrientation.VERTICAL, // 水平顯示圖像
true, // include legend
false, // tooltips
false // urls
);
// chart.setBackgroundPaint(Color.red);
CategoryPlot plot = chart.getCategoryPlot();
// plot.setDomainGridlinePaint(Color.red);
plot.setDomainGridlinesVisible(true);
// 5,設置水平網格線顏色
// plot.setRangeGridlinePaint(Color.blue);
// 6,設置是否顯示水平網格線
plot.setRangeGridlinesVisible(true);
plot.setRangeGridlinesVisible(true); //是否顯示格子線
//plot.setBackgroundAlpha(f); //設置背景透明度
NumberAxis rangeAxis = (NumberAxis)plot.getRangeAxis();
rangeAxis.setStandardTickUnits(NumberAxis.createIntegerTickUnits());
rangeAxis.setAutoRangeIncludesZero(true);
rangeAxis.setUpperMargin(0.20);
rangeAxis.setLabelAngle(Math.PI / 2.0);
rangeAxis.setAutoRange(false);
FileOutputStream fos_jpg=null;
try{
fos_jpg=new FileOutputStream("D:\\ok_bing.jpg");
/*
* 第二個參數如果為100,會報異常:
* java.lang.IllegalArgumentException: The 'quality' must be in the range 0.0f to 1.0f
* 限制quality必須小於等於1,把100改成 0.1f。
*/
// ChartUtilities.writeChartAsJPEG(fos_jpg, 0.99f, chart, 600, 300, null);
ChartUtilities.writeChartAsJPEG(fos_jpg, chart, 900, 400);
}catch(Exception e){
System.out.println("[e]"+e);
}finally{
try{
fos_jpg.close();
}catch(Exception e){
}
}
return chart;
}
}
主函數類testClass:
package it.cast;
import java.io.IOException;
import java.sql.Connection;
import java.sql.DriverManager;
import java.sql.ResultSet;
import java.sql.Statement;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
public class testClass {
public static void main(String[] args) {
dataBase data = new dataBase();
// data.createDataBase();
try{
Connection conn = null;
Statement stmt = null;
//鏈接數據庫
Class.forName("oracle.jdbc.driver.OracleDriver");
String url = "jdbc:oracle:thin:@localhost:1521:ORCL";
String UserName = "system";
String password = "manager";
conn = DriverManager.getConnection(url, UserName, password);
stmt = conn.createStatement();
String sql2="select * from Data2012to2015";
ResultSet rs = stmt.executeQuery(sql2);
//創建序列
List<Double> openPrice = new ArrayList<Double>();
List<Double> highPrice = new ArrayList<Double>();
List<Double> overPrice = new ArrayList<Double>();
List<Double> lowPrice = new ArrayList<Double>();
List<Double> vol = new ArrayList<Double>();
while (rs.next()){
openPrice.add(Double.parseDouble(rs.getString("OPENPRICE")));
highPrice.add(Double.parseDouble(rs.getString("HIGHPRICE")));
overPrice.add(Double.parseDouble(rs.getString("OVERPRICE")));
lowPrice.add(Double.parseDouble(rs.getString("LOWPRICE")));
vol.add(Double.parseDouble(rs.getString("VOL")));
}
Methods m = new Methods();
double [][]dataset = m.bpTrain(overPrice, highPrice, lowPrice, openPrice, vol);
double [][]target = new double[dataset.length][];
for(int i = 0;i<dataset.length;i++){
target[i] = new double[1];
target[i][0] = overPrice.get(overPrice.size()-dataset.length+i);
}
String sql3="select * from Data2015to2016";
ResultSet rs2 = stmt.executeQuery(sql3);
//創建序列
List<Double> openPrice2 = new ArrayList<Double>();
List<Double> highPrice2 = new ArrayList<Double>();
List<Double> overPrice2 = new ArrayList<Double>();
List<Double> lowPrice2 = new ArrayList<Double>();
List<Double> vol2 = new ArrayList<Double>();
while (rs2.next()){
openPrice2.add(Double.parseDouble(rs.getString("OPENPRICE")));
highPrice2.add(Double.parseDouble(rs.getString("HIGHPRICE")));
overPrice2.add(Double.parseDouble(rs.getString("OVERPRICE")));
lowPrice2.add(Double.parseDouble(rs.getString("LOWPRICE")));
vol2.add(Double.parseDouble(rs.getString("VOL")));
}
Methods m2 = new Methods();
double [][]dataset2 = m2.bpTrain(overPrice2, highPrice2, lowPrice2, openPrice2, vol2);
double [][]target2 = new double[dataset2.length][];
for(int i = 0;i<dataset2.length;i++){
target2[i] = new double[1];
target2[i][0] = overPrice2.get(overPrice2.size()-dataset2.length+i);
}
BPnet bp = new BPnet(new int[]{6,13,13,1}, 0.15, 0.8);
//迭代訓練5000次
for(int n=0;n<50000;n++)
for(int i=0;i<dataset.length;i++)
bp.train(dataset[i], target[i]);
//測試數據集
double []result = new double[dataset2.length];
List<ArrayList<Double>>resultList = new ArrayList<ArrayList<Double>>();
for(int j=0;j<dataset2.length;j++){
double []a = bp.computeOut(dataset2[j]);
ArrayList<Double>list = new ArrayList<Double>();
result[j] = 100*(-Math.log(1/a[0]-1));
list.add(result[j]);
list.add(target2[j][0]);
resultList.add(list);
System.out.println(Arrays.toString(dataset2[j])+":"+result[j]+" real:"+target2[j][0]);
}
//new Graph("1",resultList);
BackProce b = new BackProce();
double account = 10000;
List<ArrayList<Double>>chance = b.selectChance(resultList,account);
new Graph("1",chance);
}catch (Exception e) {
e.printStackTrace();
// TODO: handle exception
}
System.out.println("End");
}
}
缺點:1、只能繪制基本圖像,沒有找到方法將特殊點標出,如:能夠獲取在什么時間點買入,但是不知怎么在特定點用其他顏色標出。
2、神經網絡模型對訓練數據擬合很好,但是對測試數據擬合效果不佳,猜測原因可能是過擬合或是有些其他主要的變量因素沒有考慮進去。