利用tensorboard將數據可視化


注:代碼是網上下載的,但是找不到原始出處了,侵權則刪

先寫出visual類:

class TF_visualizer(object):
    def __init__(self, dimension, vecs_file, metadata_file, output_path):
        self.dimension = dimension
        self.vecs_file = vecs_file
        self.metadata_file = metadata_file
        self.output_path = output_path
        
        self.vecs = []
        with open(self.vecs_file, 'r') as vecs:
        #with open(self.vecs_file, 'rb') as vecs:
            for i, line in enumerate(vecs):
                if line != '': self.vecs.append(line)

    def visualize(self):
        # adding into projector
        config = projector.ProjectorConfig()

        placeholder = np.zeros((len(self.vecs), self.dimension))
        
        for i, line in enumerate( self.vecs ):   
            placeholder[i] = np.fromstring(line, sep=',')
        #for i,line in enumerate(self.vecs):
        #    placeholder[i] = np.fromstring(line)

        embedding_var = tf.Variable(placeholder, trainable=False, name='amazon')

        embed = config.embeddings.add()
        embed.tensor_name = embedding_var.name
        embed.metadata_path = self.metadata_file

        # define the model without training
        sess = tf.InteractiveSession()
        
        tf.global_variables_initializer().run()
        saver = tf.train.Saver()
        
        saver.save(sess, os.path.join(self.output_path, 'w2x_metadata.ckpt'))

        writer = tf.summary.FileWriter(self.output_path, sess.graph)
        projector.visualize_embeddings(writer, config)
        sess.close()
        print('Run `tensorboard --logdir={0}` to run visualize result on tensorboard'.format(self.output_path))

然后調用類:

output = '/home/xx'

# create a new tensor board visualizer
visualizer = TF_visualizer(dimension = 768,
                           vecs_file = os.path.join(output, 'amazon_vec.tsv'),
                           #vecs_file = os.path.join(output, 'mnist_10k_784d_tensors.bytes'),
                           metadata_file = os.path.join(output, 'amazon.tsv'),
                           output_path = output)
visualizer.visualize()

其中,amazon_vec.tsv中存放向量(包括詞向量,句子向量...),amazon.tsv中存放原始數據,格式為id,label,title,id和title可以隨意定義,label則為對應向量的標識,兩個文件是 一一對應的(即amazon_vec中的第一行數據對應amazon中第一行數據)

最后,命令行輸入

tensorboard --logdir=/home/xx

在瀏覽器輸入http://xx-desktop:6006即可看到可視化的數據(6006是默認端口)


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM