利用Filter和HttpServletRequestWrapper實現請求體中token校驗


  先說一下項目的背景,系統傳參為json格式,token為其中一個必傳參數,此時如果在過濾器中直接讀取request,則后續controller中通過RequestBody注解封裝請求參數是會報stream closed異常,一位InputStream是一個基礎流,只能被讀取一次。代碼如下:

package com.hellobike.scm.filter;

import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.JSONObject;
import com.google.common.base.Strings;
import com.hellobike.basic.model.sys.ScmUser;
import com.hellobike.basic.util.Utils;
import com.hellobike.scm.conf.UserThread;
import com.hellobike.scm.service.RedisCacheService;
import com.hellobike.scm.service.RedisService;
import com.hellobike.scm.util.LoginUserUtil;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.web.context.support.SpringBeanAutowiringSupport;
import org.springframework.web.multipart.MultipartHttpServletRequest;
import org.springframework.web.multipart.commons.CommonsMultipartResolver;

import javax.servlet.*;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;
import java.util.Map;
import java.util.Set;

public class AuthFilter implements Filter {

    private static final Logger logger = LoggerFactory.getLogger(AuthFilter.class);


    @Autowired
    private RedisCacheService redisCacheService;


    @Override
    public void init(FilterConfig filterConfig) throws ServletException {
        SpringBeanAutowiringSupport.processInjectionBasedOnServletContext(this, filterConfig.getServletContext());
    }

    @Override
    public void doFilter(ServletRequest request, ServletResponse response, FilterChain filterChain)
            throws IOException, ServletException {
        HttpServletRequest httpRequest = (HttpServletRequest) request;
        HttpServletResponse httpResponse = (HttpServletResponse) response;
        String method = httpRequest.getMethod();
        httpResponse.setHeader("Access-Control-Allow-Origin", "*");
        httpResponse.setContentType("application/json; charset=utf-8");
        httpResponse.setHeader("pragma", "no-cache");
        httpResponse.setHeader("cache-control", "no-cache");
        LoginUserUtil.removeCurrentUserInfo();

        String requestUrl = httpRequest.getRequestURI();
        //|| requestUrl.contains("/systemConfig/getUpgradeInfo")
        if (requestUrl.contains(".css") || requestUrl.contains(".js") || requestUrl.contains(".png")
                || requestUrl.contains(".jpg") || requestUrl.contains("/execption/")||requestUrl.contains("/sys/main/login") ||requestUrl.contains("/sys/main/checkCode")) {
            // 如果發現是css或者js文件,直接放行
            filterChain.doFilter(request, response);
            return;
        }
        String userAgent = httpRequest.getHeader("User-Agent");
        String reqCxtType = httpRequest.getContentType();
        boolean isJsonType = shouldLog(reqCxtType);
        String token = null;
        Map<String, String[]> param = request.getParameterMap();
        if (param.containsKey("token")) {
            token = param.get("token")[0];
            logger.info("authFilter中get請求token值為{}"+token);
        }
        if (token == null) {
            String body = null;
            // 再從contentType取token
            if (isJsonType) {
                httpRequest = new RequestWrapper((HttpServletRequest) request);
                body = ((RequestWrapper) httpRequest).getBody();
                JSONObject params = JSON.parseObject(body);
                logger.info("post param is {}", body);
                if (params == null) {
                    request.getRequestDispatcher("/admin/execption/tokenError").forward(request, response);
                    return;
                }
                token = params.getString("token");
            }
        }
        // 最后從form_data里面取token
        if (null == token) {
            if (request.getContentType() != null && request.getContentType().contains("multipart/form")) {
                MultipartHttpServletRequest multiReq = null;
                try {
                    multiReq = new CommonsMultipartResolver().resolveMultipart(httpRequest);
                    token = multiReq.getParameter("token");
                    logger.info("multiReq token=" + token);
                    // request = multiReq;
                } catch (Exception e) {
                    logger.info("MultipartHttpServletRequest異常");
                    e.printStackTrace();
                }
            }
        } 
        if (isJsonType && !tokenAvailable(token, userAgent)) {
            request.getRequestDispatcher("/admin/execption/tokenError").forward(request, response);
            return;
        } 
        if (!tokenAvailable(token, userAgent)) {
            request.getRequestDispatcher("/admin/execption/tokenError").forward(request, response);
            return;
        }
        // url權限校驗
           ScmUser adminUser = JSON.parseObject(redisCacheService.getToken(token),ScmUser.class);
           UserThread.setValue(adminUser);
            String userName = null;
            if (adminUser != null) {
                userName=adminUser.getUserName();
                LoginUserUtil.setCurrentUserName(userName);
            }
            //判斷是否系統升級
            if(upgrade(userName,requestUrl)){
                request.getRequestDispatcher("/systemConfig/getUpgradeInfo").forward(httpRequest, response);
                return;
            }

            logger.info("data的值{}", userName);
            String permiCode = (String) redisCacheService.getUserPermission(userName, requestUrl);
            logger.info("requestUrl={},permiCode={}", requestUrl, permiCode);
            if (permiCode != null && "None".equals(permiCode)) {
                // returnTokenError(httpResponse,ErrorCode.UNAUTHORIZED);
                request.getRequestDispatcher("/admin/execption/unauthorized").forward(request, response);
                return;
            }
        filterChain.doFilter(httpRequest, httpResponse);
    }

    private boolean upgrade(String userName,String requestUrl){
        Set<String> userNames = redisCacheService.get("scm_upgrade_tag" , Set.class);
        if (userNames != null&& (!requestUrl.equals("/systemConfig/getUpgradeInfo")) &&(!userNames.contains(userName))) {
            return true;
        }
        return false;
    }

    private boolean tokenAvailable(String token, String userAgent) {
        if (Strings.isNullOrEmpty(token)) {
            return false;
        }
        try {
            String data = redisCacheService.getToken(token);
            if (Utils.isEmpty(data)) {
                return false;
            }
            redisCacheService.set(token, data, 1200);
        } catch (Exception e) {
            logger.info("校驗token失敗");
        }
        return true;
    }

    @Override
    public void destroy() {

    }

    private boolean shouldLog(String contentType) {
        String jsonType = "application/json";

        if (contentType == null) {
            return false;
        }
        String[] cells = contentType.split(";");
        for (String cell : cells) {
            if (cell.trim().toLowerCase().equals(jsonType)) {
                return true;
            }
        }

        return false;
    }
}
package com.hellobike.manage.flter;

import javax.servlet.ReadListener;
import javax.servlet.ServletInputStream;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;
import java.io.BufferedReader;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.util.HashMap;
import java.util.Map;

public class RequestWrapper extends HttpServletRequestWrapper {
    private final String body;
    public RequestWrapper(HttpServletRequest request) throws IOException {
        super(request);
        StringBuilder stringBuilder = new StringBuilder();
        BufferedReader bufferedReader = null;
        try {
            InputStream inputStream = request.getInputStream();
            if (inputStream != null) {
                bufferedReader = new BufferedReader(new InputStreamReader(inputStream));
                char[] charBuffer = new char[128];
                int bytesRead = -1;
                while ((bytesRead = bufferedReader.read(charBuffer)) > 0) {
                    stringBuilder.append(charBuffer, 0, bytesRead);
                }
            } else {
                stringBuilder.append("");
            }
        } catch (IOException ex) {
            throw ex;
        } finally {
            if (bufferedReader != null) {
                try {
                    bufferedReader.close();
                } catch (IOException ex) {
                    throw ex;
                }
            }
        }
        body = stringBuilder.toString();
    }

    @Override
    public ServletInputStream getInputStream() throws IOException {
        final ByteArrayInputStream byteArrayInputStream = new     ByteArrayInputStream(body.getBytes());
        ServletInputStream servletInputStream = new ServletInputStream() {
            @Override
            public boolean isFinished() {
                return false;
            }

            @Override
            public boolean isReady() {
                return false;
            }

            @Override
            public void setReadListener(ReadListener readListener) {

            }

            public int read() throws IOException {
                return byteArrayInputStream.read();
            }
        };
        return servletInputStream;
    }

    @Override
    public BufferedReader getReader() throws IOException {
        return new BufferedReader(new InputStreamReader(this.getInputStream()));
    }

    public String getBody() {
        return this.body;
    }

    private Map<String , String[]> params = new HashMap<String, String[]>();


    public void setParameter(String name, Object value) {
        if (value != null) {
            System.out.println(value);
            if (value instanceof String[]) {
                params.put(name, (String[]) value);
            } else if (value instanceof String) {
                params.put(name, new String[]{(String) value});
            } else {
                params.put(name, new String[]{String.valueOf(value)});
            }
        }
    }

    @Override
    public String getParameter(String name) {
        String[]values = params.get(name);
        if(values == null || values.length == 0) {
            return null;
        }
        return values[0];
    }

}

  下面說一下我遇到的坑,系統要做一個簡單的系統升級界面,此時用戶不可訪問,因此做了個請求轉發,在springboot1.5.1版本中,FilterRegistrationBean默認會攔截轉發的請求,此時request.getRequestDispatcher("/systemConfig/getUpgradeInfo").forward(request, response)會被再次攔截,因為request中流已被讀取,所以此時RequestWrapper會報stream closed。在springboot2.0.4版本中,FilterRegistrationBean默認不會攔截轉發(FORWARD)類型的請求,所以不會報stream closed,所以第一個坑就是FilterRegistrationBean默認的攔截類型的設置。修改轉發方法傳參為httpRequest,即使用RequestWrapper封裝后的httpRquest,轉發方式變為request.getRequestDispatcher("/systemConfig/getUpgradeInfo").forward(httpRquest, response),此時在配置FilterRegistrationBean攔截轉發(FORWARD)類型的請求后,會出現無限循環的現象,一直轉發下去,這也是第二個坑,因此我們要在轉發是否的判斷條件是將我們要轉發的URL排除在外,即

requestUrl.equals("/systemConfig/getUpgradeInfo")不再進行轉發。

  問題雖然簡單,但是前前后后還是浪費不少時間,特記錄一下


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM