Hello World

吞风吻雨葬落日 欺山赶海踏雪径

0%

druid库的SQL解析功能

简介

SQL Parser是Druid的一个重要组成部分,Druid内置使用SQL Parser来实现防御SQL注入(WallFilter)、合并统计没有参数化的SQL(StatFilter的mergeSql)、SQL格式化、分库分表。

Druid SQL Parser的使用场景

  • MySql SQL全量统计
  • Hive/ODPS SQL执行安全审计
  • 分库分表SQL解析引擎
  • 数据库引擎的SQL Parser

Druid的sql parser是目前支持各种数据语法最完备的SQL Parser。目前对各种数据库的支持如下:

数据库 DML DDL
odps 完全支持 完全支持
mysql 完全支持 完全支持
postgresql 完全支持 完全支持
oracle 支持大部分 支持大部分
sql server 支持常用的 支持常用的ddl
db2 支持常用的 支持常用的ddl
hive 支持常用的 支持常用的ddl

Druid SQL Parser的代码结构

Druid SQL Parser分三个模块:

  • Parser
  • AST
  • Visitor

parser

parser是将输入文本转换为ast(抽象语法树),parser有包括两个部分,Parser和Lexer,其中Lexer实现词法分析,Parser实现语法分析。

AST

AST是Abstract Syntax Tree的缩写,也就是抽象语法树。AST是parser输出的结果。下面是获得抽象语法树的一个例子:

1
2
3
final String dbType = JdbcConstants.MYSQL; // 可以是ORACLE、POSTGRESQL、SQLSERVER、ODPS等
String sql = "select * from t";
List<SQLStatement> stmtList = SQLUtils.parseStatements(sql, dbType);

Visitor

Visitor是遍历AST的手段,是处理AST最方便的模式,Visitor是一个接口,有缺省什么都没做的实现VistorAdapter。

我们可以实现不同的Visitor来满足不同的需求,Druid内置提供了如下Visitor:

  • OutputVisitor用来把AST输出为字符串
  • WallVisitor 来分析SQL语意来防御SQL注入攻击
  • ParameterizedOutputVisitor用来合并未参数化的SQL进行统计
  • EvalVisitor 用来对SQL表达式求值
  • ExportParameterVisitor用来提取SQL中的变量参数
  • SchemaStatVisitor 用来统计SQL中使用的表、字段、过滤条件、排序表达式、分组表达式
  • SQL格式化 Druid内置了基于语义的SQL格式化功能

自定义Visitor

每种方言的Visitor都有一个缺省的VisitorAdapter,使得编写自定义的Visitor更方便。
https://github.com/alibaba/druid/wiki/SQL_Parser_Demo_visitor

方言

SQL-92、SQL-99等都是标准SQL,mysql/oracle/pg/sqlserver/odps等都是方言,也就是dialect。parser/ast/visitor都需要针对不同的方言进行特别处理。

SQL解析demo

场景: ODPS校验用户输入的SQL与密钥是否有效。
思路: 原ODPS的语句逻辑复杂,执行效率慢,所以转化成简单语句执行。因为ODPS的权限是到列的,所以需要保留语句中的列明。转换成 SELECT COUNT(*) FROM (SELECT [columns] FROM [table1] WHERE 1=2) UNION SELECT COUNT(*) FROM (SELECT [columns] FROM [table2] WHERE 1=2 ) UNION ... 这种模式执行。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
public String convertSimplestSql(String sql){

try{
DbType dbType = JdbcConstants.ODPS;
List<SQLStatement> stmtList = SQLUtils.parseStatements(sql, dbType);
SQLStatement stmt = stmtList.get(0);
SchemaStatVisitor statVisitor = SQLUtils.createSchemaStatVisitor(dbType);
stmt.accept(statVisitor);

List<String> tables = statVisitor.getTables().keySet().stream()
.map(TableStat.Name::getName)
.collect(Collectors.toList());

if(CollectionUtils.isEmpty(tables)){
return sql;
}

//获取表对应的所有列
Map<String,List<String>> tableColumnMap = new HashMap<>();

Collection<TableStat.Column> columns = statVisitor.getColumns();
for (TableStat.Column column : columns){
String table = column.getTable();
List<String> cols = tableColumnMap.get(table);
if(cols == null){
cols = new ArrayList<>();
}
cols.add(column.getName());
tableColumnMap.put(column.getTable(),cols);
}

//check *
for (Map.Entry<String, List<String>> entry : tableColumnMap.entrySet()){
if(entry.getValue().contains("*")){
entry.getValue().remove("*");
}
}

if(tableColumnMap.isEmpty()){
return sql;
}
StringBuilder newSql = new StringBuilder();
for (int i =0 ;i<tables.size();i++){
String table = tables.get(i);

List<String> cols = tableColumnMap.get(table);
if(CollectionUtils.isEmpty(cols)){
cols.add("*");
}

newSql.append(buildCountSql(table,cols));
if(i < tables.size() - 1){
newSql.append(" UNION ");
}
}
newSql.append(";");
return newSql.toString();
}catch (Exception e){
//ignore
logger.error("parse sql exception : sql = " + sql , e);
}
return sql;
}

private String buildCountSql(String table,List<String> columns){
StringBuilder sb = new StringBuilder("SELECT COUNT(*) FROM ( SELECT ");
sb.append(columns.stream().collect(Collectors.joining(",")));
sb.append(" FROM ").append(table);
sb.append(" WHERE 1=2 ) ");
return sb.toString();
}

官网

https://github.com/alibaba/druid/wiki