获得一个无向图的所有最小生成树 Java实现 附思路及代码详解

500块接的活
本来以为用克鲁斯卡尔或者普利姆可以轻易的做出来,结果从晚上九点干到凌晨四点没做出来,放弃了。虽然拒了这个活,但是第二天激起了好胜心,决定把这个题做出来。
先说思路:

第一版的思路是采用排列组合的方式,比如一个无向图里一共有6个点,11条边,边的权重分别为:11122222333,采用克鲁斯卡尔获得一个最小生成树的边的权值分别为:12223
那么根据克鲁斯卡尔的思路,得知:所有最小生成树的边的组合一定都是按照12223的权值组合的五条边。
那么1有 C 3 1 C_3^1 C31= 3种选择情况,2有 C 5 3 C_5^3 C53种选择情况,3有 C 3 1 C_3^1 C31种选择情况。但是需要注意的是,并不是所有的选择情况都是合理的,因此需要去掉这些不合理的情况。
对于原理而言很好理解,但是对于写代码而言,要在一段代码里同时实现选择和去掉不对的选择的功能。我没有想出来如何完美写出这种代码的办法。故放弃此法。

第二天想到了第二版的思路:
不用克鲁斯卡尔算法的思路,而采用普利姆算法的思路,普利姆每次从队列里选择一条权值最小的边,判断其是否合理,如果合理那么就将这条边放入(合理指:这条边的另一个端点尚未被visited过,为什么只考虑一个点呢?因为另一个点必然已经被访问过,不然这条边是怎么放在队列里的~)
如果不合理,那么就继续找到优先队列里权值最小的边。
于是思路就来了:我取出队列里的一条边后,不管其是否合理,都要继续考虑取出下一条边的情况。也就是说,我要获得所有的生成树!
那么获得所有生成树的代码的方法就是:递归回溯,也可以理解为dfs,即深度优先遍历的思想。
获得所有生成树后,判断它们的权值是否为最小生成树。(通过克鲁斯卡尔算法获得一个最小生成树的权值,不用普利姆是因为prim没有考虑到两条边同一权值的问题)
至此,获得了所有的最小生成树。
但是这里还有一个问题:题目给的是求无向图的所有最小生成树,但我们知道,无向图的代码都是通过有向图来实现的,也就是说,通过上述的第二版的思路,最后得到的最小生成树的答案的个数,一定是正确答案的两倍。
如:
1 -> 2, 2->3和2->1,3->2这两种答案,本质都是相同的一颗最小生成树。因此最后一步是对所有最小生成树去重。
我还没有想到比较好的去重方法,也希望大佬们能在评论区中指出。

下面为第二版思路的java代码具体实现:

首先是MSTAllHW类,它实现了要解决问题的接口,其中定义了add一个点的方法和addE一条边的方法,每次递归回溯时使用这两个方法来求所有的生成树。

public class MSTAllHW implements MSTAll {
    @Override
    public List<SpanningTree> getMinimumSpanningTrees(Graph graph) {
        PriorityQueue<Edge> queue = new PriorityQueue<>();
        SpanningTree tree = new SpanningTree();
        Set<Integer> visited = new HashSet<>();
        SpanningTree testTree = new MSTKruskal().getMinimumSpanningTree(graph);
        double minWeight = testTree.getTotalWeight();
        //****test*****
//        List<Edge> testEdges = testTree.getEdges();
//        HashSet<Integer> s = new HashSet();
//        for (Edge e : testEdges) {
//            s.add(e.getTarget());
//            s.add(e.getSource());
//            System.out.print(e.toString() + " ");
//        }
//        System.out.println();
//        System.out.println("----");
        //输出s里的所有东西
//        for(int a : s){
//            System.out.print(a+" ");
//        }
        //判断s的个数是否等于graph.size
//        System.out.println(graph.size()==s.size());
        //********
        System.out.println("weight:" + minWeight);

        System.out.println("------Start------");
        //list存放所有的生成树
        List<SpanningTree> list = new ArrayList<>();
        // add all connecting vertices from start vertex to the queue
        add(queue, visited, graph, 0);
        addE(queue, visited, tree, graph, list);
        //得到list后,要去掉3种情况:
        //1.totalWeight不对的
        //2.***边被重复用到的***
        //3.形成回路的(所有顶点没被用到的)
        ArrayList<String> arSum = new ArrayList<>();
        List<SpanningTree> resList = new ArrayList<>();
        for (SpanningTree sptree : list) {
            //1.判断总权重
            if (sptree.getTotalWeight() != minWeight)
                continue;
            //3.判断是否所有顶点都在
            HashSet<Integer> hs = new HashSet();
            StringBuilder sum = new StringBuilder();
            for (Edge e : sptree.getEdges()) {
                hs.add(e.getTarget());
                hs.add(e.getSource());
                sum.append(e.getTarget());
                sum.append(e.getSource());
            }
            if (hs.size() != graph.size())
                continue;
            //2.边被重复用到的 <==> 对所有点所对应的数求和,若和相同则说明已存在
            String str = new String(sum);
            boolean tt = false;
            for(String ss : arSum){
                if (ss.equals(str)){
                    tt=true;
                    break;
                }
            }
            if(tt)
                continue;
            else{
                arSum.add(str);
            }
            resList.add(sptree);
        }
        return resList;
    }

    private void addE(PriorityQueue<Edge> queue, Set<Integer> visited, SpanningTree tree, Graph graph, List<SpanningTree> list) {
        while (!queue.isEmpty()) {
            Edge[] edges = new Edge[graph.size() - 1];
//            edges[0] = queue.poll();
            //这里需要判断一下edges[0]的两个点是不是已经都用了,是的话需要放置新的edges[0]
            Edge p = queue.poll();
            int source = p.getSource();
            int target = p.getTarget();
            //如果source和target均存在于visited中,则取queue里的下一个数
            while (visited.contains(source) && visited.contains(target) && !queue.isEmpty()) {
                p = queue.poll();
                source = p.getSource();
                target = p.getTarget();
            }
            edges[0] = p;


            int i = 1;
            PriorityQueue<Edge> newQueue = new PriorityQueue<>(queue);
            PriorityQueue<Edge> linshiQ = new PriorityQueue<>(queue);
            while (!linshiQ.isEmpty()) {
                Edge poll = linshiQ.poll();
                double w = poll.getWeight();
                if (w == edges[0].getWeight())
                    edges[i++] = poll;
            }
            //edges里存了当前所有可以取的边的可能
            int j = 0;
            for (j = 0; j < i; j++) {
                //分别将可以取的边放进去,对所有情况分别讨论
                SpanningTree newTree = new SpanningTree(tree);
                Set<Integer> newVisited = new HashSet<>(visited);
                if (!visited.contains(edges[j].getSource())) {
                    newTree.addEdge(edges[j]);
                    // if a spanning tree is found, break.
                    if (newTree.size() + 1 == graph.size()) {
                        list.add(newTree);
                        continue;
                    }
                    // add all connecting vertices from current vertex to the queue
                    add(newQueue, newVisited, graph, edges[j].getSource());
                    addE(newQueue, newVisited, newTree, graph, list);
                    //System.out.println("为啥tree变成newTree了...");
                }
            }
        }
    }


    private void add(PriorityQueue<Edge> queue, Set<Integer> visited, Graph graph, int target) {
        visited.add(target);
        for (Edge edge : graph.getIncomingEdges(target)) {
            if (!visited.contains(edge.getSource()))
                queue.add(edge);
        }
    }
}

接口为:

public interface MSTAll {
    /**
     * @param graph an undirected graph containing zero to many spanning trees.
     * @return list of all minimum spanning trees.
     */
    public List<SpanningTree> getMinimumSpanningTrees(Graph graph);
}
其中,生成树的定义为:
public class SpanningTree implements Comparable<SpanningTree> {
    private final List<Edge> edges;
    private double total_weight;

    public SpanningTree() {
        edges = new ArrayList<>();
    }

    public SpanningTree(SpanningTree tree) {
        edges = new ArrayList<>(tree.getEdges());
        total_weight = tree.getTotalWeight();
    }

    public List<Edge> getEdges() {
        return edges;
    }

    public double getTotalWeight() {
        return total_weight;
    }

    public int size() {
        return edges.size();
    }

    public void addEdge(Edge edge) {
        edges.add(edge);
        total_weight += edge.getWeight();
    }

    @Override
    public int compareTo(SpanningTree tree) {
        double diff = total_weight - tree.total_weight;
        if (diff > 0) return 1;
        else if (diff < 0) return -1;
        else return 0;
    }

    @Override
    public String toString() {
        StringBuilder build = new StringBuilder();

        for (Edge edge : edges)
            build.append(String.format("\n%d <- %d : %f", edge.getTarget(), edge.getSource(), edge.getWeight()));

        return build.length() > 0 ? build.substring(1) : "";
    }

    public String getUndirectedSequence() {
        int i, size = size(), min, max;
        int[] array = new int[size];
        Edge edge;

        for (i = 0; i < size; i++) {
            edge = edges.get(i);

            if (edge.getSource() < edge.getTarget()) {
                min = edge.getSource();
                max = edge.getTarget();
            }
            else {
                min = edge.getTarget();
                max = edge.getSource();
            }

            array[i] = min * (size + 1) + max;
        }

        Arrays.sort(array);
        return Arrays.toString(array);
    }

    // ========================= For MSTEdmonds.java =========================

    public Set<Integer> getTargets() {
        Set<Integer> set = new HashSet<>();

        for (Edge edge : edges)
            set.add(edge.getTarget());

        return set;
    }

    public List<List<Edge>> getCycles() {
        Map<Integer, List<Edge>> edgeMap = getEdgeMap();
        List<List<Edge>> cycles = new ArrayList<>();
        getCyclesAux(cycles, edgeMap, new ArrayList<>(), new HashSet<>(), getAnyEdge(edgeMap));
        return cycles;
    }

    /** @return Map whose keys are source vertices and keys are the edges. */
    private Map<Integer, List<Edge>> getEdgeMap() {
        Map<Integer, List<Edge>> map = new HashMap<>();
        List<Edge> tmp;

        for (Edge edge : edges) {
            tmp = map.computeIfAbsent(edge.getSource(), k -> new ArrayList<>());
            tmp.add(edge);
        }

        return map;
    }

    private Edge getAnyEdge(Map<Integer, List<Edge>> map) {
        for (List<Edge> edge : map.values())
            return edge.get(0);
        return null;
    }

    private void getCyclesAux(List<List<Edge>> cycles, Map<Integer, List<Edge>> edgeMap, List<Edge> cycle, Set<Integer> set, Edge curr) {
        if (edgeMap.isEmpty()) return;
        set.add(curr.getSource());
        cycle.add(curr);

        if (set.contains(curr.getTarget()))        // cycle
        {
            removeAll(edgeMap, set, cycle);
            cycles.add(cycle);
            getCyclesAux(cycles, edgeMap, new ArrayList<>(), new HashSet<>(), getAnyEdge(edgeMap));
        }
        else {
            List<Edge> tmp = edgeMap.get(curr.getTarget());

            if (tmp == null) {
                removeAll(edgeMap, set, cycle);
                getCyclesAux(cycles, edgeMap, new ArrayList<>(), new HashSet<>(), getAnyEdge(edgeMap));
            }
            else {
                for (Edge edge : new ArrayList<>(tmp))
                    getCyclesAux(cycles, edgeMap, new ArrayList<>(cycle), new HashSet<>(set), edge);
            }
        }
    }

    private void removeAll(Map<Integer, List<Edge>> map, Set<Integer> set, List<Edge> cycle) {
        List<Edge> tmp;

        for (int source : set) {
            tmp = map.get(source);

            if (tmp != null) {
                tmp.removeAll(cycle);
                if (tmp.isEmpty()) map.remove(source);
            }
        }
    }
}
使用克鲁斯卡尔算法求一颗最小生成树的代码(使用不同的方法求一颗最小生成树的结果一般都不相同)
public class MSTKruskal implements MST {
    @Override
    public SpanningTree getMinimumSpanningTree(Graph graph) {
        PriorityQueue<Edge> queue = new PriorityQueue<>(graph.getAllEdges());
        DisjointSet forest = new DisjointSet(graph.size());
        SpanningTree tree = new SpanningTree();

        while (!queue.isEmpty()) {
            Edge edge = queue.poll();

            if (!forest.inSameSet(edge.getTarget(), edge.getSource())) {
                tree.addEdge(edge);

                // a spanning tree is found
                if (tree.size() + 1 == graph.size()) break;
                // merge forests
                forest.union(edge.getTarget(), edge.getSource());
            }
        }

        return tree;
    }
}
图的定义:
public class Graph {
    /**
     * A list of edge lists where each dimension of the outer list indicates a target vertex and
     * the inner list corresponds to the list of incoming edges to that target vertex.
     */
    private final List<List<Edge>> incoming_edges;

    public Graph(int size) {
        incoming_edges = Stream.generate(ArrayList<Edge>::new).limit(size).collect(Collectors.toList());
    }

    public Graph(Graph g) {
        incoming_edges = g.incoming_edges.stream().map(ArrayList::new).collect(Collectors.toList());
    }

    public int size() {
        return incoming_edges.size();
    }

    public List<Edge> getIncomingEdges(int target) {
        return incoming_edges.get(target);
    }

    public List<Edge> getAllEdges() {
        return incoming_edges.stream().flatMap(List::stream).collect(Collectors.toList());
    }

    public Deque<Integer> getVerticesWithNoIncomingEdges() {
        return IntStream.range(0, size()).filter(i -> getIncomingEdges(i).isEmpty()).boxed().collect(Collectors.toCollection(ArrayDeque::new));
    }

    /**
     * @return a list of edge deque where each dimension in the outer list represents the deque of outgoing edges for
     * the corresponding source vertex.
     */
    public List<Deque<Edge>> getOutgoingEdges() {
        List<Deque<Edge>> outgoing_edges = Stream.generate(ArrayDeque<Edge>::new).limit(size()).collect(Collectors.toList());

        for (int target = 0; target < size(); target++) {
            for (Edge incoming_edge : getIncomingEdges(target))
                outgoing_edges.get(incoming_edge.getSource()).add(incoming_edge);
        }

        return outgoing_edges;
    }

    public Edge setDirectedEdge(int source, int target, double weight) {
        List<Edge> edges = getIncomingEdges(target);
        Edge edge = new Edge(source, target, weight);
        edges.add(edge);
        return edge;
    }

    public void setUndirectedEdge(int source, int target, double weight) {
        setDirectedEdge(source, target, weight);
        setDirectedEdge(target, source, weight);
    }

    public boolean containsCycle() {
        Deque<Integer> notVisited = IntStream.range(0, size()).boxed().collect(Collectors.toCollection(ArrayDeque::new));

        while (!notVisited.isEmpty()) {
            if (containsCycleAux(notVisited.poll(), notVisited, new HashSet<>()))
                return true;
        }

        return false;
    }

    private boolean containsCycleAux(int target, Deque<Integer> notVisited, Set<Integer> visited) {
        notVisited.remove(target);
        visited.add(target);

        for (Edge edge : getIncomingEdges(target)) {
            if (visited.contains(edge.getSource()))
                return true;

            if (containsCycleAux(edge.getSource(), notVisited, new HashSet<>(visited)))
                return true;
        }

        return false;
    }

    public List<Integer> topological_sort(boolean depth_first) {
        Deque<Integer> global = getVerticesWithNoIncomingEdges();
        List<Deque<Edge>> outgoingEdgesAll = getOutgoingEdges();
        List<Integer> order = new ArrayList<>();

        while (!global.isEmpty()) {
            Deque<Integer> local = new ArrayDeque<>();

            // add vertex to the path
            int vertex = global.poll();
            order.add(vertex);
            Deque<Edge> outgoingEdges = outgoingEdgesAll.get(vertex);

            while (!outgoingEdges.isEmpty()) {
                Edge edge = outgoingEdges.poll();

                // remove one outgoing edge at a time
                List<Edge> incomingEdges = getIncomingEdges(edge.getTarget());
                incomingEdges.remove(edge);

                // if the target vertex has no incoming edges, add it to the local queue awaited to be added to the global deque
                if (incomingEdges.isEmpty())
                    local.add(edge.getTarget());
            }

            //Transfer all vertices in local to global
            while (!local.isEmpty()) {
                if (depth_first) global.addFirst(local.removeLast());
                else global.addLast(local.removeFirst());
            }
        }

        if (!hasNoEdge()) throw new IllegalArgumentException("Cyclic graph.");
        return order;
    }

    public boolean hasNoEdge() {
        return IntStream.range(0, size()).allMatch(i -> getIncomingEdges(i).isEmpty());
    }

    public String toString() {
        StringBuilder build = new StringBuilder();

        for (int i = 0; i < incoming_edges.size(); i++) {
            build.append(i);
            build.append(" <- ");
            build.append(incoming_edges.get(i).toString());
            build.append("\n");
        }

        return build.toString();
    }
}

边的定义:
public class Edge implements Comparable<Edge> {
    private int source;
    private int target;
    private double weight;

    public Edge(int source, int target, double weight) {
        init(source, target, weight);
    }

    public Edge(int source, int target) {
        this(source, target, 0);
    }

    public Edge(Edge edge) {
        this(edge.getSource(), edge.getTarget(), edge.getWeight());
    }

    private void init(int source, int target, double weight) {
        setSource(source);
        setTarget(target);
        setWeight(weight);
    }

    public int getSource() {
        return source;
    }

    public int getTarget() {
        return target;
    }

    public double getWeight() {
        return weight;
    }

    public void setSource(int vertex) {
        source = vertex;
    }

    public void setTarget(int vertex) {
        target = vertex;
    }

    public void setWeight(double weight) {
        this.weight = weight;
    }

    public void addWeight(double weight) {
        this.weight += weight;
    }

    @Override
    public int compareTo(Edge edge) {
        double diff = weight - edge.weight;
        if (diff > 0) return 1;
        else if (diff < 0) return -1;
        else return 0;
    }

    public String toString() {
        return String.format("%d <- %d : %f", getTarget(), getSource(), getWeight());
    }
}

可以看到这些定义里面,使用了优先队列和实现了Comparable接口,这都是很好的编程技巧。

好啦,让我们来跑一个案例吧:

图的结构如下:

public static void main(String[] args) {
        MSTAllHW test = new MSTAllHW();
        Graph graph = new Graph(4);
        graph.setUndirectedEdge(0, 1, 1);
        graph.setUndirectedEdge(0, 2, 1);
        graph.setUndirectedEdge(0, 3, 1);
        graph.setUndirectedEdge(1, 2, 2);
        graph.setUndirectedEdge(1, 3, 1);
        graph.setUndirectedEdge(2, 3, 1);

        List<SpanningTree> minimumSpanningTreesList = test.getMinimumSpanningTrees(graph);
        int count = 0;
        for (SpanningTree tree : minimumSpanningTreesList) {
            System.out.println("Total weight = " + tree.getTotalWeight() + " Count:" + (++count));
            System.out.println(tree.toString());
            System.out.println("---next---");
        }
    }

运行结果为:
在这里插入图片描述

在这里插入图片描述

这道题的正确答案应该一共八个不同的最小生成树,但是输出了九个,原因是我最后去重没去好(第六个和第九个是同一颗最小生成树),如果你有什么好的去重办法,还请留言。

  • 4
    点赞
  • 10
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
以下是一个使用邻接矩阵存储无向图实现Prim算法的C++代码: ```c++ #include <iostream> #include <climits> using namespace std; #define MAX_SIZE 100 class Graph { private: int num_vertex; // 图顶点的个数 int adj_matrix[MAX_SIZE][MAX_SIZE]; // 邻接矩阵 int parent[MAX_SIZE]; // 最小生成树中每个节点的父节点 int key[MAX_SIZE]; // 用于Prim算法的关键字数组 bool visited[MAX_SIZE]; // 节点是否已经被访问 public: Graph(int num_vertex) { this->num_vertex = num_vertex; for (int i = 0; i < num_vertex; i++) { for (int j = 0; j < num_vertex; j++) { adj_matrix[i][j] = 0; } visited[i] = false; key[i] = INT_MAX; } } void add_edge(int i, int j, int weight) { adj_matrix[i][j] = weight; adj_matrix[j][i] = weight; } void prim() { key[0] = 0; parent[0] = -1; for (int i = 0; i < num_vertex - 1; i++) { // 找到未被访问的关键字最小的节点 int min_key = INT_MAX; int min_index = -1; for (int j = 0; j < num_vertex; j++) { if (!visited[j] && key[j] < min_key) { min_key = key[j]; min_index = j; } } visited[min_index] = true; // 更新相邻节点的关键字和父节点 for (int j = 0; j < num_vertex; j++) { if (adj_matrix[min_index][j] && !visited[j] && adj_matrix[min_index][j] < key[j]) { key[j] = adj_matrix[min_index][j]; parent[j] = min_index; } } } } void print_tree() { cout << "Edge\tWeight" << endl; for (int i = 1; i < num_vertex; i++) { cout << parent[i] << " - " << i << "\t" << adj_matrix[i][parent[i]] << endl; } } }; int main() { Graph g(7); g.add_edge(0, 1, 7); g.add_edge(0, 3, 5); g.add_edge(1, 2, 8); g.add_edge(1, 3, 9); g.add_edge(1, 4, 7); g.add_edge(2, 4, 5); g.add_edge(3, 4, 15); g.add_edge(3, 5, 6); g.add_edge(4, 5, 8); g.add_edge(4, 6, 9); g.add_edge(5, 6, 11); g.prim(); g.print_tree(); return 0; } ``` 以上代码中,我们在Graph类中增加了三个成员变量:parent、key和visited,分别用于存储最小生成树中每个节点的父节点、Prim算法的关键字数组和节点是否已经被访问。在prim函数中,我们首先将节点0标记为已访问,并将其关键字设为0,然后依次找到未被访问的关键字最小的节点,并将其标记为已访问。接着,我们更新未被访问的相邻节点的关键字和父节点。 最后,我们在print_tree函数中输出最小生成树的边和权重。在main函数中,我们创建了一个Graph对象,并向其中添加了一些边,然后调用prim函数求解最小生成树,并调用print_tree函数输出结果。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值