基于语法树的SQL自动改写工具开发系列(2)-使用PYTHON进行简单SQL改写的开发实战-二、实战

时间:2024-11-17 09:02:30

根据上一篇,基于语法树的SQL自动改写工具开发系列(1)-离线安装语法树解析工具antlr4-DA-技术分享-M版,先在本地部署好开发环境。

DEMO 1

写一段期望改写的原始SQL,使用pygrun进行解析
比如在原生PG中不支持ORACLE中的table()函数,但有替代的unnest改写方式
ORACLE:

SELECT A FROM TABLE(VAR) T

PG:

SELECT A FROM (SELECT * FROM UNNEST(VAR) COLUMN_VALUE) T

要实现的改写规则为,找到table函数的整个表达式,并取出它的入参,将table函数表达式的节点替换成(SELECT * FROM UNNEST(入参) COLUMN_VALUE)
语法树:

C:\antlr>pygrun PlSql sql_script  --tree
SELECT A FROM TABLE(VAR) T;
^Z
(sql_script
   (unit_statement
      (data_manipulation_language_statements
         (select_statement
            (select_only_statement
               (subquery
                  (subquery_basic_elements
                     (query_block SELECT
                        (selected_list
                           (select_list_elements
                              (expression
                                 (logical_expression
                                    (unary_logical_expression
                                       (multiset_expression
                                          (relational_expression
                                             (compound_expression
                                                (concatenation
                                                   (model_expression
                                                      (unary_expression
                                                         (atom
                                                            (general_element
                                                               (general_element_part
                                                                  (id_expression
                                                                     (regular_id A))))))))))))))))
                        (from_clause FROM
                           (table_ref_list
                              (table_ref
                                 (table_ref_aux
                                    (table_ref_aux_internal
                                       (dml_table_expression_clause
                                          (table_collection_expression TABLE (
                                             (expression
                                                (logical_expression
                                                   (unary_logical_expression
                                                      (multiset_expression
                                                         (relational_expression
                                                            (compound_expression
                                                               (concatenation
                                                                  (model_expression
                                                                     (unary_expression
                                                                        (atom
                                                                           (general_element
                                                                              (general_element_part
                                                                                 (id_expression
                                                                                    (regular_id VAR)))))))))))))) ))))
                                 (table_alias
                                    (identifier
                                       (id_expression
                                          (regular_id T))))))))))))))) ; <EOF>)

C:\antlr>

以下为一个完整且可以运行的demo代码,用于说明如何进行语法改写

from antlr4 import FileStream, CommonTokenStream
from PlSqlLexer import PlSqlLexer
from PlSqlParser import PlSqlParser
from PlSqlParserVisitor import PlSqlParserVisitor
from antlr4.TokenStreamRewriter import TokenStreamRewriter

class SQLTransformer(PlSqlParserVisitor):
    def __init__(self, token_stream):
        super().__init__()
        self.rewriter = TokenStreamRewriter(token_stream)

    def visitDml_table_expression_clause(self, ctx: PlSqlParser.Dml_table_expression_clauseContext):
        if ctx.table_collection_expression() and ctx.table_collection_expression().getText().startswith('table'):
            argument = ctx.table_collection_expression().expression().getText()
            new_text = "(select * from unnest({}) column_value)".format(argument)
            self.rewriter.replace(TokenStreamRewriter.DEFAULT_PROGRAM_NAME, ctx.table_collection_expression().start.tokenIndex, ctx.table_collection_expression().stop.tokenIndex, new_text)
        return self.visitChildren(ctx)

def main(input_file, output_file):
    input_stream = FileStream(input_file, encoding='utf-8')
    lexer = PlSqlLexer(input_stream)
    stream = CommonTokenStream(lexer)
    parser = PlSqlParser(stream)
    tree = parser.sql_script()

    transformer = SQLTransformer(stream)
    transformer.visit(tree)
    output_text = transformer.rewriter.getDefaultText()
    with open(output_file, 'w', encoding='utf-8',newline='') as f:
        f.write(output_text)

if __name__ == '__main__':
    input_file = 'input.sql'
    output_file = 'output.sql'
    main(input_file, output_file)

其中def main内的代码基本可以固定,我们直接看class SQLTransformer里的def visitDml_table_expression_clause
visitDml_table_expression_clause其实是PlSqlParserVisitor.py里面的一个def,这个def的名称由visit加上节点名组成,也就是说,语法树中的每一个节点,都有一个对应的visit。而原本PlSqlParserVisitor.py里的每个visit里面都是空的,直接就return出去了:

def  visitDml_table_expression_clause(self, ctx:PlSqlParser.Dml_table_expression_clauseContext):

return  self.visitChildren(ctx)

我们自己写的这个visit就是实现了里面的具体内容。

如果我们需要修改某个语法,可以从语法树中,找到这个语法相关的上下文的最小节点,以本文前面输出的语法树为例,就应该是 (table_collection_expression TABLE (,所以理论上,我们再写个visitTable_collection_expression就好了,但本文的demo代码是从visitDml_table_expression_clause开始,是为了说明如何引用当前节点的下级节点

if ctx.table_collection_expression() and ctx.table_collection_expression().getText().startswith('table')

这句是一个判断,作用是,判断当前节点下,是否存在table_collection_expression这个节点,我们可以去对比语法树,如果没有使用table函数,是不会有这个节点的;第二个条件就是,获取table_collection_expression这个节点的文本,判断它是不是使用 table开始。这里要注意,如果不写第一个条件,在没有table_collection_expression节点时,对它执行getText会报错。

argument = ctx.table_collection_expression().expression().getText()

这里是取出table_collection_expression的下一个叫expression的节点,对照语法树可以看到,虽然里面有很多层,但这个节点实际只包含VAR这个文本,因此这里就可以得到argument="VAR",即前文例子中,table函数的入参。

new_text = "(select * from unnest({}) column_value)".format(argument)

这一句很好理解,就是格式化一个字符串,把argument的值替换{},得到(select * from unnest(VAR) column_value)

self.rewriter.replace(TokenStreamRewriter.DEFAULT_PROGRAM_NAME, ctx.table_collection_expression().start.tokenIndex, ctx.table_collection_expression().stop.tokenIndex, new_text)

这一句就是最关键的,self.rewriter是前面定义的TokenStreamRewriter(token_stream),在TokenStreamRewriter里面,可以支持对节点的替换、删除、增加等操作。
TokenStreamRewriter.replace有4个入参,分别为程序名,开始位置、结束位置、需要替换成的文本。
程序名一般固定使用TokenStreamRewriter.DEFAULT_PROGRAM_NAME就行,如果期望一次解析,就能做多种替换,比如同时生成支持PG和MYSQL的两种语法,就可以在这里设置程序名,针对不同的程序名写不同的规则。
开始位置和结束位置,可以使用对应节点的start.tokenIndexstop.tokenIndex
至此,我们就完成了一个改写规则的开发。

DEMO 2

如果需要在一次语法树解析中就完成多种规则的执行,可以再添加几个def visit,比如我们再针对create type语句来进行改写。

ORACLE:

CREATE OR REPLACE TYPE TY_TEST AS OBJECT(COL1 INT,COL2 VARCHAR(20));

PG:

CREATE TYPE TY_TEST AS (COL1 INT,COL2 VARCHAR(20));

改写规则为,对于create_type的语法节点,将create or replace 改为create,并且删除object

语法树:

C:\antlr>pygrun PlSql sql_script  --tree
CREATE OR REPLACE TYPE TY_TEST AS OBJECT(COL1 INT,COL2 VARCHAR(20));
^Z
(sql_script
   (unit_statement
      (create_type CREATE OR REPLACE TYPE
         (type_definition
            (type_name
               (id_expression
                  (regular_id TY_TEST)))
            (object_type_def
               (object_as_part AS OBJECT) (
               (object_member_spec
                  (identifier
                     (id_expression
                        (regular_id COL1)))
                  (type_spec
                     (datatype
                        (native_datatype_element INT)))) ,
               (object_member_spec
                  (identifier
                     (id_expression
                        (regular_id COL2)))
                  (type_spec
                     (datatype
                        (native_datatype_element VARCHAR)
                        (precision_part (
                           (numeric 20) ))))) ))))) ; <EOF>)

改写代码

    def visitCreate_type(self, ctx:PlSqlParser.Create_typeContext):
        # 检查并修改 'create or replace' 为 'create'
        if ctx.getChild(0).getText() == 'create' and ctx.getChild(1).getText() == 'or' and ctx.getChild(2).getText() == 'replace':
           # print(f"Modifying: {ctx.getChild(0).getText()} {ctx.getChild(1).getText()} {ctx.getChild(2).getText()}")
            self.rewriter.replace(
                TokenStreamRewriter.DEFAULT_PROGRAM_NAME,
                ctx.getChild(0).symbol.tokenIndex,
                ctx.getChild(2).symbol.tokenIndex,
                'create'
            )   
        object_as_part = ctx.type_definition().object_type_def().object_as_part()
        #print(f"ctx.object_as_part.getText(): {object_type_def.object_as_part().getChild(1).getText()}")
        # 删除 object关键字
        if object_as_part.getChild(1).getText()=='OBJECT':
            self.rewriter.delete(
                TokenStreamRewriter.DEFAULT_PROGRAM_NAME,
                object_as_part.getChild(1).symbol.tokenIndex,
                object_as_part.getChild(1).symbol.tokenIndex
            )
        return self.visitChildren(ctx)

这里可以从(create_type CREATE OR REPLACE TYPE看到,CREATE OR REPLACE TYPE这一串都在create_type这个节点上,可以通过getChild(n)来取出中间的每一部分,而这每一部分的位置,则是通过symbol.tokenIndex获取。
然后从(object_as_part AS OBJECT) (中可以看到,我们期望删除的object是在object_as_part这个节点的第二个字符串,因此使用了getChild(1)。这里需要注意,由于object_as_part这个节点是在create_type这个节点的下面很多层,然后后面会要多次使用这个节点,所以可以定义一个object_as_part = ctx.type_definition().object_type_def().object_as_part(),减少冗余代码。
找到OBJECT后,执行self.rewriter.delete,就可以把OBJECT删掉了。

DEMO 3

在ORACLE的sql脚本中,create type应该以/结尾,但是在OG中,则不能有/,如果一个脚本文件里混合了多种语句,就会出现有的/要删,有的不能删,因此我们可以写一个规则,将所有create type语句后面的/删掉。

    def visitSql_script(self, ctx:PlSqlParser.Sql_scriptContext):
        for i in range(ctx.getChildCount() - 1):
            unit = ctx.getChild(i)
            if isinstance(unit, PlSqlParser.Unit_statementContext):
                create_type_stmt = unit.getChild(0)
                # 对于create type语句
                if isinstance(create_type_stmt, PlSqlParser.Create_typeContext):
                    # 遍历兄弟节点,找到 `/` 进行删除
                    for j in range(i + 1, ctx.getChildCount()):
                        sibling = ctx.getChild(j)
                        if sibling.getText() == '/':
                            #print(f"Deleting: {sibling.getText()}")
                            self.rewriter.delete(
                                TokenStreamRewriter.DEFAULT_PROGRAM_NAME,
                                sibling.symbol.tokenIndex,
                                sibling.symbol.tokenIndex
                            )
                            break
        return self.visitChildren(ctx)

这里没有使用visitCreate_type的原因是,在antlr4生成的语法树中,/这个节点并不在Create_type这个节点的内部,所以得对Sql_script这个根节点,找到所有的单条语句Unit_statement,然后判断里面是不是有Create_type,然后再回头删掉/