用matplotlib制作的比較滿意的蠟燭圖


用matplotlib制作的比較滿意的蠟燭圖

2D圖形制作包, 功能強大, 習練了很久, 終於搞定了一個比較滿意的腳本.

特點:

  1. 使用方面要非常簡單
  2. 繪制出來的圖要非常的滿意, 具有如下的特點
    1. 時間和空間的比例尺需要固定, 就是說圖件的大小需要依據數據的長度和價格的變動幅度自動調整, 至少時間軸上應該如此.
    2. 時間軸的刻度: 對於日線圖而言, 年/月/日/星期幾 都應該一目了然.
    3. Y軸: 對數刻度, 10%等比刻度線, 刻度值的標簽應該能反應絕對的股價, 支持雙Y軸(右側的Y軸度量大盤的變化)
    4. 蠟燭非白即黑, 只要兩種顏色(包括邊界線)
    5. 分辨率要足夠高, 至少300DPI, 方便原樣(無伸縮)打印
    6. 應該支持非常方便地抽取子集, 然后制圖

版本持續升級:

2017.12 的備忘錄

在以前的函數式代碼的基礎上, OOP方式重構代碼, 方便以后擴展功能, 也讓程序運行得更健碩

結果展示

主塊代碼

繪圖模塊的代碼

結果展示:

png file from my github:

https://github.com/duanqingshan/learngit/blob/master/均勝電子_20171230_182515__468000.png

gif file from my cnblogs:

https://files.cnblogs.com/files/duan-qs/均勝電子_20171226_220616__255000.gif

主代碼塊:


# -*- coding: utf-8 -*-

u''' 研究K線形態: 從單個K線做起, 然后K線組合, 然后K線形態
# 1. 定義兩個實例 
# 2. 加載數據
# 3. 前復權處理
# 4. 計算指標
# 5. 形態研究之: 提取與顯示
# 6. 繪圖  主圖+成交量圖
'''

import amipy as ami
import plotter as pl
import pattern as pa
reload(pa)
reload(ami)

context = ami.Context('600699.SH')  # 000911
#context = ami.Context('002242.SZ')  # 000911
stk = ami.Stock(context)

stk.grab_data_tdxlday(context, num_days=None)
stk.load_tdx_qx()

stk.qfq()

stk.ma20 = ami.TTR.sma(stk.ohlc.close, 20)
stk.cyc61 = ami.TTR.sma(stk.ohlc.close, 120)

pattern = pa.Pattern(stk)
pattern.study_csyx(roc1=0.3/100)

#subset = slice(-250*3, None) # '2017-07'  '2017'
subset = slice(-120,None) # '2017-07'  '2017'
plotter = pl.Plotter(context,stk,subset,quanxi=None)
#    plotter.plot_candle_vol()
#plotter.plot_candle_vol(savefig=True)

#plotter.plot_timing(timing=pattern.csyx)    
#plotter.plot_timing(timing=pattern.szx)    
plotter.plot_timing(timing=pattern.upgap, savefig=True)    
#plotter.plot_timing(timing=pattern.dngap)  

繪圖代碼:


# -*- coding: utf-8 -*-

#import sys

import numpy as np
import pandas as pd

import datetime

import matplotlib
import matplotlib.pyplot as plt
from matplotlib.ticker import (
        FixedLocator, 
        #MultipleLocator, 
        #LogLocator, 
        
        #NullFormatter, 
        FuncFormatter, 
        #LogFormatter 
        )
from matplotlib.font_manager import FontProperties 
from matplotlib.text import Text
myfont = FontProperties(fname=r"c:\windows\fonts\msyh.ttf")  #size可不用指定
matplotlib.rcParams['axes.unicode_minus'] = False

#import amipy as ami
import ttr as TTR

#==============================================================================
# Python中的作用域及global用法 - Summer_cool - 博客園  
# https://www.cnblogs.com/summer-cool/p/3884595.html
# 
# 函數定義了本地作用域,而模塊定義的是全局作用域。
# 如果想要在函數內定義全局作用域,需要加上global修飾符。
# 
# 變量名解析:LEGB原則
# 當在函數中使用未認證的變量名時,Python搜索4個作用域:
#     [本地作用域(L-local)(函數內部聲明但沒有使用global的變量),
#      之后是上一層結構def或者lambda的本地作用域(E-enclosure),
#      之后是全局作用域(G-global)(函數中使用global聲明的變量或在模塊層聲明的變量),
#      最后是內置作用域(B)(即python的內置類和函數等)]
#      並且在第一處能夠找到這個變量名的地方停下來。
#      如果變量名在整個的搜索過程中都沒有找到,Python就會報錯。
#      
# 補:上面的變量規則只適用於簡單對象,當出現引用對象的屬性時,則有另一套搜索規則:
#     屬性引用搜索一個或多個對象,而不是作用域,並且有可能涉及到所謂的"繼承"
# 補2:global修飾符在python里的一個獨特現象:
#     在模塊層面定義的變量(無需global修飾),
#     如果在函數中沒有再定義同名變量,可以在函數中當做全局變量使用.
#     如果在函數中要對它重新賦值的話, 則必須在本函數中事先聲明為全局變量, 否則會拋出異常.
# 
#     #先聲明全局本函數里用到的全局變量: 圖表, 上下文, 股票對象
#     #使用global語句可以清楚地表明變量是在外面的塊定義的, 而且在本函數內
#     #可以使用或者修改這些變量(前提是必須先聲明為全局變量, 以便告訴python
#     #解釋器這些變量是全局的(主塊和函數塊共有的)已經是在外部--主代碼塊里--定義好了的, 
#     # 或者是本代碼塊要傳遞到主代碼塊里的變量).
#==============================================================================

class Plotter(object):
    u'''
    Plotter class to make picture of stock's ohlcv data
    '''
    # define class var
    ptype_dict={
        'lday':u'日',
        'lc5':u'五分鍾'} # 這里聲明的變量, 不用加global修飾符, 也是全局變量

        
    def __init__(self, context, stk, subset, quanxi=None):
        self.context = context
        self.stk = stk
        self.subset = subset
        self.quanxi = quanxi
        self.fig = None
        self.ax1 = self.ax2 = self.ax3 = None
        self.candle_colors = None
        self.length = None
        self.x = None

    def plot_candle_only(self, savefig=False):
        u'''僅繪制主圖    
        '''
        self.layout(volume_bars=False)
        self.candles()
        self.primary_curves()
        self.savfig(savefig)
        #fig #在ipython console里顯示整個圖表

    def plot_candle_vol(self, savefig=False):
        u'''主圖+成交量圖
        '''
        self.layout(volume_bars=True)
        self.candles() 
        self.primary_curves() 
        self.vol_bars()
        self.savfig(savefig)
        pass
            
    def plot_timing(self, timing=None, savefig=False):
        u'''畫圖: timing之K線性形態
            candles + (MA20, MA120) + 形態標注
            volume bar
        para: 
            timing: Series, 
            note: str, {'csyx', 'szx', etc}, 長上影線, 十字星等
        
        '''
        self.layout(volume_bars=True)
        self.candles() 
        self.primary_curves() 
        self.vol_bars()
        self.annotate(timing)
        self.savfig(savefig)
        
            
    def layout(self, volume_bars=True):
        u'''
        
        '''
        if volume_bars:
            self.fig, (self.ax1, self.ax2) = plt.subplots(2, 1, sharex=True, gridspec_kw={'height_ratios': [3,1]} )
        else:
            self.fig,self.ax1 = plt.subplots(1,1)
            #res = fig, ax1
        #return res
    
    def candles(self,
                col_func=None):
        u'''
        
        subset: 
            slice object, slice(start,stop,step)
            that is:
                slice(100)
                slice(-100,None)
                slice(100,200)
                slice(-200,-100,2)
                '2011-09'
                '2017'
        '''
        
        def default_col_func(index, open1, close, low, high):
            return 'black' if open1[index] > close[index] else 'white' # r g b  cyan black white
        
        subset=self.subset
        col_func= col_func or default_col_func
        ohlc = self.stk.ohlc[subset] if self.subset else self.stk.ohlc
        open1,high,low,close = ohlc.open, ohlc.high, ohlc.low, ohlc.close
        self.length = length = len(close)
        self.x = x = np.arange(length)
        candle_colors = [col_func(i, open1, close, low, high) for i in x]
        self.candle_colors = candle_colors
        # 計算出 每日的開盤價/收盤價里的最大值和最小值
        oc_min = pd.concat([open1, close], axis=1).min(axis=1)
        oc_max = pd.concat([open1, close], axis=1).max(axis=1)
    
        #candles = ax1.bar(x, oc_max-oc_min, bottom=oc_min, color=candle_colors, linewidth=0)
        #lines = ax1.vlines(x + 0.4, low, high, color=candle_colors, linewidth=1)
        candles = self.ax1.bar(x-0.4, oc_max-oc_min, bottom=oc_min, color=candle_colors, linewidth=0.2, edgecolor='black')
        shadlines_up = self.ax1.vlines(x,    oc_max, high, color=['black']* length, linewidth=0.3)
        shadlines_dn = self.ax1.vlines(x,    low, oc_min,  color=['black']* length, linewidth=0.3)
        #print candles.__class__, shadlines_up.__class__, shadlines_dn.__class__
        isinstance(candles,      matplotlib.container.BarContainer) == True
        isinstance(shadlines_dn, matplotlib.collections.LineCollection)
        isinstance(shadlines_up, matplotlib.collections.LineCollection)
        
        self.custom_figure()
        self.custom_yaxis()
        pass
    
    def primary_curves(self): #subset=None):
        #ohlc = stk.ohlc[subset] if subset else stk.ohlc
        #close = ohlc.close
        subset = self.subset
        if (isinstance(self.stk.ma20, pd.Series) and isinstance(self.stk.cyc61, pd.Series)):
            ma20 = self.stk.ma20[subset] if subset else self.stk.ma20
            cyc61 = self.stk.cyc61[subset] if subset else self.stk.cyc61
            indicators = [ma20, cyc61]
            x=self.x
            for ind in indicators:
                self.ax1.plot(x, ind, 'o-', lw=0.1, markersize=0.7, markeredgewidth=0.1, label=ind.name) #帶圓圈標記的實線
            self.ax1.legend()
            
        self.custom_xaxis(ax=self.ax1)
        
        
    def secondary_curves(self, ax):
    #    ohlc = stk.ohlc[subset] if subset else stk.ohlc
        pass
    
    def vol_bars(self):
        u'''
        
        '''
        subset = self.subset
        ohlc = self.stk.ohlc[subset] if subset else self.stk.ohlc
        volume = ohlc['volume']
        #open1,high,low,close = ohlc.open, ohlc.high, ohlc.low, ohlc.close
        x = self.x
        
        volume_scale = None
        scaled_volume = volume
        if volume.max() > 1000*1000:
            volume_scale = u'百萬股' #'M'
            scaled_volume = volume / 1000.0/1000.0
        elif volume.max() > 1000:
            volume_scale = u'千股'
            scaled_volume = volume / 1000.0
        self.ax2.bar(x-0.4, scaled_volume, color=self.candle_colors, linewidth=0.2, edgecolor='black')
        volume_title = 'Volume'
        if volume_scale:
            volume_title = 'Volume (%s)' % volume_scale
        #ax2.set_title(volume_title) # 太難看了
        self.ax2.set_ylabel(volume_title, fontdict=None)
        self.ax2.xaxis.grid(False)
        #plt.setp(ax.get_xticklabels(minor=False), fontsize=6)
        
        self.custom_xaxis(self.ax2)
        
        pass

    def annotate(self, timing):
        u'''在主圖上標注給定的K線形態:
        param:
            timing: event of Series of k-pattern
            note: str, 對應於事件的標注文本
        example:
            >>> plotter.annotate(csyx) #長上影線
        '''
        #ax=plt.gca()
        #xx = self.action.p_DJR.index
        c = self.stk.ohlc.close[self.subset] if self.subset else self.stk.ohlc.close
        self.timing = timing[self.subset] if self.subset  else timing
        ptn_dt = c[self.timing].index # True 邏輯選擇 選出長上影線的時機(日期索引)
        note = self.note = self.timing.name[:3]
        ax = self.ax1
        xx = map(lambda dt: c.index.get_loc(dt), ptn_dt) 
        yy = c * 1.1
        #strings = self.action['value'].values.astype(str)
        #strings = self.action['bonus'].values.astype(str)
        #strings = map(lambda x: u'派'+str(x), strings)
        for i,x in enumerate(xx):
            #ax.text(x, yy[i], strings[i])
            print i, c.index[x], x, yy[x], c[x]
            ax.annotate(note, xy=(x, yy[x]*1.05/1.1), xytext=(x, yy[x]+0.0),
                        arrowprops=dict(
                                facecolor='black', 
                                color='red',
                                #shrink=0.05,
                                arrowstyle='->',
                                ),)
    
    
    
    def custom_yaxis(self):
        u'''
        #   設定 Y 軸上的刻度
        #==================================================================================================================================================
        python - Matplotlib log scale tick label number formatting - Stack Overflow  
        https://stackoverflow.com/questions/21920233/matplotlib-log-scale-tick-label-number-formatting
        每個坐標軸都有7大屬性:
            ax1.set_yscale, ylim, ylabel, yticks, yticklabels, ybound, ymargin
        '''
        #use_expo=True; 
        expbase=1.1  # 2 e 10
        yaxis= self.ax1.get_yaxis()
        isinstance(yaxis, matplotlib.axis.YAxis)
        self.ax1.set_yscale(value='log', basey=expbase)
        pass
    
    def custom_figure(self):
        u'''  '''
        # 依據繪圖數據的長度和時間軸的比例尺(比如1:16)確定圖表的長度:  
        #fig = plt.gcf()
        #fig.set_size_inches(18.5, 10.5)
        self.fig.set_size_inches(self.length/16.0, 6) # /18 /20 /16 diff time-scales
        
        title = u'%s(%s)%s周期蠟燭圖'%(self.context.name, self.context.symbol, self.ptype_dict[self.context.ptype])
        self.ax1.set_title(title)
        pass
    
    def custom_xaxis(self, ax):
        u'''
        
        '''
        subset = self.subset
        ohlc = self.stk.ohlc[subset] if subset else self.stk.ohlc
        close = ohlc.close
        length = self.length  # len(close)
        
        ax.set_xlim(-2, length+10)
        xaxis= ax.get_xaxis()
        yaxis= ax.get_yaxis()
        #   設定 X 軸上的主刻度/次刻度位置
        #==================================================================================================================================================
        mdindex, wdindex, sdindex= self.ohlc_find_idx_fdim(close) 
        xMajorLocator= FixedLocator(np.array(mdindex)) # 針對主刻度,實例化一個"固定式刻度定位"
        xMinorLocator= FixedLocator(np.array(wdindex)) # 確定 X 軸的 MinorLocator
        
        # 確定 X 軸的 MajorFormatter 和 MinorFormatter 
        # 自定義的刻度格式(應該是一個function)
        datelist = close.index.date.tolist() 
        def x_major_formatter_1(idx, pos=None): 
            u'''
            格式函數的功能: idx 是位置location, 依據位置, 返回對應的日期刻度標簽
            '''
            #return datelist[idx].strftime('%Y-%m-%d')
            return datelist[idx].strftime('%m\n%Y')
        def x_major_formatter_2(idx, pos=None):
            return datelist[idx].strftime('\n\n%m\n%Y')
     
        def x_minor_formatter_1(idx, pos=None):
            #return datelist[idx].strftime(u'一\n%d') # 周一
            return datelist[idx].strftime(u'M\n%d') # 周一
        def x_minor_formatter_2(idx, pos=None):
            return datelist[idx].strftime('%m-%d')
     
        xMajorFormatter_1 = FuncFormatter(x_major_formatter_1)
        xMajorFormatter_2 = FuncFormatter(x_major_formatter_2)
        xMinorFormatter_1 = FuncFormatter(x_minor_formatter_1)
     
        # 設定 X 軸的 Locator 和 Formatter
        xaxis.set_major_locator(xMajorLocator)
        xaxis.set_minor_locator(xMinorLocator)
    
        xaxis.set_major_formatter(xMajorFormatter_1)
        if self.ax2 is None:
            xaxis.set_major_formatter(xMajorFormatter_2)
        xaxis.set_minor_formatter(xMinorFormatter_1)
    
        if self.ax2 is None: # 僅繪制主圖
            # 設定不顯示的刻度標簽:
            if ax==self.ax1:
                plt.setp(ax.get_xticklabels(minor=False), visible=True) #主刻度標簽 可見
                plt.setp(ax.get_xticklabels(minor=True), visible=True)  #次刻度標簽 可見
        elif ((self.ax1 != None) and (self.ax2 != None)): # case of 主圖+成交量圖
            if ax==self.ax2:
                plt.setp(ax.get_xticklabels(minor=True), visible=False) #次刻度標簽 隱藏
            elif ax==self.ax1:
                plt.setp(ax.get_xticklabels(minor=False), visible=False) #主刻度標簽 隱藏
     
        # 設定 X 軸主刻度和次刻度標簽的樣式(字體大小)
        for malabel in ax.get_xticklabels(minor=False):
            malabel.set_fontsize(12) # 6號也太小了
            #malabel.set_horizontalalignment('right')
            #malabel.set_rotation('45')
     
        # if ax == ax1 or ax2:
        for milabel in ax.get_xticklabels(minor=True):
            milabel.set_fontsize(12) # 5 太小了
            #milabel.set_horizontalalignment('right')
            #milabel.set_rotation('45')
            #milabel.set_fontdict=myfont
            #milabel.set_fontproperties=myfont
            #milabel.set_prop=myfont
    
    
        #   設置兩個坐標軸上的 grid
        #==================================================================================================================================================
        #xaxis_2.grid(True, 'major', color='0.3', linestyle='solid', linewidth=0.2)
        xaxis.grid(True, 'major', color='0.3', linestyle='dotted', linewidth=0.3)
        xaxis.grid(True, 'minor', color='0.3', linestyle='dotted', linewidth=0.1)
     
        #yaxis_2.grid(True, 'major', color='0.3', linestyle='dashed', linewidth=0.2)
        yaxis.grid(True, 'major', color='0.3', linestyle='dotted', linewidth=0.1)
        yaxis.grid(True, 'minor', color='0.3', linestyle='dotted', linewidth=0.1)
    
        yaxis.get_major_ticks()[2].label = \
                Text(0,28.1024,u'28.10 $\\mathdefault{1.1^{35}}$')
 
    
    
    
    
    def ohlc_find_idx_fdim(self, ohlc):
        u'''
        功能: index of  first trading-day in month 
        ------
        - 獲取每個月的第一個交易日的下標(又稱0軸索引). 
          從數據框的時間索引里提取對應的日期, 然后檢索出下標.
        - 另外, 也獲取每個交易周的第一個交易日的下標
        
        輸入:
        - ohlc: pandas數據框
        
        返回:
        - list
        
        例子:
        -------
        >>>  mdindex, wdindex, sdindex= ohlc_find_idx_fdim(ohlc_last60)
        
        '''
        #datelist= [ datetime.date(int(ys), int(ms), int(ds)) for ys, ms, ds in [ dstr.split('-') for dstr in pdata[u'日期'] ] ]
        #last60 = ohlc[-250:]
        last60 = ohlc
        datelist = last60.index.date.tolist()
        # 確定 X 軸的 MajorLocator
        mdindex= [] # 每個月第一個交易日在所有日期列表中的 index, 月日期索引
        years= set([d.year for d in datelist])  # 所有的交易年份
         
        for y in sorted(years):     
            months= set([d.month for d in datelist if d.year == y])     # 當年所有的交易月份
            for m in sorted(months):
                monthday= min([dt for dt in datelist if dt.year==y and dt.month==m])    # 當月的第一個交易日
                mdindex.append(datelist.index(monthday))
    
        wdindex =[] # weekday index, 每周的第一個交易日的索引
        for y in sorted(years):
            weeknum= set([int(d.strftime('%U')) for d in datelist if d.year==y])
            for w in sorted(weeknum):
                wd= min([dt for dt in datelist if dt.year==y and int(dt.strftime('%U'))==w])
                wdindex.append(datelist.index(wd))
        
        #==============================================================================
        # wdindex= [] # 每周第一個交易日在所有日期列表中的 index, 每周的第一個交易日的索引
        # for d in datelist:
        #     if d.weekday() == 0: wdindex.append(datelist.index(d))
        #             
        #==============================================================================
        
        # ===  檢索每個季末交易日的下標: sdindex:  end of season day index   ===
        # 對ndarray or list  進行邏輯運輸時, 需要用np.logical_or()方法才是正確的方法:
        #filter1=  (months==3) or (months==6)
        #filter1=  (months==3).tolist() or (months==6).tolist()  
        #ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
        dt= last60.index.date # 得到ndarray of date, 
        # dti= last60.index     # 得到pd.ts.index.DtetimeIndex of date, 
        months= last60.index.month #得到ndarray of month, 取值范圍為: 1~12
        # nextbar_m= last60.index.shift(1, freq='D').month # 當移動時間下標時, 數據的頻率不能為空
        #  這樣做還是有問題的, pd的做法是: 引用未來1 Day的日期, 也就是當前的日期+1day的日期
        #   比如: 當前的日期是        2016-12-30, 2017-01-03
        #         .shift(1)的日期是: 2016-12-31, 2017-01-04
        # ==> 誤判了4季末的日期變更線坐標位置
        # 解決辦法: 應該讓freq= 'per index bar', 查詢一下pd的doc吧...   
        # 變通辦法: .drop first element value or .delete(0) the first location
        #        and then .insert one value at end, to make the same length
        # 變通辦法之: 用 freq='BQ', 來生成一個dtindex:
        # pd.date_range(start=mi[0], end=mi[-1], freq='BQ') # BQ	business quarter endfrequency
        # Time Series / Date functionality — pandas 0.19.2 documentation  
        # http://pandas.pydata.org/pandas-docs/stable/timeseries.html#offset-aliases    
        # 
        # === 還有更簡潔的辦法: 就是dti.quarter屬性直接提供了第幾個季節   ===
        i_index= last60.index.delete(0)    
        i_index= i_index.insert(-1, last60.index[-1])     # -1 表示最后一個下標位置
        nextbar_m= i_index.month # 
        endMar= np.logical_and(months==3, nextbar_m==4)
        endJun= np.logical_and(months==6, nextbar_m==7)
        endSep= np.logical_and(months==9, nextbar_m==10)
        endDec= np.logical_and(months==12, nextbar_m==1)
        
        tmp1= np.logical_or(endMar, endJun)
        tmp2= np.logical_or(endSep, endDec)
        mask= np.logical_or(tmp1, tmp2)
        sdindex= [dt.tolist().index(i) for i in dt[mask] ]
    
        #print u'\n==> 季節變更坐標線:'
        #print u'    每個季末的x軸的位置下標: %r' % sdindex
        #print u'    每個季末的x軸的位置時間: %r' % last60.index[sdindex]
     
        return mdindex, wdindex, sdindex
        
    
    def savfig(self, savefig=False):
        if savefig:
            now = datetime.datetime.now()
            now_s = now.strftime('%Y%m%d_%H%M%S_')
            microsec = str(now.microsecond)
            #fn= '%s_%s_%s.pdf' %(context.name, now_s, microsec )
            #fig.savefig(fn, dpi=300)
            #print u'\n==> 該pdf文件被創建: %s' %fn
            fn= '%s_%s_%s.png' %(self.context.name, now_s, microsec )
            self.fig.savefig(fn, dpi=300)
            print u'\n==> 該png文件被創建: %s' %fn
        pass
    
        
if __name__ == '__main__':
    pass


代碼(2017.11)

  1. 主塊代碼
  2. 繪圖模塊的代碼
  3. 結果展示

結果展示1:

結果展示2:

主塊代碼: test1_load.py


# -*- coding: utf-8 -*-

import  pandas as pd

import amipy as ami
reload(ami)
import do_plot as dp
reload(dp)

#context = ami.Context('600699.SH')
context = ami.Context('000911.SZ')
stk = ami.Stock(context) #None,None)
stk.grab_data_tdxlday(context, num_days=None)
stk.ohlc = stk.ohlc_raw

stk.ma20 = ami.TTR.sma(stk.ohlc.close, 20)
stk.cyc61 = ami.TTR.sma(stk.ohlc.close, 120)
subset = slice(-120,None) # '2017-07'  '2017'
subset = '2017' #slice(-120,None) # '2017-07'  '2017'

datas = (context, stk, subset)

# 僅繪制主圖    
#dp.plot_candle_only(datas)

# 主圖+成交量圖
dp.plot_candle_vol(datas)

繪圖模塊代碼 do_plot.py


# -*- coding: utf-8 -*-

#import sys

import numpy as np
import pandas as pd

import matplotlib
import matplotlib.pyplot as plt
from matplotlib.ticker import (
        FixedLocator, 
        #MultipleLocator, 
        #LogLocator, 
        
        #NullFormatter, 
        FuncFormatter, 
        #LogFormatter 
        )
from matplotlib.font_manager import FontProperties 
myfont = FontProperties(fname=r"c:\windows\fonts\msyh.ttf")  #size可不用指定
matplotlib.rcParams['axes.unicode_minus'] = False

#import amipy as ami


#==============================================================================
# Python中的作用域及global用法 - Summer_cool - 博客園  
# https://www.cnblogs.com/summer-cool/p/3884595.html
# 
# 函數定義了本地作用域,而模塊定義的是全局作用域。
# 如果想要在函數內定義全局作用域,需要加上global修飾符。
# 
# 變量名解析:LEGB原則
# 當在函數中使用未認證的變量名時,Python搜索4個作用域:
#     [本地作用域(L-local)(函數內部聲明但沒有使用global的變量),
#      之后是上一層結構def或者lambda的本地作用域(E-enclosure),
#      之后是全局作用域(G-global)(函數中使用global聲明的變量或在模塊層聲明的變量),
#      最后是內置作用域(B)(即python的內置類和函數等)]
#      並且在第一處能夠找到這個變量名的地方停下來。
#      如果變量名在整個的搜索過程中都沒有找到,Python就會報錯。
#      
# 補:上面的變量規則只適用於簡單對象,當出現引用對象的屬性時,則有另一套搜索規則:
#     屬性引用搜索一個或多個對象,而不是作用域,並且有可能涉及到所謂的"繼承"
# 補2:global修飾符在python里的一個獨特現象:
#     在模塊層面定義的變量(無需global修飾),
#     如果在函數中沒有再定義同名變量,可以在函數中當做全局變量使用.
#     如果在函數中要對它重新賦值的話, 則必須在本函數中事先聲明為全局變量, 否則會拋出異常.
# 
#     #先聲明全局本函數里用到的全局變量: 圖表, 上下文, 股票對象
#     #使用global語句可以清楚地表明變量是在外面的塊定義的, 而且在本函數內
#     #可以使用或者修改這些變量(前提是必須先聲明為全局變量, 以便告訴python
#     #解釋器這些變量是全局的(主塊和函數塊共有的)已經是在外部--主代碼塊里--定義好了的, 
#     # 或者是本代碼塊要傳遞到主代碼塊里的變量).
#==============================================================================
global fig, ax1, ax2, ax3 # 模塊級變量名, 分別代表: 整個圖表, 子圖1/2/3
global context, stk, subset # 模塊級變量名
global candle_colors, length
ax2=ax3=None #初始化 ax2/ax3 子圖實例為None, 
             #fig和ax1可以不用初始化, 因為調用layout()后總是要返回fig和ax1的
ptype_dict={
        'lday':u'日',
        'lc5':u'五分鍾'} # 這里聲明的變量, 不用加global修飾符, 也是全局變量

def layout(volume_bars=True):
    u'''
    
    '''
    global fig, ax1, ax2, ax3
    if volume_bars:
        fig, (ax1, ax2) = plt.subplots(2, 1, sharex=True, gridspec_kw={'height_ratios': [3,1]} )
        res = fig, (ax1,ax2)
    else:
        fig,ax1 = plt.subplots(1,1)
        res = fig, ax1
    return res

def candles(
            #subset=None,
            col_func=None):
    u'''
    
    subset: 
        slice object, slice(start,stop,step)
        that is:
            slice(100)
            slice(-100,None)
            slice(100,200)
            slice(-200,-100,2)
            '2011-09'
            '2017'
    '''
    global context, stk, subset
    global candle_colors # 可能會被以后的函數所用到(比如畫成交量柱子)
    global length
    
    def default_col_func(index, open1, close, low, high):
        return 'black' if open1[index] > close[index] else 'white' # r g b  cyan black white
    
    col_func= col_func or default_col_func
    ohlc = stk.ohlc[subset] if subset else stk.ohlc
    open1,high,low,close = ohlc.open, ohlc.high, ohlc.low, ohlc.close
    length = len(close)
    x = np.arange(length)
    candle_colors = [col_func(i, open1, close, low, high) for i in x]
    # 計算出 每日的開盤價/收盤價里的最大值和最小值
    oc_min = pd.concat([open1, close], axis=1).min(axis=1)
    oc_max = pd.concat([open1, close], axis=1).max(axis=1)

    #candles = ax1.bar(x, oc_max-oc_min, bottom=oc_min, color=candle_colors, linewidth=0)
    #lines = ax1.vlines(x + 0.4, low, high, color=candle_colors, linewidth=1)
    candles = ax1.bar(x-0.4, oc_max-oc_min, bottom=oc_min, color=candle_colors, linewidth=0.2, edgecolor='black')
    shadlines_up = ax1.vlines(x,    oc_max, high, color=['black']* length, linewidth=0.3)
    shadlines_dn = ax1.vlines(x,    low, oc_min,  color=['black']* length, linewidth=0.3)
    #print candles.__class__, shadlines_up.__class__, shadlines_dn.__class__
    isinstance(candles,      matplotlib.container.BarContainer) == True
    isinstance(shadlines_dn, matplotlib.collections.LineCollection)
    isinstance(shadlines_up, matplotlib.collections.LineCollection)
    
    custom_figure()
    custom_yaxis()
    pass

def primary_curves(): #subset=None):
    #ohlc = stk.ohlc[subset] if subset else stk.ohlc
    #close = ohlc.close
    if (isinstance(stk.ma20, pd.Series) and isinstance(stk.cyc61, pd.Series)):
        ma20 = stk.ma20[subset] if subset else stk.ma20
        cyc61 = stk.cyc61[subset] if subset else stk.cyc61
        length = len(ma20)
        x = np.arange(length)
        indicators = [ma20, cyc61]
        for ind in indicators:
            ax1.plot(x, ind, 'o-', lw=0.1, markersize=0.7, markeredgewidth=0.1, label=ind.name) #帶圓圈標記的實線
        ax1.legend()
        
    custom_xaxis(ax=ax1)
    
    
def secondary_curves(ax,subset=None):
#    ohlc = stk.ohlc[subset] if subset else stk.ohlc
    pass

def vol_bars():
    u'''
    
    '''
    global stk, subset
    ohlc = stk.ohlc[subset] if subset else stk.ohlc
    volume = ohlc['volume']
    #open1,high,low,close = ohlc.open, ohlc.high, ohlc.low, ohlc.close
    x = np.arange(length)
    
    volume_scale = None
    scaled_volume = volume
    if volume.max() > 1000*1000:
        volume_scale = u'百萬股' #'M'
        scaled_volume = volume / 1000.0/1000.0
    elif volume.max() > 1000:
        volume_scale = u'千股'
        scaled_volume = volume / 1000.0
    ax2.bar(x-0.4, scaled_volume, color=candle_colors, linewidth=0.2, edgecolor='black')
    volume_title = 'Volume'
    if volume_scale:
        volume_title = 'Volume (%s)' % volume_scale
    ax2.set_title(volume_title)
    ax2.xaxis.grid(False)
    #plt.setp(ax.get_xticklabels(minor=False), fontsize=6)
    
    custom_xaxis(ax2)
    
    pass

def custom_yaxis():
    u'''
    #   設定 Y 軸上的刻度
    #==================================================================================================================================================
    python - Matplotlib log scale tick label number formatting - Stack Overflow  
    https://stackoverflow.com/questions/21920233/matplotlib-log-scale-tick-label-number-formatting
    '''
    #use_expo=True; 
    expbase=1.1  # 2 e 10
    yaxis= ax1.get_yaxis()
    isinstance(yaxis, matplotlib.axis.YAxis)
    ax1.set_yscale(value='log', basey=expbase)
    pass

def custom_figure():
    u'''  '''
    # 依據繪圖數據的長度和時間軸的比例尺(比如1:16)確定圖表的長度:  
    #fig = plt.gcf()
    #fig.set_size_inches(18.5, 10.5)
    fig.set_size_inches(length/16.0, 6) # /18 /20 /16 diff time-scales
    
    title = u'%s(%s)%s周期蠟燭圖'%(context.name, context.symbol, ptype_dict[context.ptype])
    ax1.set_title(title)
    pass

def custom_xaxis(ax):
    u'''
    
    '''
    global ax1, ax2, ax3
    ohlc = stk.ohlc[subset] if subset else stk.ohlc
    close = ohlc.close
    #length = len(close)
    
    ax.set_xlim(-2, length+10)
    xaxis= ax.get_xaxis()
    yaxis= ax.get_yaxis()
    #   設定 X 軸上的主刻度/次刻度位置
    #==================================================================================================================================================
    mdindex, wdindex, sdindex= ohlc_find_idx_fdim(close) 
    xMajorLocator= FixedLocator(np.array(mdindex)) # 針對主刻度,實例化一個"固定式刻度定位"
    xMinorLocator= FixedLocator(np.array(wdindex)) # 確定 X 軸的 MinorLocator
    
    # 確定 X 軸的 MajorFormatter 和 MinorFormatter 
    # 自定義的刻度格式(應該是一個function)
    datelist = close.index.date.tolist() 
    def x_major_formatter_1(idx, pos=None): 
        u'''
        格式函數的功能: idx 是位置location, 依據位置, 返回對應的日期刻度標簽
        '''
        #return datelist[idx].strftime('%Y-%m-%d')
        return datelist[idx].strftime('%m\n%Y')
    def x_major_formatter_2(idx, pos=None):
        return datelist[idx].strftime('\n\n%m\n%Y')
 
    def x_minor_formatter_1(idx, pos=None):
        #return datelist[idx].strftime(u'一\n%d') # 周一
        return datelist[idx].strftime(u'M\n%d') # 周一
    def x_minor_formatter_2(idx, pos=None):
        return datelist[idx].strftime('%m-%d')
 
    xMajorFormatter_1 = FuncFormatter(x_major_formatter_1)
    xMajorFormatter_2 = FuncFormatter(x_major_formatter_2)
    xMinorFormatter_1 = FuncFormatter(x_minor_formatter_1)
 
    # 設定 X 軸的 Locator 和 Formatter
    xaxis.set_major_locator(xMajorLocator)
    xaxis.set_minor_locator(xMinorLocator)

    xaxis.set_major_formatter(xMajorFormatter_1)
    if ax2 is None:
        xaxis.set_major_formatter(xMajorFormatter_2)
    xaxis.set_minor_formatter(xMinorFormatter_1)

    if ax2 is None: # 僅繪制主圖
        # 設定不顯示的刻度標簽:
        if ax==ax1:
            plt.setp(ax.get_xticklabels(minor=False), visible=True) #主刻度標簽 可見
            plt.setp(ax.get_xticklabels(minor=True), visible=True)  #次刻度標簽 可見
    elif ((ax1 != None) and (ax2 != None)): # case of 主圖+成交量圖
        if ax==ax2:
            plt.setp(ax.get_xticklabels(minor=True), visible=False) #次刻度標簽 隱藏
        elif ax==ax1:
            plt.setp(ax.get_xticklabels(minor=False), visible=False) #主刻度標簽 隱藏
 
    # 設定 X 軸主刻度和次刻度標簽的樣式(字體大小)
    for malabel in ax.get_xticklabels(minor=False):
        malabel.set_fontsize(12) # 6號也太小了
        #malabel.set_horizontalalignment('right')
        #malabel.set_rotation('45')
 
    # if ax == ax1 or ax2:
    for milabel in ax.get_xticklabels(minor=True):
        milabel.set_fontsize(12) # 5 太小了
        #milabel.set_horizontalalignment('right')
        #milabel.set_rotation('45')
        #milabel.set_fontdict=myfont
        #milabel.set_fontproperties=myfont
        #milabel.set_prop=myfont


    #   設置兩個坐標軸上的 grid
    #==================================================================================================================================================
    #xaxis_2.grid(True, 'major', color='0.3', linestyle='solid', linewidth=0.2)
    xaxis.grid(True, 'major', color='0.3', linestyle='dotted', linewidth=0.3)
    xaxis.grid(True, 'minor', color='0.3', linestyle='dotted', linewidth=0.1)
 
    #yaxis_2.grid(True, 'major', color='0.3', linestyle='dashed', linewidth=0.2)
    yaxis.grid(True, 'major', color='0.3', linestyle='dotted', linewidth=0.1)
    yaxis.grid(True, 'minor', color='0.3', linestyle='dotted', linewidth=0.1)


def ohlc_find_idx_fdim(ohlc):
    u'''
    功能: index of  first trading-day in month 
    ------
    - 獲取每個月的第一個交易日的下標(又稱0軸索引). 
      從數據框的時間索引里提取對應的日期, 然后檢索出下標.
    - 另外, 也獲取每個交易周的第一個交易日的下標
    
    輸入:
    - ohlc: pandas數據框
    
    返回:
    - list
    
    例子:
    -------
    >>>  mdindex, wdindex, sdindex= ohlc_find_idx_fdim(ohlc_last60)
    
    '''
    #datelist= [ datetime.date(int(ys), int(ms), int(ds)) for ys, ms, ds in [ dstr.split('-') for dstr in pdata[u'日期'] ] ]
    last60 = ohlc[-250:]
    datelist = last60.index.date.tolist()
    # 確定 X 軸的 MajorLocator
    mdindex= [] # 每個月第一個交易日在所有日期列表中的 index, 月日期索引
    years= set([d.year for d in datelist])  # 所有的交易年份
     
    for y in sorted(years):     
        months= set([d.month for d in datelist if d.year == y])     # 當年所有的交易月份
        for m in sorted(months):
            monthday= min([dt for dt in datelist if dt.year==y and dt.month==m])    # 當月的第一個交易日
            mdindex.append(datelist.index(monthday))

    wdindex =[] # weekday index, 每周的第一個交易日的索引
    for y in sorted(years):
        weeknum= set([int(d.strftime('%U')) for d in datelist if d.year==y])
        for w in sorted(weeknum):
            wd= min([dt for dt in datelist if dt.year==y and int(dt.strftime('%U'))==w])
            wdindex.append(datelist.index(wd))
    
    #==============================================================================
    # wdindex= [] # 每周第一個交易日在所有日期列表中的 index, 每周的第一個交易日的索引
    # for d in datelist:
    #     if d.weekday() == 0: wdindex.append(datelist.index(d))
    #             
    #==============================================================================
    
    # ===  檢索每個季末交易日的下標: sdindex:  end of season day index   ===
    # 對ndarray or list  進行邏輯運輸時, 需要用np.logical_or()方法才是正確的方法:
    #filter1=  (months==3) or (months==6)
    #filter1=  (months==3).tolist() or (months==6).tolist()  
    #ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
    dt= last60.index.date # 得到ndarray of date, 
    # dti= last60.index     # 得到pd.ts.index.DtetimeIndex of date, 
    months= last60.index.month #得到ndarray of month, 取值范圍為: 1~12
    # nextbar_m= last60.index.shift(1, freq='D').month # 當移動時間下標時, 數據的頻率不能為空
    #  這樣做還是有問題的, pd的做法是: 引用未來1 Day的日期, 也就是當前的日期+1day的日期
    #   比如: 當前的日期是        2016-12-30, 2017-01-03
    #         .shift(1)的日期是: 2016-12-31, 2017-01-04
    # ==> 誤判了4季末的日期變更線坐標位置
    # 解決辦法: 應該讓freq= 'per index bar', 查詢一下pd的doc吧...   
    # 變通辦法: .drop first element value or .delete(0) the first location
    #        and then .insert one value at end, to make the same length
    # 變通辦法之: 用 freq='BQ', 來生成一個dtindex:
    # pd.date_range(start=mi[0], end=mi[-1], freq='BQ') # BQ	business quarter endfrequency
    # Time Series / Date functionality — pandas 0.19.2 documentation  
    # http://pandas.pydata.org/pandas-docs/stable/timeseries.html#offset-aliases    
    # 
    # === 還有更簡潔的辦法: 就是dti.quarter屬性直接提供了第幾個季節   ===
    i_index= last60.index.delete(0)    
    i_index= i_index.insert(-1, last60.index[-1])     # -1 表示最后一個下標位置
    nextbar_m= i_index.month # 
    endMar= np.logical_and(months==3, nextbar_m==4)
    endJun= np.logical_and(months==6, nextbar_m==7)
    endSep= np.logical_and(months==9, nextbar_m==10)
    endDec= np.logical_and(months==12, nextbar_m==1)
    
    tmp1= np.logical_or(endMar, endJun)
    tmp2= np.logical_or(endSep, endDec)
    mask= np.logical_or(tmp1, tmp2)
    sdindex= [dt.tolist().index(i) for i in dt[mask] ]

    #print u'\n==> 季節變更坐標線:'
    #print u'    每個季末的x軸的位置下標: %r' % sdindex
    #print u'    每個季末的x軸的位置時間: %r' % last60.index[sdindex]
 

    
    return mdindex, wdindex, sdindex

def plot_candle_only(datas):
    u'''僅繪制主圖    
    '''
    global context, stk, subset
    global fig, ax1, ax2, ax3
    global candle_colors, length
    context, stk, subset = datas

    layout(volume_bars=False)
    candles()
    primary_curves()
    #fig #在ipython console里顯示整個圖表

def plot_candle_vol(datas):
    u'''主圖+成交量圖
    '''
    global context, stk, subset
    global fig, ax1, ax2, ax3
    global candle_colors, length
    context, stk, subset = datas
    
    layout(volume_bars=True)
    candles() 
    primary_curves() 
    vol_bars()
    pass

    
if __name__ == '__main__':
    pass
    


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM