大家都知道Java的servlet分get和post請求方式,在servlet或者在集成了springMVC、Struts2的框架的情況下獲取請求的參數。那么有時候我們需要在攔截其中獲取ServletRequest的參數就不那么容易了。因為在ServletRequst中,如果是get請求我們可以通過request.getParameter("")來獲取get的參數或者是form提交的post參數,但是如果是ajax提交的post請求的application/json請求,那么在get的時候就無法獲取到值了,有人會想我通過request的請求流來解析json文本,這樣做是可以的,但 是有個問題就是如果在攔截其中調用了ServletRequest的getInputStream方法,那么在后面的servlet中或者你集成的框架中的control層就無法調用getInputStream方法來解析獲取參數了。
有了上面的疑問,我們就有了分析,解決辦法的途徑。通過對HttpServletRequest的分析結合資料,最后得出結論就是改寫ServletRequst的getInputStream方法便可以解決問題。我們可以分析一下HttpServletRequest的中的stream只能被read一次,那么我們可以在filter中調用getInputSteam獲取json字符串,然后通過獲取的json文本去生成新的stream來給ServletRequest,后面的control就可以繼續獲取stream(我們自己用json文本生成)。有了這個思路我們就來看看代碼。
一.改寫ServletRequest
PostServletRequest.java
import java.io.BufferedReader; import java.io.IOException; import java.io.InputStreamReader; import javax.servlet.ServletInputStream; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletRequestWrapper; import org.apache.commons.lang3.StringUtils; public class PostServletRequest extends HttpServletRequestWrapper { private String body=null; /** * Constructs a request object wrapping the given request. * @param request * @throws IllegalArgumentException if the request is null */ public PostServletRequest(HttpServletRequest request,String body) { super(request); this.body=body; } @Override public ServletInputStream getInputStream() throws IOException { ServletInputStream inputStream = null; if(StringUtils.isNotEmpty(body)){ inputStream = new PostServletInputStream(body); } return inputStream; } @Override public BufferedReader getReader() throws IOException { String enc = getCharacterEncoding(); if(enc == null) enc = "UTF-8"; return new BufferedReader(new InputStreamReader(getInputStream(), enc)); } }
二.ServletInputStream的改寫
PostServletInputStream.java
import java.io.ByteArrayInputStream; import java.io.IOException; import java.io.InputStream; import javax.servlet.ServletInputStream; public class PostServletInputStream extends ServletInputStream { private InputStream inputStream; private String body ;//解析json之后的文本 public PostServletInputStream(String body) throws IOException { this.body=body; inputStream = null; } private InputStream acquireInputStream() throws IOException { if(inputStream == null) { inputStream = new ByteArrayInputStream(body.getBytes());//通過解析之后傳入的文本生成inputStream以便后面control調用 } return inputStream; } public void close() throws IOException { try { if(inputStream != null) { inputStream.close(); } } catch(IOException e) { throw e; } finally { inputStream = null; } } public int read() throws IOException { return acquireInputStream().read(); } public boolean markSupported() { return false; } public synchronized void mark(int i) { throw new UnsupportedOperationException("mark not supported"); } public synchronized void reset() throws IOException { throw new IOException(new UnsupportedOperationException("reset not supported")); } }
三.在filter中的調用
import java.io.BufferedReader; import java.io.IOException; import java.io.InputStream; import java.io.InputStreamReader; import java.util.Arrays; import java.util.HashMap; import java.util.List; import java.util.Map; import javax.servlet.FilterChain; import javax.servlet.FilterConfig; import javax.servlet.ServletException; import javax.servlet.ServletRequest; import javax.servlet.ServletResponse; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; import org.apache.commons.lang3.StringUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; /** * 過濾器 */ public class UrlFilter extends AbstractWebFilter { private final static Logger LOGGER = LoggerFactory.getLogger(UrlFilter.class); private final static String MERID_WHITE_LIST = "merABC:test/login.do,test/getList.do;merDEF:test/hello.do,test/greet.do"; private static Map<String, List<String>> merIdWhiteListMap = new HashMap<String, List<String>>(); @Override public void init(FilterConfig filterConfig) throws ServletException { String[] merIdWhiteList = MERID_WHITE_LIST.split(";"); if(merIdWhiteList != null && merIdWhiteList.length > 0) { int merIdSize = merIdWhiteList.length; for(int i=0;i<merIdSize;i++) { String merIdUrls = merIdWhiteList[i]; String[] merIdUrl = merIdUrls.split(":"); if(merIdUrl != null && merIdUrl.length == 2) { String merId = merIdUrl[0]; String urls = merIdUrl[1]; String[] urlList = urls.split(","); if(urlList != null && urlList.length > 0) { List<String> lists = Arrays.asList(urlList); merIdWhiteListMap.put(merId, lists); } } } } LOGGER.info("merIdWhiteListMap:{}", JsonUtil.toJsonStr(merIdWhiteListMap)); } @Override public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain) throws IOException, ServletException { HttpServletRequest req = (HttpServletRequest) request; HttpServletResponse rsp = (HttpServletResponse) response; String url = getRequestPath(req); if(url.endsWith(".do")) { //解析post的json參數,進一步根據請求入參校驗 String body = getBody((HttpServletRequest)request); if(StringUtils.isNotEmpty(body)) { String headMerId = (String) ParamsReflectUtil.getFieldValueRecursive(body, "headMerId"); if(StringUtils.isNotEmpty(headMerId)) { List<String> urlList = merIdWhiteListMap.get(headMerId); LOGGER.info("urlList:{}, curretn url:{}", urlList, url); if(urlList.contains(url)) { //使用解析數據重新生成ServletRequest,供doChain調用 request = getRequest(request,body); chain.doFilter(request, response); }else { LOGGER.info("非法url請求: {},請求入參:{}", url, body); forbiddenJson(rsp); } }else { chain.doFilter(request, response); } }else { chain.doFilter(request, response); } }else { LOGGER.info("非法url請求: {}", url); forbiddenJson(rsp); } } @Override public void destroy() { } /** * 返回ajax信息 */ private void forbiddenJson(HttpServletResponse httpResponse) throws IOException { Map<String,Object> param = new HashMap<String,Object>(); param.put("error", "403"); httpResponse.setStatus(403); httpResponse.setCharacterEncoding("utf-8"); httpResponse.setContentType("application/json"); httpResponse.getWriter().print(JsonUtil.toJsonStr(param)); } private String getBody(HttpServletRequest request) throws IOException { String body = null; StringBuilder stringBuilder = new StringBuilder(); BufferedReader bufferedReader = null; try { InputStream inputStream = request.getInputStream(); if (inputStream != null) { bufferedReader = new BufferedReader(new InputStreamReader(inputStream)); char[] charBuffer = new char[128]; int bytesRead = -1; while ((bytesRead = bufferedReader.read(charBuffer)) > 0) { stringBuilder.append(charBuffer, 0, bytesRead); } } else { stringBuilder.append(""); } } catch (IOException ex) { throw ex; } finally { if (null != bufferedReader) { bufferedReader.close(); } } body = stringBuilder.toString(); return body; } /** * 將post解析過后的request進行封裝改寫 * @param request * @param body * @return */ private ServletRequest getRequest(ServletRequest request, String body) { String enctype = request.getContentType(); if (StringUtils.isNotEmpty(enctype) && enctype.contains("application/json")) { return new PostServletRequest((HttpServletRequest) request, body); } return request; } }
AbstractWebFilter.java
import java.io.IOException; import javax.servlet.Filter; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; import org.slf4j.Logger; import org.slf4j.LoggerFactory;public abstract class AbstractWebFilter implements Filter { private static Logger log = LoggerFactory.getLogger(AbstractWebFilter.class); /** * 獲取客戶端ip */ protected String getRemortIP(HttpServletRequest request) { if (request.getHeader("x-forwarded-for") == null) { return request.getRemoteAddr(); } return request.getHeader("x-forwarded-for"); } /** * 獲取請求的url不含上下文的路徑(並且路徑開頭不帶"/") */ protected String getRequestPath(HttpServletRequest request) { String requestPath = request.getServletPath(); if (requestPath != null && requestPath.startsWith("/")) { return requestPath.substring(1); } return requestPath; } /** * 重定向到登錄頁面 */ private void forbiddenRedirect(HttpServletResponse httpResponse) throws IOException { String logoutUrl = PropertiesUtils.getString("logoutUrl"); httpResponse.sendRedirect(logoutUrl); } }
ParamsReflectUtil.java
import java.util.Iterator; import java.util.Map; import java.util.TreeMap; import org.apache.commons.lang3.StringUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import com.alibaba.fastjson.JSON; import com.alibaba.fastjson.JSONArray; import com.alibaba.fastjson.JSONObject; /** * 屬性(對象)值反射獲取工具類 */ public class ParamsReflectUtil { private final static Logger logger = LoggerFactory.getLogger(ParamsReflectUtil.class); public static Object getFieldValueRecursive(String jsonStr, String field) { JSONObject jsonObject = JSON.parseObject(jsonStr); Object fieldValue = null; for (Iterator iter = jsonObject.keySet().iterator(); iter.hasNext();) { String name = (String) iter.next(); Object value = jsonObject.get(name); if ((value != null) && (name != null)) { if (value instanceof JSONObject) { fieldValue = getFieldValueRecursive(JsonUtil.toJsonStr(value), field); }else if(value instanceof String){ if(name.equals(field)) { fieldValue = jsonObject.get(name); break; } } } } return fieldValue; } }
文章來源:https://blog.csdn.net/siantbaicn/article/details/78184364
