之前写了一篇关于Java多线程解析shape文件为geojson,并入库postgresql的exe,目前更正了在某些错误情况下会卡住线程的bug,增加了基于x,y,以及z维度的空间数据判断,在建表识别是否有三维
老版本以及解释地址:Java解析Shape文件为Geojson并打包为.exe文件(多线程)_Hot2era的博客-CSDN博客
package com.env;
import cn.hutool.core.io.FileUtil;
import cn.hutool.core.lang.UUID;
import com.alibaba.fastjson.JSONArray;
import com.alibaba.fastjson.JSONObject;
import com.env.common.utils.NullUtil;
import com.env.common.utils.ShapeUtil;
import com.env.common.utils.StringUtils;
import org.opengis.referencing.FactoryException;
import org.springframework.boot.autoconfigure.SpringBootApplication;
import org.springframework.boot.autoconfigure.jdbc.DataSourceAutoConfiguration;
import springfox.documentation.spring.web.json.Json;
import java.io.*;
import java.math.BigDecimal;
import java.sql.*;
import java.util.*;
import java.util.concurrent.*;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors;
/**
* 启动程序
*
* @author mzc
*/
@SpringBootApplication(exclude = {DataSourceAutoConfiguration.class})
public class EnvApplication {
public static AtomicInteger count = new AtomicInteger(1);
//线程池大小
public static Integer threadPoolCapacity = 16;
//一个线程装多少个队列
public static Integer queueInThread = 8;
//队列阈值
public static Integer queueCapacity = 2000;
//模式
public static String mode = "";
//服务器地址
public static String host = "";
//端口号
public static String port = "";
//哪个数据库
public static String database = "";
//用户名
public static String username = "";
//密码
public static String password = "";
public static String[] suffix = {"shp", "dbf", "sbn", "prj", "sbx"};
public static String url = "";
public static String filePath="";
public static Integer srid = 0;
public static String charset = "GBK";
public static void main(String[] args) throws FactoryException, IOException {
//读取参数
Boolean conf = readConfig(args);//readConfig(args);
if (!conf){
return;
}
//统计一下时间
long start = System.currentTimeMillis();
//当前路径
String exeDirectory = System.getProperty("user.dir");
List<File> files = FileUtil.loopFiles(exeDirectory);
if (NullUtil.isNotEmpty(filePath)){
files = FileUtil.loopFiles(new File(filePath));//"D:\\航天宏图-环境咨询事业部\\数据库信息\\02数据库开发样例数据"
}
System.out.println("filesize:" + files.size());
if (files.size()==0){
System.out.println("不存在shape文件");
return;
}
List<String> urlList = new ArrayList<>();
//递归找出所有文件
loop(files, urlList);
//去重
urlList = urlList.stream().distinct().collect(Collectors.toList());
//名字
List<String> nameList = urlList.stream()
.map(url -> {
String halfName = url.substring(url.lastIndexOf(File.separator) + 1);
String substring = halfName.substring(0, halfName.lastIndexOf("."));
return substring;
})
.collect(Collectors.toList());
//处理每一个shape
for (int i = 0; i < urlList.size(); i++) {
//找坐标系
if (NullUtil.isEmpty(srid)) {
srid = ShapeUtil.getSRIDFromShapefile(urlList.get(i));
if (NullUtil.isEmpty(srid)){
srid = 4326;
}
}
Object o = ShapeUtil.shp2geojson(urlList.get(i), srid,charset);
Map<String, Object> geoMap = (HashMap) o;
//解析map
List<Map<String, Object>> maps = analysisResult(geoMap);
//和名称同样的顺序,拿名称作为表名
String[] s = nameList.get(i).split("_");
String tableName = "";
if (s.length>3){
tableName = nameList.get(i);
}else if (s.length == 2) {
System.out.println("========"+s[1]);
if (s[1].endsWith("数据")){
tableName = s[1].replace("数据","");
}else {
tableName = s[1];
}
} else if (s.length == 3) {
if (s[1].endsWith("数据")){
tableName = s[1].replace("数据","") + s[2];
}else {
tableName = s[1]+s[2];
}
} else if (s.length == 1) {
tableName = s[0];
}
//maps = maps.subList(0,10000);
pgConnect(tableName, maps);
//System.out.println(geoMap);
}
}
/**
* 分析结果,组成一个map用于数据库添加
*
* @param geoMap
* @return
*/
public static List<Map<String, Object>> analysisResult(Map<String, Object> geoMap) {
JSONObject data = (JSONObject) geoMap.get("data");
JSONArray features = (JSONArray) data.get("features");
List<Map<String, Object>> dataMap = new ArrayList<>();
features.forEach(fe -> {
//多个features
Map<String, Object> oneMap = new HashMap<>();
JSONObject feObj = (JSONObject) fe;
if (feObj.containsKey("geometry") && feObj.containsKey("properties")) {
JSONObject geometry = (JSONObject) feObj.get("geometry");
JSONObject properties = (JSONObject) feObj.get("properties");
//转小写
properties.forEach((k, v) -> {
String lowK = k.toLowerCase();
oneMap.put(lowK, v);
});
String type = geometry.get("type").toString();
Integer srid = (Integer) feObj.get("srid");
//带数据
oneMap.put("geom", geometry.toJSONString());
oneMap.put("srid", srid);
oneMap.put("type", type);
}
dataMap.add(oneMap);
});
return dataMap;
}
/**
* 遍历文件 这是第一步
*
* @param originalFiles
* @param urlList
*/
public static void loop(List<File> originalFiles, List<String> urlList) {
for (File file : originalFiles) {
if (file.isDirectory()) {
File[] ls = FileUtil.ls(file.getAbsolutePath());
loop(Arrays.asList(ls), urlList);
} else {
String name = file.getName();
//匹配后缀
if (!StringUtils.isEmpty(name)) {
boolean isExtensionValid = Arrays.stream(suffix)
.anyMatch(ext -> name.endsWith("." + ext));
if (isExtensionValid) {
//拼接
urlList.add(file.getParentFile() + File.separator + name.substring(0, name.indexOf(".")) + ".shp");
}
}
}
}
}
/**
* 数据库连接,以及执行动态加表和数据
*
* @param tableName
* @param datas
*/
public static void pgConnect(String tableName, List<Map<String, Object>> datas) {
Connection connection = null;
Statement statement = null;
try {
// 1. 注册数据库驱动
Class.forName("org.postgresql.Driver");
// 2. 建立数据库连接
connection = DriverManager.getConnection(url, username, password);
if (null==connection){
System.out.println("数据库连接异常,请检查");
return;
}
//设置事务不自动提交
connection.setAutoCommit(false);
// 3. 创建 Statement 对象
statement = connection.createStatement();
//4 拼接建表sql
String createTableSql = bufferAdd(tableName, datas,statement);
//5 建表
System.out.println(createTableSql);
statement.executeUpdate(createTableSql.toString());
System.out.println("================创建表" + tableName + "成功,接下来添加数据==================");
connection.commit();
//创建线程池 多线程插入
//ExecutorService executorService = Executors.newFixedThreadPool(threadPoolCapacity);
ThreadPoolExecutor executor = new ThreadPoolExecutor(
threadPoolCapacity, // 核心线程数
threadPoolCapacity, // 最大线程数
0L, TimeUnit.MILLISECONDS, // 空闲线程存活时间
new LinkedBlockingQueue<>()); // 任务队列
//使用队列 一个线程固定装几个队列 如一共8个线程 每个线程3个队列 就需要24个队列
Vector<ArrayBlockingQueue<String>> queues = new Vector<>();
for (int i = 0; i < threadPoolCapacity * queueInThread; i++) {
queues.add(new ArrayBlockingQueue(queueCapacity));
}
//6 多线程读取数据并入库 在读取数据里多开的线程
Connection finalConnection = connection;
finalConnection.setAutoCommit(false);
Thread dataReaderThread = new Thread(() -> readBufferSql(tableName, datas, queues, executor));
dataReaderThread.start();
dataReaderThread.join();
//executor.awaitTermination(Long.MAX_VALUE, TimeUnit.NANOSECONDS);
executor.shutdownNow();
connection.close();
} catch (ClassNotFoundException | SQLException | InterruptedException e) {
e.printStackTrace();
} finally {
// 5. 关闭连接和释放资源
try {
if (statement != null) {
statement.close();
}
if (connection != null) {
connection.close();
}
} catch (Exception e) {
e.printStackTrace();
}
}
}
/**
* 建表语句拼接
*/
public static String bufferAdd(String tableName, List<Map<String, Object>> datas,Statement statement) throws SQLException {
//取出所有的key作为列名
Map<String, Object> keyMap = datas.get(0);
//判断是否有z维度
String zType = null;
String zSql = "SELECT " +
"CASE WHEN ST_NDims(ST_Force3D(ST_GeomFromGeoJSON('"+keyMap.get("geom")+"'))) = 3 " +
"THEN 'MultiPolygonZ' ELSE '' END AS z_coordinate_flag;";
ResultSet resultSet = statement.executeQuery(zSql);
System.out.println("======判断是否有z维度=======");
System.out.println(zSql);
String haveZ = "";
if (resultSet.next()){
haveZ = resultSet.getString(1);
}
String columType = haveZ.equals("")?keyMap.get("type").toString():haveZ;
if (NullUtil.isNotEmpty(columType)){
System.out.println("======有z维度======");
}else {
System.out.println("======无=====");
}
String srid = keyMap.get("srid").equals(0)?"4326":keyMap.get("srid").toString();
count.set(0);
// 4. 执行 SQL 创建表语句
StringBuffer createTableSql = new StringBuffer("CREATE TABLE IF NOT EXISTS ");
createTableSql.append(tableName);
createTableSql.append("(gid serial primary key,");
keyMap.forEach((k, v) -> {
if (k.equals("geom")) {
createTableSql
.append(" geom public.geometry(")
.append("'" + columType + "'")
.append("," + srid)
.append("),");
} else {
String className = v.getClass().getSimpleName();
String type = "varchar(512)";
//还原类型,小数类型
if (
className.equals("Double")
|| className.equals("BigDecimal")
|| className.equals("Float")
|| className.equals("float")
|| className.equals("double")
) {
type = "decimal";
}
//整数
if (
className.equals("Integer")
|| className.equals("int")
) {
type = "int4";
}
createTableSql.append(k + " " + type + ",");
}
});
//去掉最后一个逗号
createTableSql.deleteCharAt(createTableSql.length() - 1);
createTableSql.append(")");
return createTableSql.toString();
}
/**
* 插入数据sql拼接
*/
public static void readBufferSql(String
tableName, List<Map<String, Object>> datas, Vector<ArrayBlockingQueue<String>> queues, ThreadPoolExecutor
executor) {
long start = System.currentTimeMillis();
//使用原子类 方便计算 这里是队列集合的下标 动态取队列 一个队列装满了就插入并且用下一个取
AtomicInteger queueIndex = new AtomicInteger(0);
//每次取一个 如果前一个装满了就用后面一个
ArrayBlockingQueue<String> queue = queues.get(queueIndex.get());
//添加数据
List<String> insertSqlList = new ArrayList<>();
// 定义一个Future列表,用于存储CompletableFuture
List<CompletableFuture<Void>> futures = new ArrayList<>();
datas.forEach(da -> {
StringBuffer insertSql = new StringBuffer("insert into " + tableName);
StringBuffer keySql = new StringBuffer("( ");
//拼接gid
StringBuffer valuesSql = new StringBuffer(" VALUES( ");
String srid = da.get("srid").toString();
da.forEach((k, v) -> {
keySql.append(k + ",");
//以下是值
//先把值的引号之类的东西去掉
v = v.toString().replaceAll("'", "");
//非gid为空直接为null
if (NullUtil.isEmpty(v)) {
valuesSql.append("null" + ",");
} else if (k.equals("geom")) {
//空间数据 需要函数转
valuesSql.append("ST_SetSRID(ST_GeomFromGeoJSON('" + v + "')," + srid + "),");
} else {
valuesSql.append("'" + v + "',");
}
});
//去掉末尾逗号
keySql.deleteCharAt(keySql.length() - 1);
valuesSql.deleteCharAt(valuesSql.length() - 1);
keySql.append(")");
valuesSql.append(")");
//外面统一执行sql
String lastSql = insertSql.append(keySql).append(valuesSql).toString();
insertSqlList.add(lastSql);
try {
//加入队列
queue.put(lastSql);
//如果装满了阈值那么大的东西 就执行插入
if (queue.size() % queueCapacity == 0 || datas.size() < queueCapacity) {
//CompletableFuture.runAsync 方法提交了一个异步任务,该任务会获取数据库连接并执行插入操作
executor.submit(() -> {
Connection connection = null;
try {
connection = DriverManager.getConnection(url, username, password);
connection.setAutoCommit(false);
insert(connection, queue);
} catch (SQLException e) {
e.printStackTrace();
} catch (Exception e) {
e.printStackTrace();
}finally {
try {
connection.close();
} catch (SQLException e) {
e.printStackTrace();
}
}
});
queueIndex.getAndIncrement();
}
} catch (InterruptedException e) {
throw new RuntimeException(e);
} catch (Exception e) {
e.printStackTrace();
}
});
executor.shutdown();
long end = System.currentTimeMillis();
System.out.println("====================一共耗时:" + (end - start) / 1000 + "秒=================");
}
/**
* 执行插入语句
*
* @param connection
* @param sqls
*/
public static void insert(Connection connection, ArrayBlockingQueue<String> sqls) throws SQLException {
Statement statement = connection.createStatement();
while (sqls.size() > 0) {
int andIncrement = count.getAndIncrement();
String poll = sqls.poll();
if (NullUtil.isEmpty(poll)) {
break;
}
statement.execute(poll);
System.out.println(andIncrement);
}
connection.commit();
statement.close();
connection.close();
}
/**
* 读取参数
* @param args
* @return
*/
public static Boolean readConfig(String[] args) {
//help
Boolean res = true;
if (args.length == 1 && args[0].equals("-help")){
System.out.println("************欢迎使用shape转geojson工具,请按照顺序输入对应参数,以连接数据库:***********");
System.out.println("");
System.out.println("-v //版本号");
System.out.println("-charset //编码 如 UTF-8 默认为GBK");
System.out.println("-filepath //指定读取文件的路径 默认为当前目录");
System.out.println("-srid //坐标系 输入数字 如4326");
System.out.println("-thread //线程数量(到2000条数据才会用多线程)");
System.out.println("");
System.out.println("数据库链接请输入 -db:<host>+<port>+<database>+<mode>+<username>+<password>");
System.out.println("");
System.out.println("tip:");
System.out.println("请将shape文件,或是文件夹放置于exe程序同级或同级文件夹的下级");
System.out.println("程序将递归同级文件夹以及所有子级的shape文件");
System.out.println("如:【 shape2geojson.exe 数据1文件夹 数据2文件夹 】");
System.out.println(" ↑这是程序 这是文件夹↑,里面包几层都无所谓 ↑ ");
System.out.println("");
System.out.println("请注意:-db命令是必填,除-v -help之外,所有参数命令在-db命令之后执行,值用等号连接");
System.out.println("");
System.out.println("例如:【shape2geojson.exe -db=127.0.0.1+5432+test+pubic+user+password -srid=4326 -charset=GBK】");
return false;
}
if (args.length == 1 && args[0].equals("-v")){
System.out.println("version:1.0");
return false;
}
if (args.length>1){
List<String> params = Arrays.asList(args);
for (String one : params) {
String value = one.substring(one.lastIndexOf("=")+1);
String key =one.substring(0,one.lastIndexOf("="));
System.out.println(value);
//拼接数据库连接
if (NullUtil.isEmpty(value)||NullUtil.isEmpty(key)){
System.out.println("参数不正确,请输入 -help 查看参数结构");
return false;
}
//============= 数据库连接命令 =================
if (key.equals("-db")){
String[] s = value.split("\\+");
if (s.length<6){
System.out.println("数据库连接参数有误,请按照【 地址+端口号+数据库+模式+用户名+密码】输入");
return false;
}
String[] set = new String[]{host,port,database,mode,username,password};
for (int i = 0; i < s.length; i++) {
switch (i){
case 0: host=s[i];break;
case 1: port=s[i];break;
case 2: database=s[i];break;
case 3: mode = s[i];break;
case 4:username = s[i];break;
case 5:password = s[i];break;
}
}
}
//============= 数据库连接命令 =================
//============= 坐标系 =================
if (key.equals("-srid")){
try {
srid = Integer.parseInt(value);
}catch (Exception e){
System.out.println("srid输入有误,请输入数字");
e.printStackTrace();
return false;
}
}
if (key.equals("-charset")){
charset = value;
}
if (key.equals("-filepath")){
filePath = value;
}
}
}else {
System.out.println("参数错误");
return false;
}
url = "jdbc:postgresql://" + host + ":" + port + "/" + database + "?currentSchema=" + mode;
// 在这里可以使用读取到的配置项进行后续的操作
System.out.println(url);
return res;
}
}