先来讲述一下commons-fileupload实现上传的流程。
1.添加依赖。
<!-- commons 文件操作 -->
<dependency>
<groupId>commons-io</groupId>
<artifactId>commons-io</artifactId>
<version>2.5</version>
</dependency>
<!-- FILE UPLOAD begin -->
<dependency>
<groupId>commons-fileupload</groupId>
<artifactId>commons-fileupload</artifactId>
<version>1.3.1</version>
</dependency>
1.创建一个文件上传解析器
我们知道上传有一些限制条件,比如上传文件大小,显示的上传进度等,如果是多文件上传,还有总文件大小的限制条件。
由于是个组件,这些参数都需要开发者自己定义,因此,在创建文件上传解析器时,我才用了builder模式。
public class UploadBuilder {
public ServletFileUpload upload = new ServletFileUpload();
public UploadBuilder(Builder builder){
upload.setSizeMax(builder.sizeMax);
upload.setFileSizeMax(builder.fileSizeMax);
upload.setHeaderEncoding(builder.headerEncoding);
upload.setFileItemFactory(builder.factory);
upload.setProgressListener(builder.listener);
}
public static final class Builder {
//总文件大小上限,默认无限大
private long sizeMax = -1;
//单个文件的大小上限,默认无限大
private long fileSizeMax = -1;
//编码方式
private String headerEncoding;
private DiskFileItemFactory factory;
//进度条Listener
private UploadProgressListener listener;
public Builder(DiskFileItemFactory factory) {
this.factory = factory;
}
public Builder sizeMax(long value){
this.sizeMax = value;
return this;
}
public Builder fileSizeMax(long value){
this.fileSizeMax = value;
return this;
}
public Builder headerEncoding(String value){
this.headerEncoding = value;
return this;
}
public Builder listener(UploadProgressListener listener){
this.listener = listener;
return this;
}
public ServletFileUpload build() {
return new UploadBuilder(this).upload;
}
}
}
由于需求要展示上传文件的进度条,因此在这里设置了进度条Listener
public class UploadProgressListener implements ProgressListener {
private HttpSession session;
public UploadProgressListener(HttpServletRequest request) {
this.session = request.getSession();
}
@Override
public void update(long pBytesRead, long pContentLength, int pItems) {
/**
* 进度条信息,自行定义
*/
//进度条回写到session中
session.setAttribute("processInfo", info);
}
}
2.使用ServletFileUpload解析器解析上传数据
public default ParseUploadInfoModel parseFileItems(ServletFileUpload upload,HttpServletRequest request){
List<FileItem> items = null;
if (!ServletFileUpload.isMultipartContent(request)) {
return new ParseUploadInfoModel(UploadFileCodeEnmu.NOT_MULTIPART_CONTENT,null);
}
try {
items = upload.parseRequest(request);
} catch (FileUploadBase.SizeLimitExceededException e){
return new ParseUploadInfoModel(UploadFileCodeEnmu.Q_EXCEED_SIZE_LIMIT,null);
} catch (FileUploadBase.FileSizeLimitExceededException e) {
return new ParseUploadInfoModel(UploadFileCodeEnmu.F_EXCEED_SIZE,null);
}catch (FileUploadException e) {
}
return new ParseUploadInfoModel(UploadFileCodeEnmu.SUCCESS,items);
}
这里获取到了解析数据,但是由于框架使用的是SpringBoot,其中spring mvc在处理请求的时候,有一个判断是否是上传请求,如果是上传请求,则进行解析转为MultipartHttpServletRequest请求。这样一来,我通过以上方式去获取请求参数时,就没法获取到了。源码分析如下:
springmvc判断是否是上传请求:
protected HttpServletRequest checkMultipart(HttpServletRequest request) throws MultipartException {
if (this.multipartResolver != null && this.multipartResolver.isMultipart(request)) {
if (WebUtils.getNativeRequest(request, MultipartHttpServletRequest.class) != null) {
logger.debug("Request is already a MultipartHttpServletRequest - if not in a forward, " +
"this typically results from an additional MultipartFilter in web.xml");
}
else if (hasMultipartException(request) ) {
logger.debug("Multipart resolution failed for current request before - " +
"skipping re-resolution for undisturbed error rendering");
}
else {
try {
return this.multipartResolver.resolveMultipart(request);
}
catch (MultipartException ex) {
if (request.getAttribute(WebUtils.ERROR_EXCEPTION_ATTRIBUTE) != null) {
logger.debug("Multipart resolution failed for error dispatch", ex);
// Keep processing error dispatch with regular request handle below
}
else {
throw ex;
}
}
}
}
// If not returned before: return original request.
return request;
}
主要通过this.multipartResolver.isMultipart(request)来判断的,这里MultipartResolver.class接口有两个实现类:CommonsMultipartResolver.class和StandardServletMultipartResolver.class。我们看看CommonsMultipartResolver.class的实现。
@Override
public boolean isMultipart(HttpServletRequest request) {
return ServletFileUpload.isMultipartContent(request);
}
ServletFileUpload.isMultipartContent(request);
public static final boolean isMultipartContent(
HttpServletRequest request) {
if (!POST_METHOD.equalsIgnoreCase(request.getMethod())) {
return false;
}
return FileUploadBase.isMultipartContent(new ServletRequestContext(request));
}
public static final boolean isMultipartContent(RequestContext ctx) {
String contentType = ctx.getContentType();
if (contentType == null) {
return false;
}
if (contentType.toLowerCase(Locale.ENGLISH).startsWith(MULTIPART)) {
return true;
}
return false;
}
其实就是根据两个条件:1.一个是请求方式是POST。2.请求的contentType是multipart/,这个是前端页面定义上传的时候就必须定义好的。
那我们这里通过这种方式判断是否是上传请求,如果是上传请求,就进行解析并封装成MultipartHttpServletRequest.class。
@Override
public MultipartHttpServletRequest resolveMultipart(final HttpServletRequest request) throws MultipartException {
Assert.notNull(request, "Request must not be null");
if (this.resolveLazily) {
return new DefaultMultipartHttpServletRequest(request) {
@Override
protected void initializeMultipart() {
MultipartParsingResult parsingResult = parseRequest(request);
setMultipartFiles(parsingResult.getMultipartFiles());
setMultipartParameters(parsingResult.getMultipartParameters());
setMultipartParameterContentTypes(parsingResult.getMultipartParameterContentTypes());
}
};
}
else {
MultipartParsingResult parsingResult = parseRequest(request);
return new DefaultMultipartHttpServletRequest(request, parsingResult.getMultipartFiles(),
parsingResult.getMultipartParameters(), parsingResult.getMultipartParameterContentTypes());
}
}
protected MultipartParsingResult parseRequest(HttpServletRequest request) throws MultipartException {
String encoding = determineEncoding(request);
FileUpload fileUpload = prepareFileUpload(encoding);
try {
List<FileItem> fileItems = ((ServletFileUpload) fileUpload).parseRequest(request);
return parseFileItems(fileItems, encoding);
}
catch (FileUploadBase.SizeLimitExceededException ex) {
throw new MaxUploadSizeExceededException(fileUpload.getSizeMax(), ex);
}
catch (FileUploadBase.FileSizeLimitExceededException ex) {
throw new MaxUploadSizeExceededException(fileUpload.getFileSizeMax(), ex);
}
catch (FileUploadException ex) {
throw new MultipartException("Failed to parse multipart servlet request", ex);
}
}
这样一来,确实没法通过commons-fileupload获取到上传信息了。但是,我们找到了事情的根源,就是如果这里判断是上传请求了,再解析的。如果我们判断不是上传请求,是不需要将请求转化并解析的。
解决办法
这里定义一个CommonsMultipartResolverExtension 继承CommonsMultipartResolver ,然后注入到spingmvc容器中,让springmvc在处理请求的时候,走我自己这个定义的解析器。这样就可以避开判断提前将请求进行解析了。
@Configuration
public class CommonsMultipartResolverExtension extends CommonsMultipartResolver {
@Override
public boolean isMultipart(HttpServletRequest request) {
if (request.getRequestURI().contains("//file/upload")) {
return false;
}
return super.isMultipart(request);
}
@Override
public MultipartHttpServletRequest resolveMultipart(final HttpServletRequest request){
return super.resolveMultipart(request);
}
}
注入到容器中:
public class WebMvcConfig implements WebMvcConfigurer {
@Bean(name = "multipartResolver")
public MultipartResolver multipartResolver() {
CommonsMultipartResolver multipartResolver = new CommonsMultipartResolverExtension();
return multipartResolver;
}
}
这样我们就解决了springmvc提前解析了上传请求的问题。
3.开始真正意义的上传了。
这里,我们要知道,上传文件到指定地方,这个指定地方,有可能是本地某个文件夹下,也有可能是fastDFS,或者其他服务器地址上。既然是组件,那么就需要具有很好的扩展性。为了使这个组件具有很好的扩展性,我才用SPI机制来实现。
定义一个扩展点加载器
/**
* 扩展点加载器,扩展点的查找,校验,加载等核心逻辑的实现类
*/
public class ExtensionLoader<T> {
// =====================静态变量或常量============================
private static final String SERVICES_DIRECTORY = "META-INF/services/";
private static final Pattern NAME_SEPARATOR = Pattern.compile("\\s*[,]+\\s*");
private static final ConcurrentMap<Class<?>, ExtensionLoader<?>> EXTENSION_LOADERS = new ConcurrentHashMap<Class<?>, ExtensionLoader<?>>();
// 实例化缓存
private static final ConcurrentMap<Class<?>, Object> EXTENSION_INSTANCES = new ConcurrentHashMap<Class<?>, Object>();
// =======================ExtensionLoader的属性==========================
private final Class<?> type;
private final ExtensionFactory objectFactory;
// ========================常用工具========================
private final ConcurrentMap<Class<?>, String> cachedNames = new ConcurrentHashMap<Class<?>, String>();
private final Holder<Map<String, Class<?>>> cachedClasses = new Holder<Map<String, Class<?>>>();
// 实例缓存
private final ConcurrentMap<String, Holder<Object>> cachedInstances = new ConcurrentHashMap<String, Holder<Object>>();
// 私有构造器,单例模式设计
private ExtensionLoader(Class<?> type) {
this.type = type;
// 需要创建扩展点实例
this.objectFactory = (type == ExtensionFactory.class ? null : ExtensionLoader.getExtensionLoader(ExtensionFactory.class).getExtension());
}
/**
* 获取ExtensionFactory
* @return
*/
@SuppressWarnings("unchecked")
public T getExtension() {
return (T) new SpringExtensionFactory();
}
/**
*
* @param type
* 必须是接口,而且必须有@SPI注解
* @return
*/
@SuppressWarnings("unchecked")
public static <T> ExtensionLoader<T> getExtensionLoader(Class<T> type) {
if (type == null)
throw new IllegalArgumentException("Extension type == null");
if (!type.isInterface()) {
throw new IllegalArgumentException("Extension type(" + type + ") is not interface!");
}
// 只处理带有@SPI的接口
if (!withExtensionAnnotation(type)) {
throw new IllegalArgumentException("Extension type(" + type
+ ") is not extension, because WITHOUT @"
+ SPI.class.getSimpleName() + " Annotation!");
}
// 先从静态缓存中取,缓存没有则创建,然后放入缓存,从缓存中获取返回
ExtensionLoader<T> loader = (ExtensionLoader<T>) EXTENSION_LOADERS.get(type);
if (loader == null) {
EXTENSION_LOADERS.putIfAbsent(type, new ExtensionLoader<T>(type));
loader = (ExtensionLoader<T>) EXTENSION_LOADERS.get(type);
}
return loader;
}
/**
* 判断class是否含有@SPI
*
* @param type
* @return
*/
private static <T> boolean withExtensionAnnotation(Class<T> type) {
return type.isAnnotationPresent(SPI.class);
}
/**
* 根据名称获取Extension
*
* @param name
* @return
*/
@SuppressWarnings("unchecked")
public T getExtension(String name) {
// 先在缓存中找,缓存中没有,再创建extension.
if (name == null || name.length() == 0)
throw new IllegalArgumentException("Extension name == null");
Holder<Object> holder = cachedInstances.get(name);
if (holder == null) {
cachedInstances.putIfAbsent(name, new Holder<Object>());
holder = cachedInstances.get(name);
}
Object instance = holder.get();
if (instance == null) {
synchronized (holder) {
instance = holder.get();
instance = createExtension(name);
holder.set(instance);
}
}
return (T) instance;
}
/**
* 根据名称创建Extension
*
* @param name
* @return
*/
@SuppressWarnings("unchecked")
private T createExtension(String name) {
//加载所需的class,然后进行实例化
Map<String, Class<?>> maps = getExtensionClasses();
Class<?> clazz = maps.get(name);
if (clazz == null) {
throw new IllegalStateException();
}
try {
T instance = (T) EXTENSION_INSTANCES.get(clazz);
if (instance == null) {
EXTENSION_INSTANCES.putIfAbsent(clazz, (T) clazz.newInstance());
instance = (T) EXTENSION_INSTANCES.get(clazz);
}
injectExtension(instance);//注入依赖
return instance;
} catch (Throwable t) {
throw new IllegalStateException("Extension instance(name: " + name + ", class: " +
type + ") could not be instantiated: " + t.getMessage(), t);
}
}
/**
* 注入依赖,循环依赖的
* @param instance
* @return
*/
private T injectExtension(T instance) {
try {
if (objectFactory != null) {
for (Method method : instance.getClass().getMethods()) {
if (method.getName().startsWith("set")
&& method.getParameterTypes().length == 1
&& Modifier.isPublic(method.getModifiers())) {
Class<?> pt = method.getParameterTypes()[0];
try {
String property = method.getName().length() > 3 ? method.getName().substring(3, 4).toLowerCase() + method.getName().substring(4) : "";
Object object = objectFactory.getExtension(pt, property);
if (object != null) {
method.invoke(instance, object);
}
} catch (Exception e) {
}
}
}
}
} catch (Exception e) {
}
return instance;
}
/**
* 加载class
*
* @return
*/
private Map<String, Class<?>> getExtensionClasses() {
Map<String, Class<?>> classes = cachedClasses.get();
if (classes == null) {
synchronized (cachedClasses) {
classes = cachedClasses.get();
if (classes == null) {
classes = loadExtensionClasses();
cachedClasses.set(classes);
}
}
}
return classes;
}
private Map<String, Class<?>> loadExtensionClasses() {
final SPI defaultAnnotation = type.getAnnotation(SPI.class);
if(defaultAnnotation != null) {
String value = defaultAnnotation.value();
if(value != null && (value = value.trim()).length() > 0) {
String[] names = NAME_SEPARATOR.split(value);
if(names.length > 1) {
throw new IllegalStateException("more than 1 default extension name on extension " + type.getName()
+ ": " + Arrays.toString(names));
}
}
}
Map<String, Class<?>> extensionClasses = new HashMap<String, Class<?>>();
loadFile(extensionClasses, SERVICES_DIRECTORY);
return extensionClasses;
}
/**
* 从指定目录加载文件
* @param extensionClasses
* @param dir
*/
private void loadFile(Map<String, Class<?>> extensionClasses, String dir) {
String fileName = dir + type.getName();
try {
Enumeration<java.net.URL> urls;
ClassLoader classLoader = ExtensionLoader.class.getClassLoader();
if (classLoader != null) {
urls = classLoader.getResources(fileName);
} else {
urls = ClassLoader.getSystemResources(fileName);
}
if (urls != null) {
while (urls.hasMoreElements()) {
java.net.URL url = urls.nextElement();
try {
BufferedReader reader = new BufferedReader(
new InputStreamReader(url.openStream(), "utf-8"));
try {
// 1. 逐行读取配置文件,提取出扩展名或扩展类路径
String line = null;
while ((line = reader.readLine()) != null) {
final int ci = line.indexOf('#');
if (ci >= 0)
line = line.substring(0, ci);
line = line.trim();
if (line.length() > 0) {
try {
String name = null;
int i = line.indexOf('=');
if (i > 0) {
name = line.substring(0, i).trim();
line = line.substring(i + 1).trim();
}
// 2. 利用Class.forName方法进行类加载
if (line.length() > 0) {
Class<?> clazz = Class.forName(line, true, classLoader);
clazz.getConstructor();
if (name == null || name.length() == 0) {
if (clazz.getSimpleName().length() > type.getSimpleName().length()
&& clazz.getSimpleName().endsWith(type.getSimpleName())) {
name = clazz.getSimpleName()
.substring(0,clazz.getSimpleName().length()- type.getSimpleName().length())
.toLowerCase();
} else {
throw new IllegalStateException(
"No such extension name for the class "
+ clazz.getName()
+ " in the config "
+ url);
}
}
String[] names = NAME_SEPARATOR.split(name);
if (names != null && names.length > 0) {
for (String n : names) {
if (!cachedNames.containsKey(clazz)) {
cachedNames.put(clazz,n);
}
Class<?> c = extensionClasses.get(n);
if (c == null) {
extensionClasses.put(n,clazz);
} else if (c != clazz) {
throw new IllegalStateException(
"Duplicate extension "
+ type.getName()
+ " name "
+ n
+ " on "
+ c.getName()
+ " and "
+ clazz.getName());
}
}
}
}
} catch (Throwable t) {
}
}
} // end of while read lines
} finally {
reader.close();
}
} catch (Throwable t) {
}
} // end of while urls
}
} catch (Throwable t) {
}
}
@Override
public String toString() {
return this.getClass().getName() + "[" + type.getName() + "]";
}
}
采用SpringExtensionFactory作为扩展工厂
public class SpringExtensionFactory implements ExtensionFactory {
private static final Set<ApplicationContext> contexts = new HashSet<ApplicationContext>();
public static void addApplicationContext(ApplicationContext context) {
contexts.add(context);
}
public static void removeApplicationContext(ApplicationContext context) {
contexts.remove(context);
}
@SuppressWarnings("unchecked")
public <T> T getExtension(Class<T> type, String name) {
for (ApplicationContext context : contexts) {
if (context.containsBean(name)) {
Object bean = context.getBean(name);
if (type.isInstance(bean)) {
return (T) bean;
}
}
}
return null;
}
}
@SPI
public interface ExtensionFactory {
/**
* 获取extension
* @param type
* @param name
* @return
*/
<T> T getExtension(Class<T> type, String name);
}
我们根据扩展点加载器知道,只处理@SPI注解的接口。因此我们这些扩展点的接口必须加上@SPI,比如我这里的上传接口:
/**
* 上传文件处理器
*/
@SPI
public interface FileUploadHandler {
/**
* 处理 普通数据项
* @param item
* @param data
* @return
*/
public <T> T handleFormField(FileItem item,T data );
/**
* 处理 文件数据项
* @param item
* @param data
* @return
*/
public <T> T handleFileField(FileItem item,T data);
}
而我真正的实现类有两种方式,一种是上传到本地,一种是上传到fastdfs上。
文件上传到fastdfs上
/**
-
文件上传到fastdfs上
*/
public class FastdfsUploadHandler implements FileUploadHandler {@Override
public T handleFormField(FileItem item, T data) {
// TODO Auto-generated method stub
return null;
}@Override
public T handleFileField(FileItem item, T data) {
System.out.println(“我传到fastdfs上了哦!”);
return null;
}
}
文件上传到本地存储内存
/**
* 文件上传到本地存储内存
*/
public class FileUploadToLocalHandler implements FileUploadHandler {
public final String SPLIT_FILE_NAME = "\\";
@Value("fileUploadPath")
private String fileUploadPath;
@Override
public <T> T handleFormField(FileItem item, T data) {
//处理表单数据
return null;
}
@SuppressWarnings("unchecked")
@Override
public <T> T handleFileField(FileItem item, T data) {
String fileName = item.getName();
if (!StringUtils.isNotBlank(fileName)) {
return (T) new UploadResultModel(UploadFileCodeEnmu.Q_TYPE_DENIED, null);
}
//注意:不同的浏览器提交的文件名是不一样的,有些浏览器提交上来的文件名是带有路径的,如: c:\a\b\1.txt,而有些只是单纯的文件名,如:1.txt
//处理获取到的上传文件的文件名的路径部分,只保留文件名部分
fileName.substring(fileName.lastIndexOf(SPLIT_FILE_NAME) + 1);
if(!FileUtils.isFile(fileName)){
return (T) new UploadResultModel(UploadFileCodeEnmu.Q_TYPE_DENIED, null);
}
// 获取item中的上传文件的输入流
InputStream in;
String pathName = null;
try {
in = item.getInputStream();
//创建一个文件输出流
String newFileName = UUID.randomUUID().toString() + "_" + fileName;
String rootPath = fileUploadPath;
/**
* 文件路径不存在则需要创建文件路径
*/
File filePath = new File(rootPath);
if (!filePath.exists()) {
filePath.mkdirs();
}
pathName = fileUploadPath + File.separator + newFileName;
FileOutputStream out = new FileOutputStream(pathName);
byte buffer[] = new byte[1024];
int len = 0;
while ((len = in.read(buffer))>0) {
out.write(buffer,0,len);
}
in.close();
out.close();
item.delete();
} catch (IOException e1) {
e1.printStackTrace();
}
return (T) new UploadResultModel(UploadFileCodeEnmu.SUCCESS, pathName);
}
}
实现是有了,可是如何找到开发者真正想要使用的上传方式对应的class呢?
这里我们在扩展点加载器中已经定义了查找文件的地址就是:META-INF/services/。扫描这个文件夹下的文件,命名方式就是借口的全路径,比如:com.wang.tool.file.upload.handler.FileUploadHandler
local=com.wang.tool.file.upload.handler.impl.FileUploadToLocalHandler
fastdfs=com.wang.tool.file.upload.handler.impl.FastdfsUploadHandler
其中,这里的格式是name属性=具体的实现类
这样一来,我们只需要为扩展点加载器传入具体的接口类型以及com.wang.tool.file.upload.handler.FileUploadHandler文件中定义的name属性就可以了。
上传
@Override
public UploadResultModel upload(List<FileItem> list, Object obj,String serviceName) {
UploadResultModel result = null;
//文件上传到本地
FileUploadHandler localHandler = ExtensionLoader.getExtensionLoader(FileUploadHandler.class).getExtension(serviceName);
for (FileItem item : list) {
if (item.isFormField()) {
localHandler.handleFormField(item, obj);
} else {
result = (UploadResultModel) localHandler.handleFileField(item, obj);
}
}
return result;
}
这里我们传入的serviceName名称,就需要开发者自己在com.wang.tool.file.upload.handler.FileUploadHandler文件中定义的name属性,比如:如果serviceName=“local”,便是开发者需要使用com.wang.tool.file.upload.handler.impl.FileUploadToLocalHandler将文件上传到本地。
如果serviceName=“fastdfs”,便是开发者需要使用com.wang.tool.file.upload.handler.impl.FastdfsUploadHandler将文件上传到fastdfs上。
这样服务端代码就实现了上传操作,而且具有很好的可扩展性。
但是真正在上传的时候,又遇到一个问题,就是文件大小超过2M,前端接受不到我后端返回的错误信息。原来springboot集成了tomcat,而tomcat默认设置上传文件大小是2M,如果超过2M,连接就被中断,所以前端是收不到任何响应的。
private int maxSwallowSize = 2 * 1024 * 1024;
public int getMaxSwallowSize() { return maxSwallowSize; }
public void setMaxSwallowSize(int maxSwallowSize) {
this.maxSwallowSize = maxSwallowSize;
}
因此,我们需要自己设置tomcat关于上传文件大小的限制。代码如下所示:
@Bean
public TomcatServletWebServerFactory tomcatEmbedded() {
TomcatServletWebServerFactory tomcat = new TomcatServletWebServerFactory();
tomcat.addConnectorCustomizers((TomcatConnectorCustomizer) connector -> {
if ((connector.getProtocolHandler() instanceof AbstractHttp11Protocol<?>)) {
// -1 means unlimited
((AbstractHttp11Protocol<?>) connector.getProtocolHandler())
.setMaxSwallowSize(-1);
}
});
return tomcat;
}
好了,springboot框架下基于基于commons-fileupload实现上传,下载组件就实现了,这里有三个需要注意的地方:
上传文件请求在springmvc处理的时候已经经过解析,所以根据commons-fileupload去解析无法获取到上传数据。
springboot的内置tomcat对上传文件大小的限制默认为2M
SPI机制能够实现高扩展