Browse Source

限制 sql 查询结果不可超过 10000行

sunyj 8 years ago
parent
commit
b575551fe5

+ 2 - 2
kanban-console/src/main/java/com/uas/kanban/service/impl/KanbanInstanceServiceImpl.java

@@ -38,7 +38,7 @@ import com.uas.kanban.util.CollectionUtils;
 import com.uas.kanban.util.ObjectUtils;
 import com.uas.kanban.util.StringUtils;
 
-import me.chyxion.jdbc.NewbieJdbc;
+import me.chyxion.jdbc.NewbieJdbcSupport;
 
 /**
  * 看板实例
@@ -378,7 +378,7 @@ public class KanbanInstanceServiceImpl extends BaseService<KanbanInstance> imple
 		// 解析模版
 		String templateContent = null;
 		try {
-			NewbieJdbc jdbc = dataSourceManager.getJdbc(template.getDataSourceCode());
+			NewbieJdbcSupport jdbc = dataSourceManager.getJdbc(template.getDataSourceCode());
 			templateContent = templateParser.parseXml(content, title, jdbc);
 		} catch (DocumentException e) {
 			throw new IllegalStateException("xml 解析出错", e);

+ 1 - 2
kanban-console/src/main/java/com/uas/kanban/support/DataSourceManager.java

@@ -13,7 +13,6 @@ import com.uas.kanban.annotation.NotEmpty;
 import com.uas.kanban.base.BaseDao;
 import com.uas.kanban.model.DataSource;
 
-import me.chyxion.jdbc.NewbieJdbc;
 import me.chyxion.jdbc.NewbieJdbcSupport;
 
 /**
@@ -41,7 +40,7 @@ public class DataSourceManager {
 	 * @return NewbieJdbc 对象
 	 * @throws SQLException
 	 */
-	public NewbieJdbc getJdbc(@NotEmpty("dataSourceCode") String dataSourceCode) throws SQLException {
+	public NewbieJdbcSupport getJdbc(@NotEmpty("dataSourceCode") String dataSourceCode) throws SQLException {
 		return getCache(dataSourceCode).getJdbc();
 	}
 

+ 120 - 16
kanban-console/src/main/java/com/uas/kanban/support/TemplateParser.java

@@ -3,6 +3,10 @@ package com.uas.kanban.support;
 import java.io.IOException;
 import java.io.InputStream;
 import java.io.StringReader;
+import java.sql.Connection;
+import java.sql.PreparedStatement;
+import java.sql.ResultSet;
+import java.sql.SQLException;
 import java.util.ArrayList;
 import java.util.Date;
 import java.util.HashMap;
@@ -21,6 +25,8 @@ import org.dom4j.DocumentException;
 import org.dom4j.Element;
 import org.dom4j.io.SAXReader;
 import org.dom4j.tree.DefaultElement;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
 import org.springframework.stereotype.Component;
 
 import com.alibaba.druid.util.StringUtils;
@@ -29,7 +35,7 @@ import com.uas.kanban.model.GlobalParameter;
 import com.uas.kanban.util.ArrayUtils;
 import com.uas.kanban.util.CollectionUtils;
 
-import me.chyxion.jdbc.NewbieJdbc;
+import me.chyxion.jdbc.NewbieJdbcSupport;
 
 /**
  * @author sunyj
@@ -38,6 +44,13 @@ import me.chyxion.jdbc.NewbieJdbc;
 @Component
 public class TemplateParser {
 
+	/**
+	 * 查询条件的结果数目限制
+	 */
+	private static final int MAX_RECORD_SIZE = 10000;
+
+	private Logger logger = LoggerFactory.getLogger(getClass());
+
 	/**
 	 * 替换模版中的参数为实际值
 	 * 
@@ -92,14 +105,16 @@ public class TemplateParser {
 	 * @param title
 	 *            标题
 	 * @param jdbc
-	 *            NewbieJdbc对象
+	 *            NewbieJdbcSupport 对象
 	 * @return 解析后的 json 数据
 	 * @throws DocumentException
 	 * @throws IOException
 	 * @throws TransformerException
+	 * @throws SQLException
+	 * @throws IllegalStateException
 	 */
-	public String parseXml(@NotEmpty("content") String content, String title, @NotEmpty("jdbc") NewbieJdbc jdbc)
-			throws DocumentException, TransformerException, IOException {
+	public String parseXml(@NotEmpty("content") String content, String title, @NotEmpty("jdbc") NewbieJdbcSupport jdbc)
+			throws DocumentException, TransformerException, IOException, IllegalStateException, SQLException {
 		content = processSql(content);
 		content = processForm(content, jdbc);
 		content = processGrid(content, jdbc);
@@ -137,13 +152,15 @@ public class TemplateParser {
 	 * @param content
 	 *            模版内容
 	 * @param jdbc
-	 *            NewbieJdbc对象
+	 *            NewbieJdbcSupport 对象
 	 * @return 解析后的模版内容
 	 * @throws DocumentException
+	 * @throws SQLException
+	 * @throws IllegalStateException
 	 */
 	@SuppressWarnings("unchecked")
-	private String processForm(@NotEmpty("content") String content, @NotEmpty("jdbc") NewbieJdbc jdbc)
-			throws DocumentException {
+	private String processForm(@NotEmpty("content") String content, @NotEmpty("jdbc") NewbieJdbcSupport jdbc)
+			throws DocumentException, IllegalStateException, SQLException {
 		Document document = getDocument(content);
 		// 获取 form 组件
 		List<Element> elements = document.selectNodes("//form");
@@ -155,6 +172,7 @@ public class TemplateParser {
 			if (CollectionUtils.isEmpty(fieldElements)) {
 				continue;
 			}
+			checkCount(jdbc.getDataSource().getConnection(), sql);
 			Map<String, Object> map = jdbc.findMap(sql);
 			if (CollectionUtils.isEmpty(map)) {
 				continue;
@@ -175,13 +193,15 @@ public class TemplateParser {
 	 * @param content
 	 *            模版内容
 	 * @param jdbc
-	 *            NewbieJdbc对象
+	 *            NewbieJdbcSupport 对象
 	 * @return 解析后的模版内容
 	 * @throws DocumentException
+	 * @throws SQLException
+	 * @throws IllegalStateException
 	 */
 	@SuppressWarnings("unchecked")
-	private String processGrid(@NotEmpty("content") String content, @NotEmpty("jdbc") NewbieJdbc jdbc)
-			throws DocumentException {
+	private String processGrid(@NotEmpty("content") String content, @NotEmpty("jdbc") NewbieJdbcSupport jdbc)
+			throws DocumentException, IllegalStateException, SQLException {
 		Document document = getDocument(content);
 		// 获取 form 组件
 		List<Element> elements = document.selectNodes("//grid");
@@ -202,6 +222,7 @@ public class TemplateParser {
 					fieldNames.add(fieldName);
 				}
 			}
+			checkCount(jdbc.getDataSource().getConnection(), sql);
 			List<Map<String, Object>> listMap = jdbc.listMap(sql);
 			if (CollectionUtils.isEmpty(listMap)) {
 				continue;
@@ -227,19 +248,22 @@ public class TemplateParser {
 	 * @param content
 	 *            模版内容
 	 * @param jdbc
-	 *            NewbieJdbc对象
+	 *            NewbieJdbcSupport 对象
 	 * @return 解析后的模版内容
 	 * @throws DocumentException
+	 * @throws SQLException
+	 * @throws IllegalStateException
 	 */
 	@SuppressWarnings("unchecked")
-	private String processBarAndLine(@NotEmpty("content") String content, @NotEmpty("jdbc") NewbieJdbc jdbc)
-			throws DocumentException {
+	private String processBarAndLine(@NotEmpty("content") String content, @NotEmpty("jdbc") NewbieJdbcSupport jdbc)
+			throws DocumentException, IllegalStateException, SQLException {
 		Document document = getDocument(content);
 		// 获取 bar 和 line 组件
 		List<Element> elements = document.selectNodes("//bar | //line");
 		for (Element element : elements) {
 			// 执行 sql ,获取数据
 			String sql = element.attribute("sql").getText();
+			checkCount(jdbc.getDataSource().getConnection(), sql);
 			Map<String, List<Object>> map = convert(jdbc.listMap(sql));
 			if (CollectionUtils.isEmpty(map)) {
 				continue;
@@ -285,19 +309,22 @@ public class TemplateParser {
 	 * @param content
 	 *            模版内容
 	 * @param jdbc
-	 *            NewbieJdbc对象
+	 *            NewbieJdbcSupport 对象
 	 * @return 解析后的模版内容
 	 * @throws DocumentException
+	 * @throws SQLException
+	 * @throws IllegalStateException
 	 */
 	@SuppressWarnings("unchecked")
-	private String processPie(@NotEmpty("content") String content, @NotEmpty("jdbc") NewbieJdbc jdbc)
-			throws DocumentException {
+	private String processPie(@NotEmpty("content") String content, @NotEmpty("jdbc") NewbieJdbcSupport jdbc)
+			throws DocumentException, IllegalStateException, SQLException {
 		Document document = getDocument(content);
 		// 获取 bar 和 line 组件
 		List<Element> elements = document.selectNodes("//pie");
 		for (Element element : elements) {
 			// 执行 sql ,获取数据
 			String sql = element.attribute("sql").getText();
+			checkCount(jdbc.getDataSource().getConnection(), sql);
 			List<Map<String, Object>> listMap = jdbc.listMap(sql);
 			if (CollectionUtils.isEmpty(listMap)) {
 				continue;
@@ -328,6 +355,83 @@ public class TemplateParser {
 		return new SAXReader().read(new StringReader(xml));
 	}
 
+	/**
+	 * 检查当前条件下的结果数目是否超出限制
+	 * 
+	 * @param connection
+	 * @param sql
+	 * @throws SQLException
+	 * @throws IllegalStateException
+	 */
+	private void checkCount(@NotEmpty("connection") Connection connection, @NotEmpty("sql") String sql)
+			throws SQLException, IllegalStateException {
+		int count = getCount(connection, sql);
+		if (count > MAX_RECORD_SIZE) {
+			String message = "查询条件的结果数目超出限制:" + count;
+			throw new IllegalStateException(message, new IllegalStateException("sql : " + sql));
+		}
+	}
+
+	/**
+	 * 获取当前查询语句的结果数目
+	 * 
+	 * @param connection
+	 * @param sql
+	 * @return
+	 * @throws SQLException
+	 */
+	private int getCount(@NotEmpty("connection") Connection connection, @NotEmpty("sql") String sql)
+			throws SQLException {
+		PreparedStatement preparedStatement = null;
+		ResultSet resultSet = null;
+		try {
+			// 如果直接在 sql 外用 count(1) 统计数目,当关联表多时,可能会出现错误
+			// ORA-01792: 表或视图中的最大列数为 1000
+			// 报错主要发生在 select * 的情况下,但是不能这样判断,因为可能存在 select tt.*, pi_id from
+			// purchase t left join 这样的情况,很难区分
+			// 因此 1. 对于普通 sql ,将 select 后的字段改为 count(1)
+			// 2. 而最外层含有 group by 的 sql ,直接改为 count(1) 可能得到不止一行,结果也并非实际行数。再加上
+			// group
+			// by 的结果列数一般很小,所以可以在外面使用 count(1) ,一般不会超出 1000 行
+			String lowerSql = sql.toLowerCase();
+			if (!lowerSql.matches("[\\s\\S]+?group[\\s]+?by[\\s]+?[^)]+?")) {
+				String regex = "([\\s\\S]+?from)[\\s]+?[^,]+?";
+				Pattern pattern = Pattern.compile(regex);
+				Matcher matcher = pattern.matcher(lowerSql);
+				if (matcher.find()) {
+					int start = matcher.start(1);
+					int end = matcher.end(1);
+					sql = sql.substring(0, start) + "select count(1) from" + sql.substring(end);
+				} else {
+					throw new IllegalStateException("sql 解析错误:未发现第一个 from");
+				}
+			} else {
+				sql = "select count(1) from (" + sql + ")";
+			}
+			preparedStatement = connection.prepareStatement(sql);
+			resultSet = preparedStatement.executeQuery();
+			resultSet.next();
+			int count = resultSet.getInt(1);
+			return count;
+		} finally {
+			if (resultSet != null) {
+				try {
+					resultSet.close();
+				} catch (SQLException e) {
+					logger.error("", e);
+				}
+			}
+			if (preparedStatement != null) {
+				try {
+					preparedStatement.close();
+				} catch (SQLException e) {
+					logger.error("", e);
+				}
+			}
+			connection.close();
+		}
+	}
+
 	/**
 	 * 将 List<Map<String, Object>> 转为 Map<String, List<Object>>
 	 *