Java 多線程事務回滾 ——多線程插入數據庫時事務控制()


背景

        日常項目中,經常會出現一個場景,同時批量插入數據庫數據,由於邏輯復雜或者其它原因,我們無法使用sql進行批量插入。串行效率低,耗時長,為了提高效率,這個時候我們首先想到多線程並發插入,但是如何控制事務呢 … 直接上干貨

實現效果

  •  開啟多條子線程,並發插入數據庫

  •  當其中一條線程出現異常,或者處理結果為非預期結果,則全部線程均回滾

代碼實現

@Service
public class CompanyUserBatchServiceImpl implements CompanyUserBatchService {
    private static final Logger logger = LoggerFactory.getLogger(CompanyUserBatchServiceImpl.class);

    @Autowired
    private CompanyUserService companyUserService;

    @Override
    public ReturnData addNewCurrentCompanyUsers(String params) {
        logger.info("addNewCompanyUsers 新增參保人方法");
        logger.info(">>>>>>>>>>>>參數:{}", params);
        ReturnData rd = new ReturnData();
        rd.setRetCode(CommonConstants.RETURN_CODE_FAIL);
        if (StringUtils.isBlank(params)) {
            rd.setMsg("入參為空!");
            logger.info(">>>>>>入參為空。");
            return rd;
        }

        List<CompanyUserResultVo> companyUsers;
        try {
            companyUsers = JSONObject.parseArray(params, CompanyUserResultVo.class);
        } catch (Exception e) {
            logger.info(">>>>>>>>>入參格式有誤: {}", e);
            rd.setMsg("入參格式有誤!");
            return rd;
        }


        //每條線程最小處理任務數
        int perThreadHandleCount = 1;
        //線程池的最大線程數
        int nThreads = 10;
        int taskSize = companyUsers.size();

        if (taskSize > nThreads * perThreadHandleCount) {
            perThreadHandleCount = taskSize % nThreads == 0 ? taskSize / nThreads : taskSize / nThreads + 1;
            nThreads = taskSize % perThreadHandleCount == 0 ? taskSize / perThreadHandleCount : taskSize / perThreadHandleCount + 1;
        } else {
            nThreads = taskSize;
        }

        logger.info("批量添加參保人taskSize: {}, perThreadHandleCount: {}, nThreads: {}", taskSize, perThreadHandleCount, nThreads);
        CountDownLatch mainLatch = new CountDownLatch(1);
        //監控子線程
        CountDownLatch threadLatch = new CountDownLatch(nThreads);
        //根據子線程執行結果判斷是否需要回滾
        BlockingDeque<Boolean> resultList = new LinkedBlockingDeque<>(nThreads);
        //必須要使用對象,如果使用變量會造成線程之間不可共享變量值
        RollBack rollBack = new RollBack(false);
        ExecutorService fixedThreadPool = Executors.newFixedThreadPool(nThreads);

        List<Future<List<Object>>> futures = Lists.newArrayList();
        List<Object> returnDataList = Lists.newArrayList();
        //給每個線程分配任務
        for (int i = 0; i < nThreads; i++) {
            int lastIndex = (i + 1) * perThreadHandleCount;
            List<CompanyUserResultVo> companyUserResultVos = companyUsers.subList(i * perThreadHandleCount, lastIndex >= taskSize ? taskSize : lastIndex);
            AddNewCompanyUserThread addNewCompanyUserThread = new AddNewCompanyUserThread(mainLatch, threadLatch, rollBack, resultList, companyUserResultVos);
            Future<List<Object>> future = fixedThreadPool.submit(addNewCompanyUserThread);
            futures.add(future);
        }

        /** 存放子線程返回結果. */
        List<Boolean> backUpResult = Lists.newArrayList();
        try {
            //等待所有子線程執行完畢
            boolean await = threadLatch.await(20, TimeUnit.SECONDS);
            //如果超時,直接回滾
            if (!await) {
                rollBack.setRollBack(true);
            } else {
                logger.info("創建參保人子線程執行完畢,共 {} 個線程", nThreads);
                //查看執行情況,如果有存在需要回滾的線程,則全部回滾
                for (int i = 0; i < nThreads; i++) {
                    Boolean result = resultList.take();
                    backUpResult.add(result);
                    logger.debug("子線程返回結果result: {}", result);
                    if (result) {
                        /** 有線程執行異常,需要回滾子線程. */
                        rollBack.setRollBack(true);
                    }
                }
            }
        } catch (InterruptedException e) {
            logger.error("等待所有子線程執行完畢時,出現異常");
            throw new SystemException("等待所有子線程執行完畢時,出現異常,整體回滾");
        } finally {
            //子線程再次開始執行
            mainLatch.countDown();
            logger.info("關閉線程池,釋放資源");
            fixedThreadPool.shutdown();
        }

        /** 檢查子線程是否有異常,有異常整體回滾. */
        for (int i = 0; i < nThreads; i++) {
            if (CollectionUtils.isNotEmpty(backUpResult)) {
                Boolean result = backUpResult.get(i);
                if (result) {
                    logger.info("創建參保人失敗,整體回滾");
                    throw new SystemException("創建參保人失敗");
                }
            } else {
                logger.info("創建參保人失敗,整體回滾");
                throw new SystemException("創建參保人失敗");
            }
        }

        //拼接結果
        try {
            for (Future<List<Object>> future : futures) {
                returnDataList.addAll(future.get());
            }
        } catch (Exception e) {
            logger.info("獲取子線程操作結果出現異常,創建的參保人列表: {} ,異常信息: {}", JSONObject.toJSONString(companyUsers), e);
            throw new SystemException("創建參保人子線程正常創建參保人成功,主線程出現異常,回滾失敗");
        }

        rd.setRetCode(CommonConstants.RETURN_CODE_SUCCESS);
        rd.setData(returnDataList);
        return rd;
    }

    public class AddNewCompanyUserThread implements Callable<List<Object>> {
        /**
         * 主線程監控
         */
        private CountDownLatch mainLatch;
        /**
         * 子線程監控
         */
        private CountDownLatch threadLatch;
        /**
         * 是否回滾
         */
        private RollBack rollBack;
        private BlockingDeque<Boolean> resultList;
        private List<CompanyUserResultVo> taskList;

        public AddNewCompanyUserThread(CountDownLatch mainLatch, CountDownLatch threadLatch, RollBack rollBack, BlockingDeque<Boolean> resultList, List<CompanyUserResultVo> taskList) {
            this.mainLatch = mainLatch;
            this.threadLatch = threadLatch;
            this.rollBack = rollBack;
            this.resultList = resultList;
            this.taskList = taskList;
        }

        @Override
        public List<Object> call() {
            //為了保證事務不提交,此處只能調用一個有事務的方法,spring 中事務的顆粒度是方法,只有方法不退出,事務才不會提交
            return companyUserService.addNewCompanyUsers(mainLatch, threadLatch, rollBack, resultList, taskList);
        }

    }

    public class RollBack {
        private Boolean isRollBack;

        public Boolean getRollBack() {
            return isRollBack;
        }

        public void setRollBack(Boolean rollBack) {
            isRollBack = rollBack;
        }

        public RollBack(Boolean isRollBack) {
            this.isRollBack = isRollBack;
        }
    }

public List<Object> addNewCompanyUsers(CountDownLatch mainLatch, CountDownLatch threadLatch, CompanyUserBatchServiceImpl.RollBack rollBack, BlockingDeque<Boolean> resultList, List<CompanyUserResultVo> taskList) {
        List<Object> returnDataList = Lists.newArrayList();
        Boolean result = false;
        logger.debug("線程: {}創建參保人條數 : {}", Thread.currentThread().getName(), taskList.size());
        try {
            for (CompanyUserResultVo companyUserResultVo : taskList) {
                ReturnData returnData = addSingleCompanyUser(companyUserResultVo);
                if (returnData.getRetCode() == CommonConstants.RETURN_CODE_FAIL) {
                    result = true;
                }
                returnDataList.add(returnData.getData());
            }
            //Exception 和 Error 都需要抓
        } catch (Throwable throwable) {
            throwable.printStackTrace();
            logger.info("線程: {}創建參保人出現異常: {} ", Thread.currentThread().getName(), throwable);
            result = true;
        }

        resultList.add(result);
        threadLatch.countDown();
        logger.info("子線程 {} 計算過程已經結束,等待主線程通知是否需要回滾", Thread.currentThread().getName());

        try {
            mainLatch.await();
            logger.info("子線程 {} 再次啟動", Thread.currentThread().getName());
        } catch (InterruptedException e) {
            logger.error("批量創建參保人線程InterruptedException異常");
            throw new SystemException("批量創建參保人線程InterruptedException異常");
        }

        if (rollBack.getRollBack()) {
            logger.error("批量創建參保人線程回滾, 線程: {}, 需要更新的信息taskList: {}",
                    Thread.currentThread().getName(),
                    JSONObject.toJSONString(taskList));
            logger.info("子線程 {} 執行完畢,線程退出", Thread.currentThread().getName());
            throw new SystemException("批量創建參保人線程回滾");
        }

        logger.info("子線程 {} 執行完畢,線程退出", Thread.currentThread().getName());
        return returnDataList;
    }

思想就是使用兩個CountDownWatch實現子線程的二段提交

步驟:

     1、主線程將任務分發給子線程,然后使用 boolean await = threadLatch.await(20,TimeUnit.SECONDS);阻塞主線程,等待所有子線程處理向數據庫中插入的業務

    2、使用threadLatch.countDown();釋放子線程鎖定,同時使用mainLatch.await();阻塞子線程,將程序的控制權交還給主線程。

   3、主線程檢查子線程執行插入數據庫的結果,若有非預期結果出現,主線程標記狀態告知子線程回滾,然后使用mainLatch.countDown();將程序控制權再次交給子線程,子線程檢測回滾標志,判斷是否回滾。

   4、子線程執行結束,主線程拼接處理結果,響應給請求方


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM