大家都知道Java的servlet分get和post请求方式,在servlet或者在集成了springMVC、Struts2的框架的情况下获取请求的参数。那么有时候我们需要在拦截其中获取ServletRequest的参数就不那么容易了。因为在ServletRequst中,如果是get请求我们可以通过request.getParameter("")来获取get的参数或者是form提交的post参数,但是如果是ajax提交的post请求的application/json请求,那么在get的时候就无法获取到值了,有人会想我通过request的请求流来解析json文本,这样做是可以的,但 是有个问题就是如果在拦截其中调用了ServletRequest的getInputStream方法,那么在后面的servlet中或者你集成的框架中的control层就无法调用getInputStream方法来解析获取参数了。
有了上面的疑问,我们就有了分析,解决办法的途径。通过对HttpServletRequest的分析结合资料,最后得出结论就是改写ServletRequst的getInputStream方法便可以解决问题。我们可以分析一下HttpServletRequest的中的stream只能被read一次,那么我们可以在filter中调用getInputSteam获取json字符串,然后通过获取的json文本去生成新的stream来给ServletRequest,后面的control就可以继续获取stream(我们自己用json文本生成)。有了这个思路我们就来看看代码。
一.改写ServletRequest
PostServletRequest.java
import java.io.BufferedReader; import java.io.IOException; import java.io.InputStreamReader; import javax.servlet.ServletInputStream; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletRequestWrapper; import org.apache.commons.lang3.StringUtils; public class PostServletRequest extends HttpServletRequestWrapper { private String body=null; /** * Constructs a request object wrapping the given request. * @param request * @throws IllegalArgumentException if the request is null */ public PostServletRequest(HttpServletRequest request,String body) { super(request); this.body=body; } @Override public ServletInputStream getInputStream() throws IOException { ServletInputStream inputStream = null; if(StringUtils.isNotEmpty(body)){ inputStream = new PostServletInputStream(body); } return inputStream; } @Override public BufferedReader getReader() throws IOException { String enc = getCharacterEncoding(); if(enc == null) enc = "UTF-8"; return new BufferedReader(new InputStreamReader(getInputStream(), enc)); } }
二.ServletInputStream的改写
PostServletInputStream.java
import java.io.ByteArrayInputStream; import java.io.IOException; import java.io.InputStream; import javax.servlet.ServletInputStream; public class PostServletInputStream extends ServletInputStream { private InputStream inputStream; private String body ;//解析json之后的文本 public PostServletInputStream(String body) throws IOException { this.body=body; inputStream = null; } private InputStream acquireInputStream() throws IOException { if(inputStream == null) { inputStream = new ByteArrayInputStream(body.getBytes());//通过解析之后传入的文本生成inputStream以便后面control调用 } return inputStream; } public void close() throws IOException { try { if(inputStream != null) { inputStream.close(); } } catch(IOException e) { throw e; } finally { inputStream = null; } } public int read() throws IOException { return acquireInputStream().read(); } public boolean markSupported() { return false; } public synchronized void mark(int i) { throw new UnsupportedOperationException("mark not supported"); } public synchronized void reset() throws IOException { throw new IOException(new UnsupportedOperationException("reset not supported")); } }
三.在filter中的调用
import java.io.BufferedReader; import java.io.IOException; import java.io.InputStream; import java.io.InputStreamReader; import java.util.Arrays; import java.util.HashMap; import java.util.List; import java.util.Map; import javax.servlet.FilterChain; import javax.servlet.FilterConfig; import javax.servlet.ServletException; import javax.servlet.ServletRequest; import javax.servlet.ServletResponse; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; import org.apache.commons.lang3.StringUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; /** * 过滤器 */ public class UrlFilter extends AbstractWebFilter { private final static Logger LOGGER = LoggerFactory.getLogger(UrlFilter.class); private final static String MERID_WHITE_LIST = "merABC:test/login.do,test/getList.do;merDEF:test/hello.do,test/greet.do"; private static Map<String, List<String>> merIdWhiteListMap = new HashMap<String, List<String>>(); @Override public void init(FilterConfig filterConfig) throws ServletException { String[] merIdWhiteList = MERID_WHITE_LIST.split(";"); if(merIdWhiteList != null && merIdWhiteList.length > 0) { int merIdSize = merIdWhiteList.length; for(int i=0;i<merIdSize;i++) { String merIdUrls = merIdWhiteList[i]; String[] merIdUrl = merIdUrls.split(":"); if(merIdUrl != null && merIdUrl.length == 2) { String merId = merIdUrl[0]; String urls = merIdUrl[1]; String[] urlList = urls.split(","); if(urlList != null && urlList.length > 0) { List<String> lists = Arrays.asList(urlList); merIdWhiteListMap.put(merId, lists); } } } } LOGGER.info("merIdWhiteListMap:{}", JsonUtil.toJsonStr(merIdWhiteListMap)); } @Override public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain) throws IOException, ServletException { HttpServletRequest req = (HttpServletRequest) request; HttpServletResponse rsp = (HttpServletResponse) response; String url = getRequestPath(req); if(url.endsWith(".do")) { //解析post的json参数,进一步根据请求入参校验 String body = getBody((HttpServletRequest)request); if(StringUtils.isNotEmpty(body)) { String headMerId = (String) ParamsReflectUtil.getFieldValueRecursive(body, "headMerId"); if(StringUtils.isNotEmpty(headMerId)) { List<String> urlList = merIdWhiteListMap.get(headMerId); LOGGER.info("urlList:{}, curretn url:{}", urlList, url); if(urlList.contains(url)) { //使用解析数据重新生成ServletRequest,供doChain调用 request = getRequest(request,body); chain.doFilter(request, response); }else { LOGGER.info("非法url请求: {},请求入参:{}", url, body); forbiddenJson(rsp); } }else { chain.doFilter(request, response); } }else { chain.doFilter(request, response); } }else { LOGGER.info("非法url请求: {}", url); forbiddenJson(rsp); } } @Override public void destroy() { } /** * 返回ajax信息 */ private void forbiddenJson(HttpServletResponse httpResponse) throws IOException { Map<String,Object> param = new HashMap<String,Object>(); param.put("error", "403"); httpResponse.setStatus(403); httpResponse.setCharacterEncoding("utf-8"); httpResponse.setContentType("application/json"); httpResponse.getWriter().print(JsonUtil.toJsonStr(param)); } private String getBody(HttpServletRequest request) throws IOException { String body = null; StringBuilder stringBuilder = new StringBuilder(); BufferedReader bufferedReader = null; try { InputStream inputStream = request.getInputStream(); if (inputStream != null) { bufferedReader = new BufferedReader(new InputStreamReader(inputStream)); char[] charBuffer = new char[128]; int bytesRead = -1; while ((bytesRead = bufferedReader.read(charBuffer)) > 0) { stringBuilder.append(charBuffer, 0, bytesRead); } } else { stringBuilder.append(""); } } catch (IOException ex) { throw ex; } finally { if (null != bufferedReader) { bufferedReader.close(); } } body = stringBuilder.toString(); return body; } /** * 将post解析过后的request进行封装改写 * @param request * @param body * @return */ private ServletRequest getRequest(ServletRequest request, String body) { String enctype = request.getContentType(); if (StringUtils.isNotEmpty(enctype) && enctype.contains("application/json")) { return new PostServletRequest((HttpServletRequest) request, body); } return request; } }
AbstractWebFilter.java
import java.io.IOException; import javax.servlet.Filter; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; import org.slf4j.Logger; import org.slf4j.LoggerFactory;public abstract class AbstractWebFilter implements Filter { private static Logger log = LoggerFactory.getLogger(AbstractWebFilter.class); /** * 获取客户端ip */ protected String getRemortIP(HttpServletRequest request) { if (request.getHeader("x-forwarded-for") == null) { return request.getRemoteAddr(); } return request.getHeader("x-forwarded-for"); } /** * 获取请求的url不含上下文的路径(并且路径开头不带"/") */ protected String getRequestPath(HttpServletRequest request) { String requestPath = request.getServletPath(); if (requestPath != null && requestPath.startsWith("/")) { return requestPath.substring(1); } return requestPath; } /** * 重定向到登录页面 */ private void forbiddenRedirect(HttpServletResponse httpResponse) throws IOException { String logoutUrl = PropertiesUtils.getString("logoutUrl"); httpResponse.sendRedirect(logoutUrl); } }
ParamsReflectUtil.java
import java.util.Iterator; import java.util.Map; import java.util.TreeMap; import org.apache.commons.lang3.StringUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import com.alibaba.fastjson.JSON; import com.alibaba.fastjson.JSONArray; import com.alibaba.fastjson.JSONObject; /** * 属性(对象)值反射获取工具类 */ public class ParamsReflectUtil { private final static Logger logger = LoggerFactory.getLogger(ParamsReflectUtil.class); public static Object getFieldValueRecursive(String jsonStr, String field) { JSONObject jsonObject = JSON.parseObject(jsonStr); Object fieldValue = null; for (Iterator iter = jsonObject.keySet().iterator(); iter.hasNext();) { String name = (String) iter.next(); Object value = jsonObject.get(name); if ((value != null) && (name != null)) { if (value instanceof JSONObject) { fieldValue = getFieldValueRecursive(JsonUtil.toJsonStr(value), field); }else if(value instanceof String){ if(name.equals(field)) { fieldValue = jsonObject.get(name); break; } } } } return fieldValue; } }
文章来源:https://blog.csdn.net/siantbaicn/article/details/78184364