魔改keras中的model.summary()


之前做項目,好奇這個函數是怎么實現的,我把源碼看了一遍,魔改代碼,把沒用的刪除,重新封裝為一個類,還加上了可以輸出至txt的功能
'''
class print_summary_magic_modification:
def init(self, model, file_path):
self.model = model
self.file_path = file_path

def params_nums(weights):
    return int(np.sum([K.count_params(p) for p in set(weights)]))


def print_row(self, fields, positions):
    line = ''
    for i in range(len(fields)):
        if i > 0:
            line = line[:-1] + ' '
        line += str(fields[i])
        line = line[:positions[i]]
        line += ' ' * (positions[i] - len(line))
    print(line)

def print_layer_summary(self, layer, positions):
    try:
        output_shape = layer.output_shape
    except AttributeError:
        output_shape = 'multiple'
    name = layer.name
    cls_name = layer.__class__.__name__
    fields = [name + ' (' + cls_name + ')',
              output_shape, layer.count_params()]
    self.print_row(fields, positions)

def print_summary(self):
    """Prints a summary of a model.
    """
    line_length = 65
    positions = [29, 55, 100]

    # header names for the different log elements
    to_display = ['Layer (type)', 'Output Shape', 'Param #']

    print('_' * line_length)
    self.print_row(to_display, positions)
    print('=' * line_length)

    layers = self.model.layers
    for i in range(len(layers)):
        self.print_layer_summary(layers[i], positions)
        if i == len(layers) - 1:
            print('=' * line_length)
        else:
            print('_' * line_length)


def print_summary2txt(self):
    """Prints a summary of a model.
    """
    with open(self.file_path, 'a', encoding='utf-8') as f:

        line_length = 65
        positions = [29, 55, 100]

        # header names for the different log elements
        to_display = ['Layer (type)', 'Output Shape', 'Param #']

        print('_' * line_length)
        self.print_row(to_display, positions)
        print('=' * line_length)

        layers = self.model.layers
        for i in range(len(layers)):
            self.print_layer_summary(layers[i], positions)
            if i == len(layers) - 1:
                print('=' * line_length)
            else:
                print('_' * line_length)

'''

下面這個功能可以直接使用model.summary()輸出至txt文件,我在google中搜了好久找見的pythonic代碼
'''
output_file_path = ''
file_name = ''
with open(output_file_path + file_name, 'w') as f:
with redirect_stdout(f):
model.summary()

'''


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM