keras可視化pydot graphviz問題


1. 安裝

pip install graphviz

pip install pydot

pip install pydot-ng  # 版本兼容需要

# 測試一下
from keras.utils.visualize_util import plot

 

2. 使用:

#!/usr/bin/env python
# coding=utf-8

"""
利用keras cnn進行端到端的驗證碼識別, 簡單直接暴力。
迭代100次可以達到95%的准確率,但是很容易過擬合,泛化能力糟糕, 除了增加訓練數據還沒想到更好的方法.

__autho__: jkmiao
__email__: miao1202@126.com
___date__:2017-02-08

"""
from keras.models import Sequential
from keras.layers import Dense, Dropout, Flatten, Activation, LSTM, Reshape
from keras.layers import Convolution2D, MaxPooling2D
from PIL import Image
import os, random
import numpy as np
from keras.models import model_from_json
from util import CharacterTable
from keras.callbacks import ModelCheckpoint
from sklearn.model_selection import train_test_split
from keras.utils.visualize_util import plot


def load_data(path='img/clearNoise/'):
    fnames = [os.path.join(path, fname) for fname in os.listdir(path) if fname.endswith('jpg')]
    random.shuffle(fnames)
    data, label = [], []
    for fname in fnames:
        imgLabel = fname.split('/')[-1].split('_')[0]
        imgM = np.array(Image.open(fname).convert('L'))
        imgM = 1 * (imgM>180)
        data.append(imgM.reshape((imgM.shape[0], imgM.shape[1], 1)))
        label.append(imgLabel.lower())
    return np.array(data), label

ctable = CharacterTable()
data, label = load_data()
label_onehot = np.zeros((len(label), 216))
for i, lb in enumerate(label):
    label_onehot[i,:] = ctable.encode(lb)
print data.shape
print label_onehot.shape

x_train, x_test, y_train, y_test = train_test_split(data, label_onehot, test_size=0.1)

DEBUG = False

# 建模
if DEBUG:
    model = Sequential()
    model.add(Convolution2D(32, 5, 5, border_mode='valid', input_shape=(60, 200, 1), name='conv1'))
    model.add(Activation('relu'))
    model.add(MaxPooling2D(pool_size=(2,2)))
    model.add(Convolution2D(32, 3, 3, name='conv2'))
    model.add(Activation('relu'))
    model.add(MaxPooling2D(pool_size=(2,2)))
    model.add(Flatten())
   # model.add(Reshape((20, 60)))
   # model.add(LSTM(32))
    model.add(Dense(512))
    model.add(Activation('relu'))
    model.add(Dense(216))
    model.add(Activation('softmax'))

else:
    model = model_from_json(open('model/ba_cnn_model2.json').read())
    model.load_weights('model/ba_cnn_model2.h5')

# 編譯
model.compile(loss='mse', optimizer='adam', metrics=['accuracy'], class_mode='categorical')
model.summary()

# 繪圖 plot(model, to_file
='model.png', show_shapes=True) # 訓練 check_pointer = ModelCheckpoint('./model/train_len_size1.h5', monitor='val_loss', verbose=1, save_best_only=True) model.fit(x_train, y_train, batch_size=32, nb_epoch=5, validation_split=0.1, callbacks=[check_pointer]) json_string = model.to_json() with open('./model/ba_cnn_model2.json', 'w') as fw: fw.write(json_string) model.save_weights('./model/ba_cnn_model2.h5') # 測試 y_pred = model.predict(x_test, verbose=1) cnt = 0 for i in range(len(y_pred)): guess = ctable.decode(y_pred[i]) correct = ctable.decode(y_test[i]) if guess == correct: cnt += 1 if i%10==0: print '--'*10, i print 'y_pred', guess print 'y_test', correct print cnt/float(len(y_pred))

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM