sunyj 8 лет назад
Родитель
Сommit
190a332c16

+ 18 - 104
src/main/java/com/uas/search/service/impl/IndexServiceImpl.java

@@ -14,6 +14,8 @@ import com.uas.search.jms.JmsListener;
 import com.uas.search.jms.QueueMessageParser;
 import com.uas.search.model.*;
 import com.uas.search.service.IndexService;
+import com.uas.search.support.DownloadHelper;
+import com.uas.search.support.DownloadService;
 import com.uas.search.support.IndexSearcherManager;
 import com.uas.search.support.IndexWriterManager;
 import com.uas.search.util.FileUtils;
@@ -95,7 +97,7 @@ public class IndexServiceImpl implements IndexService {
 	/**
 	 * 从数据库获取数据时的分页大小
 	 */
-	private static final int PAGE_SIZE = 1000;
+	public static final int PAGE_SIZE = 1000;
 
 	/**
 	 * 单个文件存储的最大数据数目,需是PAGE_SIZE的整数倍
@@ -490,111 +492,23 @@ public class IndexServiceImpl implements IndexService {
         if (threads > druidDBConfiguration.getMaxActive()) {
             throw new IllegalArgumentException("线程数量不可超过 " + druidDBConfiguration.getMaxActive());
         }
-        startFileIndex = startFileIndex == null || startFileIndex < 1 ? 1 : startFileIndex;
-        if (startFileIndex == 1 && endFileIndex == null) {
-            // 删除旧的文件
-            FileUtils.deleteSubFiles(new File(SearchUtils.getDataPath(tableName)));
-        }
-        endFileIndex = endFileIndex == null || endFileIndex < 1 ? 1024 * 1024 * 1024 : endFileIndex;
-        for (int i = 1; i <= threads; i++) {
-            if (tableName.equals(COMPONENT_TABLE_NAME)) {
-                new Thread(new DownloadComponentTread(i, threads, startFileIndex + i - 1, endFileIndex)).start();
-            } else if (tableName.equals(GOODS_TABLE_NAME)) {
-                new Thread(new DownloadGoodsTread(i, threads, startFileIndex + i - 1, endFileIndex)).start();
-            } else {
-                throw new IllegalArgumentException("多线程下载不支持该表:" + tableName);
-            }
-        }
-    }
-
-    /**
-     * 下载器件的线程
-     */
-    private class DownloadComponentTread implements Runnable {
-
-        /**
-         * 线程名称
-         */
-        private int id;
-
-        /**
-         * 新增文件时,文件 id 的自增步长,(即线程数量)
-         */
-        private int step;
-
-        /**
-         * 开始的文件
-         */
-        private int startFileIndex;
-
-        /**
-         * 开始的文件
-         */
-        private int endFileIndex;
-
-        public DownloadComponentTread(int id, int step, int startFileIndex, int endFileIndex) {
-            this.id = id;
-            this.step = step;
-            this.startFileIndex = startFileIndex;
-            this.endFileIndex = endFileIndex;
-        }
-
-        @Override
-        public void run() {
-			String name = "Thread-" + id;
-            try {
-				if (endFileIndex < startFileIndex) {
-                    logger.error(name + " fileIndex 不可超过 : endFileIndex=" + endFileIndex);
-                    return;
-                }
-                Long startTime = new Date().getTime();
-                logger.info(name + " 下载器件... ");
-
-                Sort sort = new Sort(Sort.Direction.ASC, "id");
-                // 分页获取数据
-                PageParams pageParams = new PageParams();
-                pageParams.setPage(startFileIndex);
-                pageParams.setSize(PAGE_SIZE);
-                PageInfo pageInfo = new PageInfo(pageParams, sort);
-                Page<Component> pageResult = componentDao.findAll(pageInfo);
-				if(id == 1){
-					logger.info(name + " 发现数据 " + pageResult.getTotalElements() + " 条");
-				}
-
-                int totalPages = pageResult.getTotalPages();
-                if (totalPages < startFileIndex) {
-                    logger.error(name + " fileIndex 不可超过 : totalPages=" + totalPages);
-                    return;
-                }
-                // 已翻页的数据数目
-                Long size = 0L;
-                String goodsDataPath = SearchUtils.getDataPath(COMPONENT_TABLE_NAME);
-                File file = new File(goodsDataPath);
-                if (!file.exists()) {
-                    file.mkdirs();
-                }
-                while (totalPages >= startFileIndex && endFileIndex >= startFileIndex) {
-                    String componentFileName = String.format("%010d", startFileIndex) + ".txt";
-                    PrintWriter printWriter = new PrintWriter(goodsDataPath + "/" + componentFileName);
-                    List<Component> content = pageResult.getContent();
-                    for (Component element : content) {
-                        printWriter.println(JSONObject.toJSONString(element));
-                    }
-                    size += content.size();
-                    logger.info(name + " " + componentFileName + " - Downloaded..................." + size);
 
-                    printWriter.flush();
-                    printWriter.close();
-                    startFileIndex += step;
-                    pageParams.setPage(startFileIndex);
-                    pageInfo = new PageInfo(pageParams, sort);
-                    pageResult = componentDao.findAll(pageInfo);
-                }
-
-                logger.info(String.format("%s 下载完成,耗时%.2fs\n ", name, (new Date().getTime() - startTime) / 1000.0));
-            } catch (Throwable e) {
-                logger.error(name + " 器件下载失败", e);
+        if (tableName.equals(COMPONENT_TABLE_NAME)) {
+            DownloadHelper<Component> downloadHelper = new DownloadHelper<>(threads, startFileIndex, endFileIndex, tableName, "id",componentDao, new DownloadService<Component>());
+            long result = downloadHelper.getResult();
+            logger.info("totalSize = " + result);
+        } else if (tableName.equals(GOODS_TABLE_NAME)) {
+            startFileIndex = startFileIndex == null || startFileIndex < 1 ? 1 : startFileIndex;
+            if (startFileIndex == 1 && endFileIndex == null) {
+                // 删除旧的文件
+                FileUtils.deleteSubFiles(new File(SearchUtils.getDataPath(tableName)));
             }
+            endFileIndex = endFileIndex == null || endFileIndex < 1 ? 1024 * 1024 * 1024 : endFileIndex;
+            for (int i = 1; i <= threads; i++) {
+                new Thread(new DownloadGoodsTread(i, threads, startFileIndex + i - 1, endFileIndex)).start();
+            }
+        } else {
+            throw new IllegalArgumentException("多线程下载不支持该表:" + tableName);
         }
     }
 

+ 186 - 0
src/main/java/com/uas/search/support/DownloadHelper.java

@@ -0,0 +1,186 @@
+package com.uas.search.support;
+
+import com.uas.search.util.FileUtils;
+import com.uas.search.util.SearchUtils;
+import org.springframework.data.jpa.repository.JpaRepository;
+
+import java.io.File;
+import java.util.concurrent.*;
+
+
+/**
+ * 下载数据的辅助类
+ *
+ * @author sunyj
+ * @since 2017/11/25 17:30
+ */
+public class DownloadHelper<T> {
+
+    /**
+     * 线程最小数量
+     */
+    private final int MIN_THREAD_SIZE = 1;
+
+    /**
+     * 线程最大数量
+     */
+    private final int MAX_THREAD_SIZE = 10000;
+
+    /**
+     * 默认开始的文件
+     */
+    private final int DEFAULT_START_FILE_INDEX = 1;
+
+    /**
+     * 默认结束的文件
+     */
+    private final int DEFAULT_END_FILE_INDEX = 1024 * 1024 * 1024;
+
+    /**
+     * 线程数量
+     */
+    private Integer threadSize;
+
+    /**
+     * 开始的文件
+     */
+    private Integer startFileIndex;
+
+    /**
+     * 结束的文件
+     */
+    private Integer endFileIndex;
+
+    /**
+     * 要下载的表
+     */
+    private String tableName;
+
+    /**
+     * 排序字段
+     */
+    private String sortField;
+
+    /**
+     * dao
+     */
+    private JpaRepository<T, Long> dao;
+
+    /**
+     * 下载的实现
+     */
+    private DownloadService<T> downloadService;
+
+    /**
+     * 线程管理
+     */
+    private ExecutorService executorService;
+
+    /**
+     * 收集执行结果
+     */
+    private CompletionService<Long> completionService;
+
+    /**
+     * 执行结果
+     */
+    private Long result;
+
+    /**
+     * @param threadSize      线程数量
+     * @param startFileIndex  开始的文件
+     * @param endFileIndex    结束的文件
+     * @param tableName       要下载的表
+     * @param sortField       排序字段
+     * @param dao             dao
+     * @param downloadService 下载的实现
+     */
+    public DownloadHelper(Integer threadSize, Integer startFileIndex, Integer endFileIndex, String tableName, String sortField, JpaRepository<T, Long> dao, DownloadService<T> downloadService) {
+        if (threadSize == null || threadSize < MIN_THREAD_SIZE || threadSize > MAX_THREAD_SIZE) {
+            throw new IllegalArgumentException("threadSize is between " + MIN_THREAD_SIZE + " and " + MAX_THREAD_SIZE);
+        }
+        if (downloadService == null) {
+            throw new IllegalArgumentException("runnable is null");
+        }
+        this.threadSize = threadSize;
+        this.downloadService = downloadService;
+        this.startFileIndex = startFileIndex == null || startFileIndex < DEFAULT_START_FILE_INDEX ? DEFAULT_START_FILE_INDEX : startFileIndex;
+        this.endFileIndex = endFileIndex == null || endFileIndex < DEFAULT_START_FILE_INDEX ? DEFAULT_END_FILE_INDEX : endFileIndex;
+        this.tableName = tableName;
+        this.sortField = sortField;
+        this.dao = dao;
+        start();
+    }
+
+
+    /**
+     * 开始下载
+     */
+    private void start() {
+        executorService = Executors.newCachedThreadPool();
+        completionService = new ExecutorCompletionService<>(executorService);
+        if (startFileIndex == DEFAULT_START_FILE_INDEX && endFileIndex == DEFAULT_END_FILE_INDEX) {
+            // 删除旧的文件
+            FileUtils.deleteSubFiles(new File(SearchUtils.getDataPath(tableName)));
+        }
+        for (int i = 0; i < threadSize; i++) {
+            completionService.submit(getTask(i, threadSize, startFileIndex + i, endFileIndex, tableName, sortField, dao));
+        }
+        waitResult();
+    }
+
+
+    /**
+     * 获取任务
+     *
+     * @param id             线程 id
+     * @param step           新增文件时,文件 id 的自增步长,(即线程数量)
+     * @param startFileIndex 开始的文件
+     * @param endFileIndex   结束的文件
+     * @param tableName      要下载的表
+     * @param sortField      排序字段
+     * @param dao            dao
+     * @return 任务
+     */
+    private Callable<Long> getTask(final int id, final int step, final int startFileIndex, final int endFileIndex, final String tableName, final String sortField, final JpaRepository<T, Long> dao) {
+        return new Callable<Long>() {
+            @Override
+            public Long call() throws Exception {
+                return downloadService.download(id, step, startFileIndex, endFileIndex, tableName, sortField, dao);
+            }
+        };
+    }
+
+    /**
+     * 等待执行结果
+     */
+    private void waitResult() {
+        if (executorService.isShutdown() || executorService.isTerminated()) {
+            throw new IllegalStateException("结果已返回,不可再次获取");
+        }
+        result = 0L;
+        for (int i = 0; i < threadSize; i++) {
+            try {
+                Future<Long> future = completionService.take();
+                Long count = future.get();
+                result += (count == null ? 0L : count);
+            } catch (InterruptedException | ExecutionException e) {
+                throw new IllegalStateException("获取下载结果失败", e);
+            }
+        }
+        executorService.shutdown();
+    }
+
+    /**
+     * 获取执行结果
+     *
+     * @return 下载的总数量
+     */
+    public long getResult() {
+        if (result == null) {
+            waitResult();
+        }
+        return result;
+    }
+
+}

+ 98 - 0
src/main/java/com/uas/search/support/DownloadService.java

@@ -0,0 +1,98 @@
+package com.uas.search.support;
+
+import com.alibaba.fastjson.JSONObject;
+import com.uas.search.constant.model.PageInfo;
+import com.uas.search.constant.model.PageParams;
+import com.uas.search.service.impl.IndexServiceImpl;
+import com.uas.search.util.SearchUtils;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+import org.springframework.data.domain.Page;
+import org.springframework.data.domain.Sort;
+import org.springframework.data.jpa.repository.JpaRepository;
+
+import java.io.File;
+import java.io.PrintWriter;
+import java.util.Date;
+import java.util.List;
+
+/**
+ * 下载线程
+ *
+ * @author sunyj
+ * @since 2017/11/25 18:24
+ */
+public class DownloadService<T> {
+
+    /**
+     * @param id             线程 id
+     * @param step           新增文件时,文件 id 的自增步长,(即线程数量)
+     * @param startFileIndex 开始的文件
+     * @param endFileIndex   结束的文件
+     * @param tableName      要下载的表
+     * @param sortField      排序字段
+     * @param dao            dao
+     * @return 下载的数量
+     */
+    public long download(int id, int step, int startFileIndex, int endFileIndex, String tableName, String sortField, JpaRepository<T, Long> dao) {
+        long size = 0L;
+        Logger logger = LoggerFactory.getLogger(getClass());
+        String name = "Thread-" + id;
+        try {
+            if (endFileIndex < startFileIndex) {
+                logger.error(name + " fileIndex 不可超过 : endFileIndex=" + endFileIndex);
+                return size;
+            }
+            Long startTime = new Date().getTime();
+            logger.info(name + " 下载" + tableName + "... ");
+
+            Sort sort = new Sort(Sort.Direction.ASC, sortField);
+            // 分页获取数据
+            PageParams pageParams = new PageParams();
+            pageParams.setPage(startFileIndex);
+            pageParams.setSize(IndexServiceImpl.PAGE_SIZE);
+            PageInfo pageInfo = new PageInfo(pageParams, sort);
+            Page<T> pageResult = dao.findAll(pageInfo);
+            if (id == 1) {
+                logger.info(name + " 发现数据 " + pageResult.getTotalElements() + " 条");
+            }
+
+            int totalPages = pageResult.getTotalPages();
+            if (totalPages < startFileIndex) {
+                logger.error(name + " fileIndex 不可超过 : totalPages=" + totalPages);
+                return size;
+            }
+            String dataPath = SearchUtils.getDataPath(tableName);
+            File dataDir = new File(dataPath);
+            if (!dataDir.exists()) {
+                dataDir.mkdirs();
+            }
+            while (totalPages >= startFileIndex && endFileIndex >= startFileIndex) {
+                if (Math.random() > 0.5) {
+                    throw new IllegalStateException("随机错误");
+                }
+                String fileName = String.format("%010d", startFileIndex) + ".txt";
+                PrintWriter printWriter = new PrintWriter(dataPath + "/" + fileName);
+                List<T> content = pageResult.getContent();
+                for (T element : content) {
+                    printWriter.println(JSONObject.toJSONString(element));
+                }
+                size += content.size();
+                logger.info(name + " " + fileName + " - Downloaded..................." + size);
+
+                printWriter.flush();
+                printWriter.close();
+                startFileIndex += step;
+                pageParams.setPage(startFileIndex);
+                pageInfo = new PageInfo(pageParams, sort);
+                pageResult = dao.findAll(pageInfo);
+            }
+
+            logger.info(String.format("%s 下载完成,耗时%.2fs\n ", name, (new Date().getTime() - startTime) / 1000.0));
+        } catch (Throwable e) {
+            logger.error(name + " " + tableName + "下载失败", e);
+        }
+        return size;
+    }
+
+}