根据上一篇,基于语法树的SQL自动改写工具开发系列(1)-离线安装语法树解析工具antlr4-DA-技术分享-M版,先在本地部署好开发环境。
DEMO 1
写一段期望改写的原始SQL,使用pygrun进行解析
比如在原生PG中不支持ORACLE中的table()函数,但有替代的unnest改写方式
ORACLE:
SELECT A FROM TABLE(VAR) T
PG:
SELECT A FROM (SELECT * FROM UNNEST(VAR) COLUMN_VALUE) T
要实现的改写规则为,找到table函数的整个表达式,并取出它的入参,将table函数表达式的节点替换成(SELECT * FROM UNNEST(入参) COLUMN_VALUE)
语法树:
C:\antlr>pygrun PlSql sql_script --tree
SELECT A FROM TABLE(VAR) T;
^Z
(sql_script
(unit_statement
(data_manipulation_language_statements
(select_statement
(select_only_statement
(subquery
(subquery_basic_elements
(query_block SELECT
(selected_list
(select_list_elements
(expression
(logical_expression
(unary_logical_expression
(multiset_expression
(relational_expression
(compound_expression
(concatenation
(model_expression
(unary_expression
(atom
(general_element
(general_element_part
(id_expression
(regular_id A))))))))))))))))
(from_clause FROM
(table_ref_list
(table_ref
(table_ref_aux
(table_ref_aux_internal
(dml_table_expression_clause
(table_collection_expression TABLE (
(expression
(logical_expression
(unary_logical_expression
(multiset_expression
(relational_expression
(compound_expression
(concatenation
(model_expression
(unary_expression
(atom
(general_element
(general_element_part
(id_expression
(regular_id VAR)))))))))))))) ))))
(table_alias
(identifier
(id_expression
(regular_id T))))))))))))))) ; <EOF>)
C:\antlr>
以下为一个完整且可以运行的demo代码,用于说明如何进行语法改写
from antlr4 import FileStream, CommonTokenStream
from PlSqlLexer import PlSqlLexer
from PlSqlParser import PlSqlParser
from PlSqlParserVisitor import PlSqlParserVisitor
from antlr4.TokenStreamRewriter import TokenStreamRewriter
class SQLTransformer(PlSqlParserVisitor):
def __init__(self, token_stream):
super().__init__()
self.rewriter = TokenStreamRewriter(token_stream)
def visitDml_table_expression_clause(self, ctx: PlSqlParser.Dml_table_expression_clauseContext):
if ctx.table_collection_expression() and ctx.table_collection_expression().getText().startswith('table'):
argument = ctx.table_collection_expression().expression().getText()
new_text = "(select * from unnest({}) column_value)".format(argument)
self.rewriter.replace(TokenStreamRewriter.DEFAULT_PROGRAM_NAME, ctx.table_collection_expression().start.tokenIndex, ctx.table_collection_expression().stop.tokenIndex, new_text)
return self.visitChildren(ctx)
def main(input_file, output_file):
input_stream = FileStream(input_file, encoding='utf-8')
lexer = PlSqlLexer(input_stream)
stream = CommonTokenStream(lexer)
parser = PlSqlParser(stream)
tree = parser.sql_script()
transformer = SQLTransformer(stream)
transformer.visit(tree)
output_text = transformer.rewriter.getDefaultText()
with open(output_file, 'w', encoding='utf-8',newline='') as f:
f.write(output_text)
if __name__ == '__main__':
input_file = 'input.sql'
output_file = 'output.sql'
main(input_file, output_file)
其中def main
内的代码基本可以固定,我们直接看class SQLTransformer
里的def visitDml_table_expression_clause
。visitDml_table_expression_clause
其实是PlSqlParserVisitor.py
里面的一个def,这个def的名称由visit
加上节点名组成,也就是说,语法树中的每一个节点,都有一个对应的visit。而原本PlSqlParserVisitor.py里的每个visit里面都是空的,直接就return出去了:
def visitDml_table_expression_clause(self, ctx:PlSqlParser.Dml_table_expression_clauseContext):
return self.visitChildren(ctx)
我们自己写的这个visit就是实现了里面的具体内容。
如果我们需要修改某个语法,可以从语法树中,找到这个语法相关的上下文的最小节点,以本文前面输出的语法树为例,就应该是 (table_collection_expression TABLE (
,所以理论上,我们再写个visitTable_collection_expression
就好了,但本文的demo代码是从visitDml_table_expression_clause
开始,是为了说明如何引用当前节点的下级节点
if ctx.table_collection_expression() and ctx.table_collection_expression().getText().startswith('table')
这句是一个判断,作用是,判断当前节点下,是否存在table_collection_expression
这个节点,我们可以去对比语法树,如果没有使用table函数,是不会有这个节点的;第二个条件就是,获取table_collection_expression
这个节点的文本,判断它是不是使用 table
开始。这里要注意,如果不写第一个条件,在没有table_collection_expression
节点时,对它执行getText
会报错。
argument = ctx.table_collection_expression().expression().getText()
这里是取出table_collection_expression
的下一个叫expression
的节点,对照语法树可以看到,虽然里面有很多层,但这个节点实际只包含VAR
这个文本,因此这里就可以得到argument="VAR"
,即前文例子中,table函数的入参。
new_text = "(select * from unnest({}) column_value)".format(argument)
这一句很好理解,就是格式化一个字符串,把argument的值替换{}
,得到(select * from unnest(VAR) column_value)
self.rewriter.replace(TokenStreamRewriter.DEFAULT_PROGRAM_NAME, ctx.table_collection_expression().start.tokenIndex, ctx.table_collection_expression().stop.tokenIndex, new_text)
这一句就是最关键的,self.rewriter
是前面定义的TokenStreamRewriter(token_stream)
,在TokenStreamRewriter
里面,可以支持对节点的替换、删除、增加等操作。TokenStreamRewriter.replace
有4个入参,分别为程序名,开始位置、结束位置、需要替换成的文本。
程序名一般固定使用TokenStreamRewriter.DEFAULT_PROGRAM_NAME
就行,如果期望一次解析,就能做多种替换,比如同时生成支持PG和MYSQL的两种语法,就可以在这里设置程序名,针对不同的程序名写不同的规则。
开始位置和结束位置,可以使用对应节点的start.tokenIndex
和stop.tokenIndex
。
至此,我们就完成了一个改写规则的开发。
DEMO 2
如果需要在一次语法树解析中就完成多种规则的执行,可以再添加几个def visit
,比如我们再针对create type语句来进行改写。
ORACLE:
CREATE OR REPLACE TYPE TY_TEST AS OBJECT(COL1 INT,COL2 VARCHAR(20));
PG:
CREATE TYPE TY_TEST AS (COL1 INT,COL2 VARCHAR(20));
改写规则为,对于create_type
的语法节点,将create or replace
改为create
,并且删除object
语法树:
C:\antlr>pygrun PlSql sql_script --tree
CREATE OR REPLACE TYPE TY_TEST AS OBJECT(COL1 INT,COL2 VARCHAR(20));
^Z
(sql_script
(unit_statement
(create_type CREATE OR REPLACE TYPE
(type_definition
(type_name
(id_expression
(regular_id TY_TEST)))
(object_type_def
(object_as_part AS OBJECT) (
(object_member_spec
(identifier
(id_expression
(regular_id COL1)))
(type_spec
(datatype
(native_datatype_element INT)))) ,
(object_member_spec
(identifier
(id_expression
(regular_id COL2)))
(type_spec
(datatype
(native_datatype_element VARCHAR)
(precision_part (
(numeric 20) ))))) ))))) ; <EOF>)
改写代码
def visitCreate_type(self, ctx:PlSqlParser.Create_typeContext):
# 检查并修改 'create or replace' 为 'create'
if ctx.getChild(0).getText() == 'create' and ctx.getChild(1).getText() == 'or' and ctx.getChild(2).getText() == 'replace':
# print(f"Modifying: {ctx.getChild(0).getText()} {ctx.getChild(1).getText()} {ctx.getChild(2).getText()}")
self.rewriter.replace(
TokenStreamRewriter.DEFAULT_PROGRAM_NAME,
ctx.getChild(0).symbol.tokenIndex,
ctx.getChild(2).symbol.tokenIndex,
'create'
)
object_as_part = ctx.type_definition().object_type_def().object_as_part()
#print(f"ctx.object_as_part.getText(): {object_type_def.object_as_part().getChild(1).getText()}")
# 删除 object关键字
if object_as_part.getChild(1).getText()=='OBJECT':
self.rewriter.delete(
TokenStreamRewriter.DEFAULT_PROGRAM_NAME,
object_as_part.getChild(1).symbol.tokenIndex,
object_as_part.getChild(1).symbol.tokenIndex
)
return self.visitChildren(ctx)
这里可以从(create_type CREATE OR REPLACE TYPE
看到,CREATE OR REPLACE TYPE
这一串都在create_type
这个节点上,可以通过getChild(n)
来取出中间的每一部分,而这每一部分的位置,则是通过symbol.tokenIndex
获取。
然后从(object_as_part AS OBJECT) (
中可以看到,我们期望删除的object是在object_as_part
这个节点的第二个字符串,因此使用了getChild(1)
。这里需要注意,由于object_as_part
这个节点是在create_type
这个节点的下面很多层,然后后面会要多次使用这个节点,所以可以定义一个object_as_part = ctx.type_definition().object_type_def().object_as_part()
,减少冗余代码。
找到OBJECT后,执行self.rewriter.delete
,就可以把OBJECT删掉了。
DEMO 3
在ORACLE的sql脚本中,create type应该以/
结尾,但是在OG中,则不能有/
,如果一个脚本文件里混合了多种语句,就会出现有的/
要删,有的不能删,因此我们可以写一个规则,将所有create type语句后面的/
删掉。
def visitSql_script(self, ctx:PlSqlParser.Sql_scriptContext):
for i in range(ctx.getChildCount() - 1):
unit = ctx.getChild(i)
if isinstance(unit, PlSqlParser.Unit_statementContext):
create_type_stmt = unit.getChild(0)
# 对于create type语句
if isinstance(create_type_stmt, PlSqlParser.Create_typeContext):
# 遍历兄弟节点,找到 `/` 进行删除
for j in range(i + 1, ctx.getChildCount()):
sibling = ctx.getChild(j)
if sibling.getText() == '/':
#print(f"Deleting: {sibling.getText()}")
self.rewriter.delete(
TokenStreamRewriter.DEFAULT_PROGRAM_NAME,
sibling.symbol.tokenIndex,
sibling.symbol.tokenIndex
)
break
return self.visitChildren(ctx)
这里没有使用visitCreate_type的原因是,在antlr4生成的语法树中,/
这个节点并不在Create_type
这个节点的内部,所以得对Sql_script
这个根节点,找到所有的单条语句Unit_statement
,然后判断里面是不是有Create_type
,然后再回头删掉/