Parcourir la source

use druid to parse and validate the sql

sunyj il y a 8 ans
Parent
commit
bf04ccc836

+ 15 - 3
kanban-console/src/main/java/com/uas/kanban/service/impl/PanelServiceImpl.java

@@ -1,5 +1,9 @@
 package com.uas.kanban.service.impl;
 
+import com.alibaba.druid.sql.ast.SQLStatement;
+import com.alibaba.druid.sql.ast.statement.SQLSelectStatement;
+import com.alibaba.druid.sql.dialect.oracle.parser.OracleStatementParser;
+import com.alibaba.druid.sql.dialect.oracle.visitor.OracleOutputVisitor;
 import com.uas.kanban.annotation.NotEmpty;
 import com.uas.kanban.base.BaseService;
 import com.uas.kanban.dao.DataSourceDao;
@@ -97,9 +101,6 @@ public class PanelServiceImpl extends BaseService<Panel> implements PanelService
 
     @Override
     public List<Map<String, Object>> validateSQL(@NotEmpty("panelCode") String panelCode, @NotEmpty("sql") String sql, Boolean replaceParameters) throws SQLException, OperationException {
-        if (sql.toLowerCase().matches("([\\s]*?update|delete|insert)[\\s]+?[\\s\\S]+?")) {
-            throw new OperationException("不支持 update, delete, insert 操作");
-        }
         Panel panel = panelDao.checkExist(panelCode);
         // 如果需要替换 sql 中的参数
         if (replaceParameters != null && replaceParameters) {
@@ -120,6 +121,17 @@ public class PanelServiceImpl extends BaseService<Panel> implements PanelService
             }
             sql = kanbanParser.replaceParameters(sql, parameters, true);
         }
+
+        // 利用 druid 检测 SQL 语法
+        List<SQLStatement> sqlStatements = new OracleStatementParser(sql).parseStatementList();
+        for (SQLStatement sqlStatement : sqlStatements) {
+            if (!(sqlStatement instanceof SQLSelectStatement)) {
+                StringBuilder out = new StringBuilder();
+                sqlStatement.accept(new OracleOutputVisitor(out));
+                throw new SQLException("不支持的操作:" + out);
+            }
+        }
+
         NewbieJdbcSupport jdbc = dataSourceManager.getJdbc(panel.getDataSourceCode());
         kanbanParser.checkCount(jdbc.getDataSource().getConnection(), sql);
         return jdbc.listMap(sql);