代码使用了递归, gson,nd4j
import com.google.gson.JsonElement;
import org.apache.commons.lang3.ArrayUtils;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.shape.Shape;
import java.util.ArrayList;
import java.util.Arrays;
class CommonPort{
static INDArray load_json_to_NDArray(JsonElement json_array, int depth) {
if (0 >= depth)
return null;
ArrayList<Integer> shape = get_shape(json_array, depth);
Integer[] shape_a = new Integer[shape.size()];
shape.toArray(shape_a);
int[] shape_int = ArrayUtils.toPrimitive(shape_a);
INDArray result;
if (1 == shape_int.length)
result = Nd4j.create(shape_int[0]);
else
result = Nd4j.zeros(shape_int);
int json_array_size = json_array.getAsJsonArray().size();
for(int i = 0; i < json_array_size; ++i){
if (1 == depth)
result.putScalar(i, json_array.getAsJsonArray().get(i).getAsDouble());
else
result.putRow(i, load_json_to_NDArray(json_array.getAsJsonArray().get(i), depth - 1));
}
if (1 == shape_int.length){
result = result.permute(1, 0);
}
return result;
}
static private ArrayList<Integer> get_shape(JsonElement json_array, int depth) {
ArrayList<Integer> result = new ArrayList<>();
if (1 == depth)
result.add(json_array.getAsJsonArray().size());
else if (1 < depth){
result.add(json_array.getAsJsonArray().size());
result.addAll(get_shape(json_array.getAsJsonArray().get(0), depth - 1));
}
return result;
}
}