批量写入大数据在我们平时的项目中或有遇到,一般我们能想到的提高速度的方式就是使用多线程。比如我们要入10w条数据,那么创建10个线程,每个线程承担入1w条数据。从效率上来说,这比单线程场景高10倍。本人曾经想试图封装这个工具类出来,但是借鉴了网上很多封装的例子,最后还是失败了。
最近浏览群里的一位大佬的帖子,发现了他也封装了这种批量提交的工具类,我体验了一下速度很快,所以就想拿出来给大家分享一下,在此也十分感谢茶佬的支持。
添加必要依赖
因为有用到hutool的异常工具类,所以要添加,如果你项目没有且没有添加的条件,改成手动抛出则可
org.springframework.boot spring-boot-starter-web com.alibaba druid-spring-boot-starter 1.1.10 org.mybatis.spring.boot mybatis-spring-boot-starter 2.1.2 mysql mysql-connector-java 5.1.47 runtime cn.hutool hutool-all 5.7.16 org.springframework.boot spring-boot-starter-test test
工具类
import cn.hutool.core.exceptions.ExceptionUtil; import java.util.ArrayList; import java.util.List; import java.util.concurrent.ArrayBlockingQueue; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; import java.util.stream.Collectors; import java.util.stream.IntStream; public class BatchInsertProcessor{ private String threadNamePrefix = "reportInsert-"; // 是否已经开始处理 private boolean started; // 用于等待线程处理结束后的收尾处理 private CountDownLatch cdl; // 是否还会产生数据: 用于配合 queue.size() 判断线程是否该结束 private volatile boolean isProduceData = true; // 实体数据容器队列,队列满,则限制生产方的生产速度 private ArrayBlockingQueue queue; // 能存储数据时,就调用该方法给使用方,使用方可以调用存储接口存储 private StorageConsumer consumer; // 批量插入时,每次最多插入多少条 private int maxItemCount; private List workThreadList; public BatchInsertProcessor() { this(1000); } public BatchInsertProcessor(int capacity) { queue = new ArrayBlockingQueue<>(capacity); } public synchronized void setThreadNamePrefix(String threadNamePrefix) { if (started) { throw new RuntimeException("已经开始处理,不能再线程名称前缀"); } this.threadNamePrefix = threadNamePrefix; } public void start(StorageConsumer consumer) { this.start(consumer, 4); } public void start(StorageConsumer consumer, int workThreadCount) { this.start(consumer, workThreadCount, 0); } public synchronized void start(StorageConsumer consumer, int workThreadCount, int maxItemCount) { if (started) { throw new RuntimeException("处理中"); } started = true; this.consumer = consumer; this.maxItemCount = maxItemCount; this.cdl = new CountDownLatch(workThreadCount); workThreadList = IntStream.range(0, workThreadCount) .mapToObj(i -> { final WorkThread workThread = new WorkThread(threadNamePrefix + i, maxItemCount); workThread.start(); return workThread; }) .collect(Collectors.toList()); } public void put(T entity) { try { queue.put(entity); } catch (InterruptedException e) { ExceptionUtil.wrapAndThrow(e); } } public void await() { if (!started) { throw new RuntimeException("还未运行"); } try { isProduceData = false; cdl.await(); for (WorkThread workThread : workThreadList) { workThread.clearEntity(); } } catch (InterruptedException e) { ExceptionUtil.wrapAndThrow(e); } } public void stop() { if (!started) { throw new RuntimeException("还未运行"); } isProduceData = false; queue.clear(); } private class WorkThread extends Thread { // 批量插入时,用于缓存实体的容器 private List batchCacheContainer; private int maxItemCount; public WorkThread(String name, int maxItemCount) { super(name); this.maxItemCount = maxItemCount; if (maxItemCount > 0) { batchCacheContainer = new ArrayList<>(maxItemCount); } } @Override public void run() { while (true) { // 如果不产生数据了,队列也会空,则退出线程 if (!isProduceData && queue.size() == 0) { break; } final T entity; try { entity = queue.poll(500, TimeUnit.MILLISECONDS); if (entity == null) { continue; } if (maxItemCount > 0) { batchCacheContainer.add(entity); if (batchCacheContainer.size() >= maxItemCount) { consumer.accept(null, batchCacheContainer); batchCacheContainer.clear(); } } else { consumer.accept(entity, null); } } catch (InterruptedException e) { cdl.countDown(); ExceptionUtil.wrapAndThrow(e); } } cdl.countDown(); } public void clearEntity() { if (maxItemCount > 0 && batchCacheContainer.size() > 0) { consumer.accept(null, batchCacheContainer); batchCacheContainer.clear(); } } } public interface StorageConsumer { void accept(T t, List ts); } }
封装实体类,这个我自己写了个测试实体
public class DemoEntity {
private int id;
private String name;
public DemoEntity(int id, String name) {
this.id = id;
this.name = name;
}
public int getId() {
return id;
}
public void setId(int id) {
this.id = id;
}
public String getName() {
return name;
}
public void setName(String name) {
this.name = name;
}
@Override
public String toString() {
return "DemoEntity{" +
"id=" + id +
", name='" + name + ''' +
'}';
}
}
创建表sql
CREATE TABLE `batch_save` ( `id` int(11) NOT NULL, `name` varchar(255) DEFAULT NULL, PRIMARY KEY (`id`) ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
mapper.xml
insert into batch_save(id, name) values (#{item.id}, #{item.name})
测试类
package com.zhbcm.save;
import com.zhbcm.save.dao.BatchSaveDao;
import com.zhbcm.save.entity.DemoEntity;
import com.zhbcm.save.util.BatchInsertProcessor;
import org.junit.jupiter.api.Test;
import org.springframework.boot.test.context.SpringBootTest;
import javax.annotation.Resource;
@SpringBootTest
class BatchInsertApplicationTests {
@Resource
private BatchSaveDao batchSaveDao;
@Test
public void batchInsert() {
final BatchInsertProcessor work = new BatchInsertProcessor<>();
work.start((t, ts) -> {
//单条数据t就有数据,多条数据ts就会有数据
//这里我是插了2w条数据,所以用ts
batchSaveDao.saveList(ts);
}, 10, 2000);
// 模拟生产数据
try {
for (int i = 0; i < 20000; i++) {
work.put(new DemoEntity(i, i + " name"));
}
// 等待入库完成
work.await();
} catch (Exception e) {
// 如果生产过程中有异常,立即停止掉处理器,不再入库
work.stop();
}
}
实际效果:2w条数据1s,我这里开了10个线程,每个线程最大承担量为2000
码云传送门
原文茶佬笔记地址:https://www.yuque.com/mrcode.cn/note-combat/qd8bo3



