ast模块与astunparse模块


 AST简介 

Abstract Syntax Trees即抽象语法树。Ast是python源码到字节码的一种中间产物,借助ast模块可以从语法树的角度分析源码结构。此外,我们不仅可以修改和执行语法树,还可以将Source生成的语法树unparse成python源码。因此ast给python源码检查、语法分析、修改代码以及代码调试等留下了足够的发挥空间。可以通过将ast.PyCF_ONLY_AST作为标志传递给compile()内置函数,或者使用此模块中提供的parse()帮助器生成抽象语法树。结果将是一个对象树,其类都继承自ast.AST。可以使用内置的compile()函数将抽象语法树编译成Python代码对象。

官网:https://docs.python.org/3.6/library/ast.html

Python官方提供的CPython解释器对python源码的处理过程如下:

  1. Parse source code into a parse tree (Parser/pgen.c)
  2. Transform parse tree into an Abstract Syntax Tree (Python/ast.c)
  3. Transform AST into a Control Flow Graph (Python/compile.c)
  4. Emit bytecode based on the Control Flow Graph (Python/compile.c)

即实际python代码的处理过程如下:

源代码解析 --> 语法树 --> 抽象语法树(AST) --> 控制流程图 --> 字节码

上述过程在python2.5之后被应用。python源码首先被解析成语法树,随后又转换成抽象语法树。在抽象语法树中我们可以看到源码文件中的python的语法结构。

大部分时间编程可能都不需要用到抽象语法树,但是在特定的条件和需求的情况下,AST又有其特殊的方便性。

下面是一个抽象语法的简单实例。

Module(body=[
    Print(
          dest=None,
          values=[BinOp( left=Num(n=1),op=Add(),right=Num(n=2))],
          nl=True,
 )]) 

Compile函数

先简单了解一下compile函数。compile函数将源代码编译成可以由exec()或eval()执行的代码对象。返回code类型。

compile(source, filename, mode[, flags[, dont_inherit]]) 

  • source -- 字符串或者AST(Abstract Syntax Trees)对象。一般可将整个py文件内容file.read()传入。
  • filename -- 代码文件名称,如果不是从文件读取代码则传递一些可辨认的值。
  • mode -- 指定编译代码的种类。可以指定为 exec, eval, single。
  • flags -- 变量作用域,局部命名空间,如果被提供,可以是任何映射对象。
  • flags和dont_inherit是用来控制编译源码时的标志。
func_def = \
"""
def add(x, y):
    return x + y
print(add(3, 5))
"""

cm = compile(func_def, '<string>', 'exec')
print(type(cm))
isinstance(cm, types.CodeType)

exec(func_def) #传入的类型可以是str、bytes或code。
exec(cm) #传入的类型可以是str、bytes或code。

上面func_def经过compile编译得到字节码,cm即code对象,True == isinstance(cm, types.CodeType)。

compile(source, filename, mode, ast.PyCF_ONLY_AST)  <==> ast.parse(sourcefilename='<unknown>'mode='exec')

生成ast

 除了python内置ast模块可以生成抽象语法树,还有很多第三方库,如astunparse, codegen, unparse等。这些第三方库不仅能够以更好的方式展示出ast结构,还能够将ast反向导出python source代码。

安装astunparse:pip install astunparse

astunparse官网:https://pypi.org/project/astunparse/

import ast, astunparse
func_def = \
"""
def add(x, y):
    return x + y
print(add(3, 5))
"""

r_node = ast.parse(func_def)
print(ast.dump(r_node))
print(astunparse.dump(r_node))
import ast
import astunparse
func_def = \
"""
a = 3
b = 5
def add(x, y):
    return x + y
print(add(a,b))
"""

def nodeTree(node:str):
    str2list = list(node.replace(' ', ''))
    count = 0
    for i, e in enumerate(str2list):
        if e == '(':
            count += 1
            str2list[i] = '(\n{}'.format('|   ' * count)
        elif e == ')':
            count -= 1
            str2list[i] = '\n{})'.format('|   ' * count)
        elif e == ',':
            str2list[i] = ',\n{}'.format('|   ' * count)
        elif e == '[':
            count += 1
            str2list[i] = '[\n{}'.format('|   ' * count)
        elif e == ']':
            count -= 1
            str2list[i] = '\n{}]'.format('|   ' * count)

    return ''.join(str2list)

# def nodeTree(node:str):
#     reStr = ''
#     count = 0
#     for e in node:
#         if e == '(':
#             count += 1
#             reStr += '(\n{}'.format('|   ' * count)
#         elif e == ')':
#             count -= 1
#             reStr += '\n{})'.format('|   ' * count)
#         elif e == ',':
#             reStr += ',\n{}'.format('|   ' * count)
#         elif e == '[':
#             count += 1
#             reStr += '[\n{}'.format('|   ' * count)
#         elif e == ']':
#             count -= 1
#             reStr += '\n{}]'.format('|   ' * count)
#         else:
#             reStr += e
#     return reStr

r_node = ast.parse(func_def)
cm = ast.dump(r_node)
print(nodeTree(cm))
自定义展示出ast结构的函数

通过ast的parse方法得到ast tree的根节点r_node, 我看可以通过根节点来遍历语法树,从而对python代码进行分析和修改。
ast.parse(可以直接查看ast模块的源代码)方法实际上是调用内置函数compile进行编译,源码如下所示:

def parse(source, filename='<unknown>', mode='exec'):
    """
    Parse the source into an AST node.
    Equivalent to compile(source, filename, mode, PyCF_ONLY_AST).
    """
    return compile(source, filename, mode, PyCF_ONLY_AST)

传递给compile特殊的flag = PyCF_ONLY_AST, 来通过compile返回抽象语法树。

节点类型分析

import ast
root_node = ast.parse("print('hello world')")
print(ast.dump(root_node))

输出:

Module(body=[Expr(value=Call(func=Name(id='print', ctx=Load()), args=[Str(s='hello world')], keywords=[]))])

语法树中的每个节点都对应ast下的一种类型,根节点是ast.Moudle类型,在分析的时候可以通过isinstance函数方便的进行节点类型的判断。

import ast
root_node = ast.parse("print('hello world')")
print(ast.dump(root_node))
print(isinstance(root_node,ast.Module))
print(isinstance(root_node,ast.Expr))
print(isinstance(root_node.body[0],ast.Expr))

ast中存在的节点的所有类型可以参考:ast节点类型
比如 a = 10这样一条语句对应ast.Assign节点类型,而Assign节点类型分别有两个子节点, 分别为ast.Name类型的a和ast.Num类型的10等。
我们可以通过ast.dump(node)函数来将node格式化,并进行打印,以查看节点内容,以“a = 10”这行代码为例。


Module(body=[Assign(targets=[Name(id='a', ctx=Store())], value=Num(n=10))])
(1) root节点
Module(body=[Assign(targets=[Name(id='a', ctx=Store())], value=Num(n=10))])
root节点是Module类型,由于只有一行代码,所有root节点只有Assign这样一个子节点。

(2) 子节点
Assign(targets=[Name(id='a', ctx=Store())], value=Num(n=10))
上述的Assign节点有三个子节点,分别是Name, Store和Num.
Name(id='a', ctx=Store())
Num(n=10)
而Name有一个子节点,Store.
Store()(Store表示Name中操作时赋值, 类型的有Load,del, 具体参考节点类型的文档)
一个简单的“a = 10”的这样一行代码,我们就可以通过上述的这种ast tree去分析和修改代码结构。

语法树的遍历分析

1. visitor的定义

可以通过ast模块的提供的visitor来对语法树进行遍历。
ast.NodeVisitor是一个专门用来遍历语法树的工具,我们可以通过继承这个类来完成对语法树的遍历以及遍历过程中的处理。

import ast
import astunparse
func_def = \
"""
a = 3
b = 5
def add(x, y):
    return x + y
print(add(a,b))
"""
class CodeVisitor(ast.NodeVisitor):
    def generic_visit(self, node):
        print(type(node).__name__,end=', ')
        ast.NodeVisitor.generic_visit(self, node)

    def visit_FunctionDef(self, node):
        print(type(node).__name__,end=', ')
        ast.NodeVisitor.generic_visit(self, node)

    def visit_Assign(self, node):
        print(type(node).__name__,end=', ')
        ast.NodeVisitor.generic_visit(self, node)
r_node = ast.parse(func_def)
visitor = CodeVisitor()
visitor.visit(r_node)
View Code
class CodeVisitor(ast.NodeVisitor):
    def generic_visit(self, node):
        print type(node).__name__
        ast.NodeVisitor.generic_visit(self, node)
 
    def visit_FunctionDef(self, node):
        print type(node).__name__
        ast.NodeVisitor.generic_visit(self, node)
 
    def visit_Assign(self, node):
        print type(node).__name__
        ast.NodeVisitor.generic_visit(self, node)

如上述代码,定义类CodeVisitor,继承自NodeVisitor,这里面主要有两种类型的函数,一种的generic_visit,一种是"visit_" + "Node类型"。
visitor首先从根节点root进行遍历,在遍历的过程中,假设节点类型为Assign,如果存在visit_Assign类型的函数,则调用visit_Assgin函数,如果不存在则调用generic_visit函数。
总的来说就是每个节点类型都有专用的类型处理函数,如果不存在,则调用通用的的处理函数generic_visit.
关于visitor进行语法树的遍历,stackoverflow上有一篇文章讲的比较详细:Simple example of how to use ast.NodeVisitor
注意:
在每个函数处理中,根据需求需要加上ast.NodeVisitor.generic_visit(self, node)这段代码,否则visitor不会继续访问当前节点的子节点。
e.g. 如果定义如下的函数:
def visit_Moudle(self, node):
     print type(node).__name__
那么,首先访问根节点root,root为Moudle类型,会调用visit_Moudle函数,由于visit_Moudle函数中没有调用NodeVisitor.generic_visit(self, node),所以此次遍历只遍历了根节点root,并没有遍历其他节点。

2. walk方式遍历

 

for node in ast.walk(tree):
    if isinstance(node, ast.FunctionDef):
        print(node.name)

 

节点的修改

ast模块同样提供了一个NodeTransfomer节点来支持对node的修改,NodeTransfomer继承自NodeVisitor,并重写了generic_visit函数。
对于NodeTransfomer的generic_visit以及visit_ + 节点类型的函数,都需要返回一个node,可以返回原始node,一个新的替代的node,或者是返回Node代表remove掉这个节点。
假设我们有如下的代码:

"""ast test code"""
a = 10
b = "test"
print(a)

我们定义一个NodeTransform的visitor如下:

class ReWriteName(ast.NodeTransformer):
    def generic_visit(self, node):
        has_lineno = getattr(node, "lineno", "None")
        col_offset = getattr(node, "col_offset", "None")
        print type(node).__name__, has_lineno, col_offset
        ast.NodeTransformer.generic_visit(self, node)
        return node
 
    def visit_Name(self, node):
        new_node = node
        if node.id == "a":
            new_node = ast.Name(id = "a_rep", ctx = node.ctx)
        return new_node
 
    def visit_Num(self, node):
        if node.n == 10:
            node.n = 100
        return node

在visit_Name中,将变量"a"替换成了变量"a_rep",执行到a = 10以及print a的时候,都会将a替换成a_rep,并返回一个新节点。
在visit_Num中,简单粗暴的将10替换成了100,返回修改后的原节点。
我们通过如下方式运用这个NodeTransfomer visitor

file = open("code.py", "r")
source = file.read()
visitor = ReWriteName()
root = ast.parse(source)
root = visitor.visit(root)
ast.fix_missing_locations(root)
 
code_object = compile(root, "<string>", "exec")
exec code_object

ast作用在python解析语法之后,编译成pyCodeObject字节码结构之前,通过NodeTransformer修改后,返回修改后的语法树,我们通过内置模块compile编译成pyCodeObject对象,交给python虚拟机执行。
执行结果:100
可以看到,我们同时将a = 10和print a两处将a名字换成了a_rep,并将10替换成了100,最后打印的结果是100,成功修改了语法树的节点。
关于节点的修改,这里有比较好的例子可以参考:https://greentreesnakes.readthedocs.org/en/latest/examples.html
注意:
修改语法树节点,尤其是删除一个语法树节点时要慎重,因为修改或者删除后有可能返回错误的语法树,直到compile或者执行的时候才会发现问题。
通过节点修改python code就可以通过上述方法进行,不过请注意,在运用visitor的代码中有ast.fix_missing_locations(root)这样一行代码,这是因为我们自己创建的节点是不包含lineno以及col_offset这些必要的属性,必须手动修改添加指定,新添加的节点代码的行位置以及偏移位置。

修复节点位置

我们可以通过相应的方法,对默认没有lineno以及col_offset的节点进行位置的修复,以方便在代码中获取每个节点的位置信息,主要有三种方法进行修复。
1)ast.fix_missing_locations(node)
函数递归的将父节点的位置信息(lineno以及col_offset)赋值给没有位置信息的子节点。
2)ast.copy_location(new_node, node)
将node的位置信息拷贝给new_node节点,并返回new_node节点。当我们将旧节点替换成一个新节点的时候,这种方法比较适用。
3)ast.increment_lineno(node, n=1)
将node节点以及其所以子节点的行号加上n。
3  分析
我们通过“三. 节点的修改"中的例子来分析location信息。
在例子中,我们只有在visit_Name的时候返回的新的节点,这时候节点是没有lineno以及col_offset属性,我们可以通过两种方式获取。
一是如上述代码中,利用ast.fix_missing_locations函数来修复,在"a = 10"以及"print a"中,Name节点a跟父节点的lineno相同,但是此时col_offset会有差异。
二是我们将visit_Name的代码修改如下:

def visit_Name(self, node):
    new_node = node
    if node.id == "a":
        new_node = ast.Name(id = "a_rep", ctx = node.ctx)
        ast.copy_location(new_node, node)
    return new_node

通过copy_location将旧节点的location信息拷贝给新节点。

AST应用

AST模块实际编程中很少用到,但是作为一种源代码辅助检查手段是非常有意义的;语法检查,调试错误,特殊字段检测等。

上面通过为函数添加调用日志的信息是一种调试python源代码的一种方式,不过实际中我们是通过parse整个python文件的方式遍历修改源码

相关代码

# -- encoding:utf-8 --
"""
Greate by ibf on 2019
"""
import ast
import astunparse
func_def = \
"""
a = 3
b = 5
def add(x, y):
    return x + y
print(add(a,b))
"""

def nodeTree(node:str):
    str2list = list(node.replace(' ', ''))
    count = 0
    for i, e in enumerate(str2list):
        if e == '(':
            count += 1
            str2list[i] = '(\n{}'.format('|   ' * count)
        elif e == ')':
            count -= 1
            str2list[i] = '\n{})'.format('|   ' * count)
        elif e == ',':
            str2list[i] = ',\n{}'.format('|   ' * count)
        elif e == '[':
            count += 1
            str2list[i] = '[\n{}'.format('|   ' * count)
        elif e == ']':
            count -= 1
            str2list[i] = '\n{}]'.format('|   ' * count)

    return ''.join(str2list)

# def nodeTree(node:str):
#     reStr = ''
#     count = 0
#     for e in node:
#         if e == '(':
#             count += 1
#             reStr += '(\n{}'.format('|   ' * count)
#         elif e == ')':
#             count -= 1
#             reStr += '\n{})'.format('|   ' * count)
#         elif e == ',':
#             reStr += ',\n{}'.format('|   ' * count)
#         elif e == '[':
#             count += 1
#             reStr += '[\n{}'.format('|   ' * count)
#         elif e == ']':
#             count -= 1
#             reStr += '\n{}]'.format('|   ' * count)
#         else:
#             reStr += e
#     return reStr

'''遍历节点'''
class CodeVisitor(ast.NodeVisitor):
    def generic_visit(self, node):
        print(type(node).__name__,end=', ')
        ast.NodeVisitor.generic_visit(self, node)

    def visit_FunctionDef(self, node):
        print(type(node).__name__,end=', ')
        ast.NodeVisitor.generic_visit(self, node)

    def visit_Assign(self, node):
        print(type(node).__name__,end=', ')
        ast.NodeVisitor.generic_visit(self, node)

'''修改节点'''
class ReWriteName(ast.NodeTransformer):
    def generic_visit(self, node):
        # has_lineno = getattr(node, "lineno", "None")
        # col_offset = getattr(node, "col_offset", "None")
        # print(type(node).__name__, has_lineno, col_offset)

        ast.NodeTransformer.generic_visit(self, node)
        return node

    def visit_Name(self, node):
        new_node = node
        if node.id == "a":
            new_node = ast.Name(id="a_rep", ctx=node.ctx)
        return new_node

    def visit_Num(self, node):
        if node.n == 3:
            node.n = 100
        return node
    def visit_BinOp(self,node):
            node.op = ast.Sub()
            return node

r_node = ast.parse(func_def)
visitor = ReWriteName()
visitor.visit(r_node)
print(astunparse.unparse(r_node))#将ast反向导出python source代码。
# ast.fix_missing_locations(r_node)
View Code

 

 

参考链接:https://www.cnblogs.com/yssjun/p/10069199.html  http://www.dalkescientific.com/writings/diary/archive/2010/02/22/instrumenting_the_ast.html  https://pycoders-weekly-chinese.readthedocs.io/en/latest/issue3/static-modification-of-python-with-python-the-ast-module.html#cpython

 


免责声明!

本站转载的文章为个人学习借鉴使用,本站对版权不负任何法律责任。如果侵犯了您的隐私权益,请联系本站邮箱yoyou2525@163.com删除。



 
粤ICP备18138465号  © 2018-2025 CODEPRJ.COM