在第一篇文章的后续部分,这次我们将编写一些更有用的自定义收集器:用于按给定的标准进行分组,采样输入,批处理和在固定大小的窗口上滑动。
分组(计数事件,直方图)
假设您有一些项目的集合,并且想要计算每个项目(相对于equals()
)出现在此集合中的次数。 这可以使用Apache Commons Collections的CollectionUtils.getCardinalityMap()
来实现。 此方法采用Iterable<T>
并返回Map<T, Integer>
,计算每个项目出现在集合中的次数。 但是,有时我们不使用equals()
而是按输入T
的任意属性分组。 例如,说我们有一个Person
对象列表,我们想计算男性与女性的数量(即Map<Sex, Integer>
)或年龄分布。 有一个内置的收集器Collectors.groupingBy(Function<T, K> classifier)
–但是,它从键返回一个映射到映射到该键的所有项。 看到:
import static java.util.stream.Collectors.groupingBy;
//...
final List<Person> people = //...
final Map<Sex, List<Person>> bySex = people
.stream()
.collect(groupingBy(Person::getSex));
这很有价值,但是在我们的案例中,不必要地构建了两个List<Person>
。 我只想知道人数。 没有内置的这种收集器,但是我们可以用相当简单的方式来组成它:
import static java.util.stream.Collectors.counting;
import static java.util.stream.Collectors.groupingBy;
//...
final Map<Sex, Long> bySex = people
.stream()
.collect(
groupingBy(Person::getSex, HashMap::new, counting()));
这个重载版本的groupingBy()
具有三个参数。 如前所述,第一个是键( 分类器 )功能。 第二个参数创建了一个新地图,我们很快就会看到它为什么有用的原因。 counting()
是一个嵌套的收集器,它将所有具有相同性别的人合并在一起-在我们的例子中,当他们到达时就简单地对其计数。 能够选择地图的实现方式非常有用,例如在构建年龄直方图时。 我们想知道在给定年龄下有多少人-但年龄值应排序:
final TreeMap<Integer, Long> byAge = people
.stream()
.collect(
groupingBy(Person::getAge, TreeMap::new, counting()));
byAge
.forEach((age, count) ->
System.out.println(age + ":\t" + count));
我们最终从年龄(排序)开始,使用TreeMap
来计算具有该年龄的人数。
采样,批处理和滑动窗口
Scala中的IterableLike.sliding()
方法允许通过固定大小的滑动窗口查看集合。 该窗口从开始处开始,并且在每次迭代中移动给定数量的项目。 Java 8中缺少的这种功能允许使用多种有用的运算符,例如计算移动平均值 ,将大集合分成批处理(与Guava中的Lists.partition()
比较)或对每个第n个元素进行采样。 我们将为Java 8实现具有类似行为的收集器。 让我们从单元测试开始,它应该简要描述我们想要实现的目标:
import static com.nurkiewicz.CustomCollectors.sliding
@Unroll
class CustomCollectorsSpec extends Specification {
def "Sliding window of #input with size #size and step of 1 is #output"() {
expect:
input.stream().collect(sliding(size)) == output
where:
input | size | output
[] | 5 | []
[1] | 1 | [[1]]
[1, 2] | 1 | [[1], [2]]
[1, 2] | 2 | [[1, 2]]
[1, 2] | 3 | [[1, 2]]
1..3 | 3 | [[1, 2, 3]]
1..4 | 2 | [[1, 2], [2, 3], [3, 4]]
1..4 | 3 | [[1, 2, 3], [2, 3, 4]]
1..7 | 3 | [[1, 2, 3], [2, 3, 4], [3, 4, 5], [4, 5, 6], [5, 6, 7]]
1..7 | 6 | [1..6, 2..7]
}
def "Sliding window of #input with size #size and no overlapping is #output"() {
expect:
input.stream().collect(sliding(size, size)) == output
where:
input | size | output
[] | 5 | []
1..3 | 2 | [[1, 2], [3]]
1..4 | 4 | [1..4]
1..4 | 5 | [1..4]
1..7 | 3 | [1..3, 4..6, [7]]
1..6 | 2 | [[1, 2], [3, 4], [5, 6]]
}
def "Sliding window of #input with size #size and some overlapping is #output"() {
expect:
input.stream().collect(sliding(size, 2)) == output
where:
input | size | output
[] | 5 | []
1..4 | 5 | [[1, 2, 3, 4]]
1..7 | 3 | [1..3, 3..5, 5..7]
1..6 | 4 | [1..4, 3..6]
1..9 | 4 | [1..4, 3..6, 5..8, 7..9]
1..10 | 4 | [1..4, 3..6, 5..8, 7..10]
1..11 | 4 | [1..4, 3..6, 5..8, 7..10, 9..11]
}
def "Sliding window of #input with size #size and gap of #gap is #output"() {
expect:
input.stream().collect(sliding(size, size + gap)) == output
where:
input | size | gap | output
[] | 5 | 1 | []
1..9 | 4 | 2 | [1..4, 7..9]
1..10 | 4 | 2 | [1..4, 7..10]
1..11 | 4 | 2 | [1..4, 7..10]
1..12 | 4 | 2 | [1..4, 7..10]
1..13 | 4 | 2 | [1..4, 7..10, [13]]
1..13 | 5 | 1 | [1..5, 7..11, [13]]
1..12 | 5 | 3 | [1..5, 9..12]
1..13 | 5 | 3 | [1..5, 9..13]
}
def "Sampling #input taking every #nth th element is #output"() {
expect:
input.stream().collect(sliding(1, nth)) == output
where:
input | nth | output
[] | 1 | []
[] | 5 | []
1..3 | 5 | [[1]]
1..6 | 2 | [[1], [3], [5]]
1..10 | 5 | [[1], [6]]
1..100 | 30 | [[1], [31], [61], [91]]
}
}
在Spock中使用数据驱动的测试,我成功地立即编写了近40个测试用例,简洁地描述了所有需求。 我希望这些对您来说都是清楚的,即使您以前从未看过这种语法。 我已经假定存在方便的工厂方法:
public class CustomCollectors {
public static <T> Collector<T, ?, List<List<T>>> sliding(int size) {
return new SlidingCollector<>(size, 1);
}
public static <T> Collector<T, ?, List<List<T>>> sliding(int size, int step) {
return new SlidingCollector<>(size, step);
}
}
收藏家接连收到物品的事实使工作变得更加困难。 当然,首先收集整个列表并在列表上滑动会比较容易,但是却很浪费。 让我们迭代构建结果。 我什至不假装通常可以并行执行此任务,因此我将使combiner()
未实现:
public class SlidingCollector<T> implements Collector<T, List<List<T>>, List<List<T>>> {
private final int size;
private final int step;
private final int window;
private final Queue<T> buffer = new ArrayDeque<>();
private int totalIn = 0;
public SlidingCollector(int size, int step) {
this.size = size;
this.step = step;
this.window = max(size, step);
}
@Override
public Supplier<List<List<T>>> supplier() {
return ArrayList::new;
}
@Override
public BiConsumer<List<List<T>>, T> accumulator() {
return (lists, t) -> {
buffer.offer(t);
++totalIn;
if (buffer.size() == window) {
dumpCurrent(lists);
shiftBy(step);
}
};
}
@Override
public Function<List<List<T>>, List<List<T>>> finisher() {
return lists -> {
if (!buffer.isEmpty()) {
final int totalOut = estimateTotalOut();
if (totalOut > lists.size()) {
dumpCurrent(lists);
}
}
return lists;
};
}
private int estimateTotalOut() {
return max(0, (totalIn + step - size - 1) / step) + 1;
}
private void dumpCurrent(List<List<T>> lists) {
final List<T> batch = buffer.stream().limit(size).collect(toList());
lists.add(batch);
}
private void shiftBy(int by) {
for (int i = 0; i < by; i++) {
buffer.remove();
}
}
@Override
public BinaryOperator<List<List<T>>> combiner() {
return (l1, l2) -> {
throw new UnsupportedOperationException("Combining not possible");
};
}
@Override
public Set<Characteristics> characteristics() {
return EnumSet.noneOf(Characteristics.class);
}
}
我花了很多时间来编写此实现,尤其是正确的finisher()
所以请不要害怕。 关键部分是一个buffer
,它可以收集项目,直到可以形成一个滑动窗口为止。 然后丢弃“最旧”的物品,并step
向前滑动窗口。 我对这种实现并不特别满意,但是测试正在通过。 sliding(N)
(与sliding(N, 1)
同义词)将允许计算N
项目的移动平均值。 sliding(N, N)
将输入分成大小为N
批次。 sliding(1, N)
获取第N个元素(样本)。 希望您会发现这个收藏家有用,喜欢!