MongoTemplate很好用,但是基於xml注冊為Bean時只能綁定在一個database上。
遇到需要支撐多個database的項目或動態切換database的項目就非常難受了。
解決的思路是把MongoTemplate放在Map中緩存起來,由於MongoTemplate內部實現了連接池,所以不用再關心池的概念。
把管理容器的類聲明為Spring的組件,這樣一來就可以通過@Value引入properties文件中的屬性
使用LocalThread來確保本地線程的安全,避免多線程並發調用時導致的結果不一致。
import com.mongodb.*; import org.springframework.beans.factory.annotation.Value; import org.springframework.data.mongodb.MongoDbFactory; import org.springframework.data.mongodb.core.MongoTemplate; import org.springframework.data.mongodb.core.SimpleMongoDbFactory; import org.springframework.data.mongodb.core.convert.DefaultDbRefResolver; import org.springframework.data.mongodb.core.convert.DefaultMongoTypeMapper; import org.springframework.data.mongodb.core.convert.MappingMongoConverter; import org.springframework.data.mongodb.core.mapping.MongoMappingContext; import org.springframework.data.mongodb.core.query.Query; import org.springframework.data.mongodb.core.query.Update; import org.springframework.stereotype.Repository; import java.net.UnknownHostException; import java.util.Collections; import java.util.List; import java.util.Map; import java.util.Optional; import java.util.concurrent.ConcurrentHashMap; /** * @author ParanoidCAT * @since JDK 1.8 */ @Repository(value = "mongoRepository") public class MongoRepository { @Value("${mongo.host}") private String host; @Value("${mongo.port}") private Integer port; @Value("${mongo.username}") private String username; @Value("${mongo.password}") private String password; @Value("${mongo.database}") private String database; @Value("${mongo.connectionsPerHost}") private Integer connectionsPerHost; @Value("${mongo.threadsAllowedToBlockForConnectionMultiplier}") private Integer threadsAllowedToBlockForConnectionMultiplier; @Value("${mongo.connectTimeout}") private Integer connectTimeout; @Value("${mongo.maxWaitTime}") private Integer maxWaitTime; @Value("${mongo.socketTimeout}") private Integer socketTimeout; @Value("${mongo.socketKeepAlive}") private Boolean socketKeepAlive; private ThreadLocal<MongoTemplate> threadLocal = new ThreadLocal<>(); private static final Map<String, MongoTemplate> MONGO_TEMPLATE_CACHE = new ConcurrentHashMap<>(16); private void changeDatabase(String databaseName) { if (Optional.ofNullable(threadLocal.get()).map(MongoTemplate::getDb).map(DB::getName).orElse(database).equals(databaseName)) { return; } if (MONGO_TEMPLATE_CACHE.containsKey(databaseName)) { threadLocal.remove(); threadLocal.set(MONGO_TEMPLATE_CACHE.get(databaseName)); return; } synchronized (MONGO_TEMPLATE_CACHE) { if (MONGO_TEMPLATE_CACHE.containsKey(databaseName)) { changeDatabase(databaseName); } else { threadLocal.remove(); try { threadLocal.set(createMongoTemplate(databaseName)); MONGO_TEMPLATE_CACHE.putIfAbsent(databaseName, threadLocal.get()); } catch (Exception e) { // TODO 輸出日志 System.out.println(e.toString()); } } } } private MongoTemplate createMongoTemplate(String databaseName) throws UnknownHostException { MongoClient mongoClient = new MongoClient( Collections.singletonList(new ServerAddress(host, port)), Collections.singletonList(MongoCredential.createCredential(username, database, password.toCharArray())), new MongoClientOptions .Builder() .connectionsPerHost(connectionsPerHost) .threadsAllowedToBlockForConnectionMultiplier(threadsAllowedToBlockForConnectionMultiplier) .connectTimeout(connectTimeout) .maxWaitTime(maxWaitTime) .socketTimeout(socketTimeout) .socketKeepAlive(socketKeepAlive) .cursorFinalizerEnabled(true) .build() ); MongoDbFactory mongoDbFactory = new SimpleMongoDbFactory(mongoClient, databaseName); MappingMongoConverter mappingMongoConverter = new MappingMongoConverter(new DefaultDbRefResolver(mongoDbFactory), new MongoMappingContext()); mappingMongoConverter.setTypeMapper(new DefaultMongoTypeMapper(null)); return new MongoTemplate(mongoDbFactory, mappingMongoConverter); } /** * 插入一條記錄 * * @param databaseName 數據庫名 * @param t 實例 * @param <T> 實例所屬的類 */ public <T> void insert(String databaseName, T t) { changeDatabase(databaseName); threadLocal.get().insert(t); } /** * 插入一條記錄 * * @param databaseName 數據庫名 * @param collectionName 集合名 * @param t 實例 * @param <T> 實例所屬的類 */ public <T> void insert(String databaseName, String collectionName, T t) { changeDatabase(databaseName); threadLocal.get().insert(t, collectionName); } /** * 插入多條記錄 * * @param databaseName 數據庫名 * @param tClass 實例的class * @param tList 實例 * @param <T> 實例所屬的類 */ public <T> void insertAll(String databaseName, Class<T> tClass, List<T> tList) { changeDatabase(databaseName); threadLocal.get().insert(tList, tClass); } /** * 插入多條記錄 * * @param databaseName 數據庫名 * @param collectionName 集合名 * @param tList 實例 * @param <T> 實例所屬的類 */ public <T> void insertAll(String databaseName, String collectionName, List<T> tList) { changeDatabase(databaseName); threadLocal.get().insert(tList, collectionName); } /** * 移除一條或多條記錄 * * @param databaseName 數據庫名 * @param tClass 實例的class * @param query 查詢條件 * @param <T> 實例所屬的類 * @return */ public <T> long remove(String databaseName, Class<T> tClass, Query query) { changeDatabase(databaseName); return threadLocal.get().remove(query, tClass).getN(); } /** * 移除一條或多條記錄 * * @param databaseName 數據庫名 * @param collectionName 集合名 * @param query 查詢條件 * @return 受影響的記錄條數 */ public long remove(String databaseName, String collectionName, Query query) { changeDatabase(databaseName); return threadLocal.get().remove(query, collectionName).getN(); } /** * 更新多條記錄 * * @param databaseName 數據庫名 * @param tClass 實例的class * @param query 查詢條件 * @param update 更新內容 * @param <T> 實例所屬的類 * @return 受影響的記錄條數 */ public <T> long updateMulti(String databaseName, Class<T> tClass, Query query, Update update) { changeDatabase(databaseName); return threadLocal.get().updateMulti(query, update, tClass).getN(); } /** * 更新多條記錄 * * @param databaseName 數據庫名 * @param collectionName 集合名 * @param query 查詢條件 * @param update 更新內容 * @return 受影響的記錄條數 */ public long updateMulti(String databaseName, String collectionName, Query query, Update update) { changeDatabase(databaseName); return threadLocal.get().updateMulti(query, update, collectionName).getN(); } /** * 查詢多條記錄 * * @param databaseName 數據庫名 * @param tClass 實例的class * @param query 查詢條件 * @param <T> 實例所屬的類 * @return 實例 */ public <T> List<T> find(String databaseName, Class<T> tClass, Query query) { changeDatabase(databaseName); return threadLocal.get().find(query, tClass); } /** * 查詢多條記錄 * * @param databaseName 數據庫名 * @param collectionName 集合名 * @param query 查詢條件 * @param tClass 實例的class * @param <T> 實例所屬的類 * @return 實例 */ public <T> List<T> find(String databaseName, String collectionName, Query query, Class<T> tClass) { changeDatabase(databaseName); return threadLocal.get().find(query, tClass, collectionName); } /** * 查詢第一條記錄 * * @param databaseName 數據庫名 * @param tClass 實例的class * @param query 查詢條件 * @param <T> 實例所屬的類 * @return 實例 */ public <T> T findOne(String databaseName, Class<T> tClass, Query query) { changeDatabase(databaseName); return threadLocal.get().findOne(query, tClass); } /** * 查詢第一條記錄 * * @param databaseName 數據庫名 * @param tClass 實例的class * @param collectionName 集合名 * @param query 查詢條件 * @param <T> 實例所屬的類 * @return 實例 */ public <T> T findOne(String databaseName, Class<T> tClass, String collectionName, Query query) { changeDatabase(databaseName); return threadLocal.get().findOne(query, tClass, collectionName); } }
測試類:
40個線程同時並發,測試多線程調用是否安全
這里我在五個database中放入了五個名稱為"test"的collection,每個collection里放了一個{"name":"數據庫名"}的Document用於測試
import com.mdruby.repository.MongoRepository; import com.mongodb.BasicDBObject; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.data.mongodb.core.query.Query; import org.springframework.stereotype.Controller; import org.springframework.web.bind.annotation.PathVariable; import org.springframework.web.bind.annotation.RequestMapping; import org.springframework.web.bind.annotation.RequestMethod; import org.springframework.web.bind.annotation.ResponseBody; import java.util.Optional; import java.util.concurrent.CountDownLatch; /** * @author ParanoidCAT * @since JDK 1.8 */ @Controller(value = "testController") @RequestMapping(value = {"test"}) public class TestController { @Autowired private MongoRepository mongoRepository; private static final CountDownLatch COUNT_DOWN_LATCH = new CountDownLatch(40); @RequestMapping(value = {"/run/{databaseName}"}, method = {RequestMethod.GET}, produces = {"application/json;charset=utf-8"}) @ResponseBody public void run(@PathVariable String databaseName) throws InterruptedException { // 線程計數+1 COUNT_DOWN_LATCH.countDown(); // 線程沒到40個就等等 COUNT_DOWN_LATCH.await(); // 線程如果到了40個就一起放行,每個線程執行150次query for (int i = 0; i < 150; i++) { BasicDBObject basicDBObject = mongoRepository.findOne(databaseName, BasicDBObject.class, "test", new Query()); if (!databaseName.equals(Optional.ofNullable(basicDBObject).map(basicDBObject1 -> basicDBObject1.getString("name")).orElse("testString"))) { System.out.println(Thread.currentThread().getName() + ": " + databaseName + " - " + basicDBObject); } } } }
