Spring Data通常采用Page作为分页对象,@PageableDefault Pageable作为入参,通过JPA查询能自动实现分页,但有些缓存场景中会直接查询出List然后手动排序,分页封装成Pape,于是就写了个小工具方便处理。
package com.leo.boot.jpa.stream;
import org.apache.commons.collections4.CollectionUtils;
import org.apache.commons.collections4.IterableUtils;
import org.springframework.beans.BeanUtils;
import org.springframework.data.domain.Page;
import org.springframework.data.domain.PageImpl;
import org.springframework.data.domain.Pageable;
import org.springframework.data.domain.Sort;
import org.springframework.data.domain.Sort.Direction;
import org.springframework.data.domain.Sort.Order;
import java.beans.PropertyDescriptor;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import java.util.stream.Stream;
import java.util.stream.StreamSupport;
import static java.util.Comparator.comparing;
import static java.util.Comparator.nullsLast;
import static java.util.stream.Collectors.toList;
public class PageUtil {
public static <T> Page<T> getPage(List<T> list, Pageable pageable) {
if (CollectionUtils.isEmpty(list)) {
return new PageImpl<>(Collections.emptyList(), pageable, 0);
}
@SuppressWarnings("unchecked")
Class<T> clazz = (Class<T>) list.get(0).getClass();
Stream<T> stream = list.stream();
if (!IterableUtils.isEmpty(pageable.getSort())) {
stream = stream.sorted(getComparator(pageable.getSort(), clazz));
}
List<T> slice = stream.skip((long) pageable.getPageNumber() * pageable.getPageSize())
.limit(pageable.getPageSize())
.collect(toList());
return new PageImpl<>(slice, pageable, list.size());
}
@SuppressWarnings("unchecked")
public static <T> Comparator<T> getComparator(Sort sort, Class<T> clazz) {
return StreamSupport.stream(sort.spliterator(), false)
.map(order -> getComparator(order, clazz))
.reduce(Comparator::thenComparing)
.orElse((Comparator<T>) Comparator.naturalOrder());
}
@SuppressWarnings({"unchecked", "rawtypes"})
public static <T> Comparator<T> getComparator(Order order, Class<T> clazz) {
PropertyDescriptor propertyDescriptor = BeanUtils.getPropertyDescriptor(clazz, order.getProperty());
if (propertyDescriptor == null || propertyDescriptor.getReadMethod() == null) {
throw new IllegalArgumentException(String.format("can't find read method for %s in %s", order.getProperty(), clazz.getName()));
}
Method readMethod = propertyDescriptor.getReadMethod();
Comparator<T> comparator = nullsLast(comparing(f -> {
try {
return (Comparable) readMethod.invoke(f);
} catch (IllegalAccessException | InvocationTargetException e) {
throw new IllegalArgumentException(e);
}
}, nullsLast(Comparator.naturalOrder())));
return Direction.ASC.equals(order.getDirection()) ? comparator : comparator.reversed();
}
}
TestCase
package com.leo.boot.jpa.stream;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.springframework.data.domain.Page;
import org.springframework.data.domain.PageRequest;
import org.springframework.data.domain.Pageable;
import org.springframework.data.domain.Sort;
import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;
class PageUtilTest {
List<User> users;
@BeforeEach
void setUp() {
User user1 = User.builder().id(1).name("张三").gender(Gender.MALE).build();
User user2 = User.builder().id(2).name("李四").gender(Gender.FEMALE).build();
User user3 = User.builder().id(3).name("王五").gender(Gender.MALE).build();
User user4 = User.builder().id(4).name("赵六").gender(Gender.FEMALE).build();
User user5 = User.builder().id(5).name("田七").gender(Gender.MALE).build();
User empty = new User();
users = Arrays.asList(empty, user1, user3, user5, user2, user4, null);
}
@Test
void getPage() {
Pageable pageable1 = PageRequest.of(0, 10);
Page<User> page1 = PageUtil.getPage(users, pageable1);
Assertions.assertEquals(7, page1.getTotalElements());
Assertions.assertEquals(7, page1.getNumberOfElements());
Assertions.assertEquals(1, page1.getTotalPages());
Pageable pageable2 = PageRequest.of(1, 3, Sort.Direction.ASC, "gender", "id");
Page<User> page2 = PageUtil.getPage(users, pageable2);
Assertions.assertEquals(7, page2.getTotalElements());
Assertions.assertEquals(3, page2.getNumberOfElements());
Assertions.assertEquals(3, page2.getTotalPages());
Assertions.assertEquals(2, page2.getContent().get(0).getId());
Assertions.assertEquals(4, page2.getContent().get(1).getId());
}
@Test
void getComparator() {
Sort sort = Sort.by(Sort.Direction.ASC, "gender").and(Sort.by(Sort.Direction.DESC, "id"));
List<User> users = this.users.stream().sorted(PageUtil.getComparator(sort, User.class)).collect(Collectors.toList());
Assertions.assertEquals(5, users.get(0).getId());
Assertions.assertEquals(3, users.get(1).getId());
Assertions.assertEquals(1, users.get(2).getId());
Assertions.assertEquals(4, users.get(3).getId());
Assertions.assertEquals(2, users.get(4).getId());
Assertions.assertNull(users.get(5).getId());
Assertions.assertNull(users.get(6));
}
@Test
void illegalArgument() {
PageRequest pageable = PageRequest.of(1, 3, Sort.Direction.ASC, "id", "illegalArgument");
Assertions.assertThrows(IllegalArgumentException.class, () -> PageUtil.getPage(users, pageable));
}
}
@AllArgsConstructor
@NoArgsConstructor
@Builder
@Data
class User {
private Integer id;
private Gender gender;
private String name;
}
enum Gender {
MALE, FEMALE;
}