这个代码参考了
分支限界_01背包问题_Java实现_ljming的专栏-CSDN博客_分支限界法01背包问题
Good.java
public class Good implements Comparable<Good> {
private int weight;
private int value;
private double unitValue;
public Good(int weight, int value) {
this.weight = weight;
this.value = value;
this.unitValue = (weight == 0) ? 0 : (double) value / weight;
}
public int getWeight() {
return weight;
}
public void setWeight(int weight) {
this.weight = weight;
}
@Override
public String toString() {
return "Good{" +
"weight=" + weight +
", value=" + value +
", unitValue=" + unitValue +
'}';
}
public int getValue() {
return value;
}
public void setValue(int value) {
this.value = value;
}
public double getUnitValue() {
return unitValue;
}
@Override
public int compareTo(Good snapsack) {
return Double.compare(unitValue, snapsack.getUnitValue());
}
}
Bfs.java
import java.util.*;
public class Bfs {
private Good[] bags;
private int totalWeight;
private int n;
private int bestValue;
private Vector<Integer> path;
public Bfs(Good[] bags, int totalWeight) {
super();
this.bags = bags;
this.totalWeight = totalWeight;
this.n = bags.length;
Arrays.sort(bags, Collections.reverseOrder());
}
public void printSolution() {
System.out.println("Max value: " + bestValue);
for (int i = 0; i < n; i++) {
System.out.println(bags[i] + " " + path.get(i));
}
}
public void solve() {
PriorityQueue<Node> maxheap = new PriorityQueue<Node>(new Comparator<Node>() {
@Override
public int compare(Node o1, Node o2) {
if (o2.upboundValue - o1.upboundValue == 0) {
return o2.index - o1.index;
} else {
return Double.compare(o2.upboundValue, o1.upboundValue);
}
}
});
Node node = new Node(0, 0, 0, 0, new Vector<Integer>());
while (node.index != n) {
Node left = new Node(node.currWeight + bags[node.index].getWeight(), node.currValue + bags[node.index].getValue(), node.index + 1, 1, node.path);
if (left.currWeight <= totalWeight) {
bestValue = Math.max(left.currValue, bestValue);
maxheap.add(left);
}
Node right = new Node(node.currWeight, node.currValue, node.index + 1, 0, node.path);
if (right.upboundValue >= bestValue) {
maxheap.add(right);
}
node = maxheap.poll();
}
path = node.path;
printSolution();
}
class Node {
private int currWeight;
private int currValue;
private double upboundValue;
private int index;
private Vector<Integer> path;
public Node(int currWeight, int currValue, int index, int side, Vector<Integer> path) {
this.currWeight = currWeight;
this.currValue = currValue;
this.index = index;
this.upboundValue = getUpboundValue(currWeight, currValue, index);
this.path = (Vector<Integer>) path.clone();
if (index != 0)
this.path.add(side);
}
private double getUpboundValue(int currWeight, int currValue, int index) {
int surplusWeight = totalWeight - currWeight;
double value = currValue;
int i = index;
while (i < n && bags[i].getWeight() <= surplusWeight) {
surplusWeight -= bags[i].getWeight();
value += bags[i].getValue();
i++;
}
if (i < n) {
value += bags[i].getUnitValue() * surplusWeight;
}
return value;
}
}
}
Main.java
public class Main {
public static int[] weight = {2, 1, 3, 2};
public static int[] value = {12, 10, 20, 15};
public static int goodNum = weight.length;
public static int totalWeight = 25;
public static void main(String[] args) {
Good[] bags = new Good[goodNum];
for (int i = 0; i < goodNum; i++)
bags[i] = new Good(weight[i], value[i]);
Bfs problem = new Bfs(bags, totalWeight);
problem.solve();
}
}