利用Filter和HttpServletRequestWrapper实现请求体中token校验


  先说一下项目的背景,系统传参为json格式,token为其中一个必传参数,此时如果在过滤器中直接读取request,则后续controller中通过RequestBody注解封装请求参数是会报stream closed异常,一位InputStream是一个基础流,只能被读取一次。代码如下:

package com.hellobike.scm.filter;

import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.JSONObject;
import com.google.common.base.Strings;
import com.hellobike.basic.model.sys.ScmUser;
import com.hellobike.basic.util.Utils;
import com.hellobike.scm.conf.UserThread;
import com.hellobike.scm.service.RedisCacheService;
import com.hellobike.scm.service.RedisService;
import com.hellobike.scm.util.LoginUserUtil;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.web.context.support.SpringBeanAutowiringSupport;
import org.springframework.web.multipart.MultipartHttpServletRequest;
import org.springframework.web.multipart.commons.CommonsMultipartResolver;

import javax.servlet.*;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;
import java.util.Map;
import java.util.Set;

public class AuthFilter implements Filter {

    private static final Logger logger = LoggerFactory.getLogger(AuthFilter.class);


    @Autowired
    private RedisCacheService redisCacheService;


    @Override
    public void init(FilterConfig filterConfig) throws ServletException {
        SpringBeanAutowiringSupport.processInjectionBasedOnServletContext(this, filterConfig.getServletContext());
    }

    @Override
    public void doFilter(ServletRequest request, ServletResponse response, FilterChain filterChain)
            throws IOException, ServletException {
        HttpServletRequest httpRequest = (HttpServletRequest) request;
        HttpServletResponse httpResponse = (HttpServletResponse) response;
        String method = httpRequest.getMethod();
        httpResponse.setHeader("Access-Control-Allow-Origin", "*");
        httpResponse.setContentType("application/json; charset=utf-8");
        httpResponse.setHeader("pragma", "no-cache");
        httpResponse.setHeader("cache-control", "no-cache");
        LoginUserUtil.removeCurrentUserInfo();

        String requestUrl = httpRequest.getRequestURI();
        //|| requestUrl.contains("/systemConfig/getUpgradeInfo")
        if (requestUrl.contains(".css") || requestUrl.contains(".js") || requestUrl.contains(".png")
                || requestUrl.contains(".jpg") || requestUrl.contains("/execption/")||requestUrl.contains("/sys/main/login") ||requestUrl.contains("/sys/main/checkCode")) {
            // 如果发现是css或者js文件,直接放行
            filterChain.doFilter(request, response);
            return;
        }
        String userAgent = httpRequest.getHeader("User-Agent");
        String reqCxtType = httpRequest.getContentType();
        boolean isJsonType = shouldLog(reqCxtType);
        String token = null;
        Map<String, String[]> param = request.getParameterMap();
        if (param.containsKey("token")) {
            token = param.get("token")[0];
            logger.info("authFilter中get请求token值为{}"+token);
        }
        if (token == null) {
            String body = null;
            // 再从contentType取token
            if (isJsonType) {
                httpRequest = new RequestWrapper((HttpServletRequest) request);
                body = ((RequestWrapper) httpRequest).getBody();
                JSONObject params = JSON.parseObject(body);
                logger.info("post param is {}", body);
                if (params == null) {
                    request.getRequestDispatcher("/admin/execption/tokenError").forward(request, response);
                    return;
                }
                token = params.getString("token");
            }
        }
        // 最后从form_data里面取token
        if (null == token) {
            if (request.getContentType() != null && request.getContentType().contains("multipart/form")) {
                MultipartHttpServletRequest multiReq = null;
                try {
                    multiReq = new CommonsMultipartResolver().resolveMultipart(httpRequest);
                    token = multiReq.getParameter("token");
                    logger.info("multiReq token=" + token);
                    // request = multiReq;
                } catch (Exception e) {
                    logger.info("MultipartHttpServletRequest异常");
                    e.printStackTrace();
                }
            }
        } 
        if (isJsonType && !tokenAvailable(token, userAgent)) {
            request.getRequestDispatcher("/admin/execption/tokenError").forward(request, response);
            return;
        } 
        if (!tokenAvailable(token, userAgent)) {
            request.getRequestDispatcher("/admin/execption/tokenError").forward(request, response);
            return;
        }
        // url权限校验
           ScmUser adminUser = JSON.parseObject(redisCacheService.getToken(token),ScmUser.class);
           UserThread.setValue(adminUser);
            String userName = null;
            if (adminUser != null) {
                userName=adminUser.getUserName();
                LoginUserUtil.setCurrentUserName(userName);
            }
            //判断是否系统升级
            if(upgrade(userName,requestUrl)){
                request.getRequestDispatcher("/systemConfig/getUpgradeInfo").forward(httpRequest, response);
                return;
            }

            logger.info("data的值{}", userName);
            String permiCode = (String) redisCacheService.getUserPermission(userName, requestUrl);
            logger.info("requestUrl={},permiCode={}", requestUrl, permiCode);
            if (permiCode != null && "None".equals(permiCode)) {
                // returnTokenError(httpResponse,ErrorCode.UNAUTHORIZED);
                request.getRequestDispatcher("/admin/execption/unauthorized").forward(request, response);
                return;
            }
        filterChain.doFilter(httpRequest, httpResponse);
    }

    private boolean upgrade(String userName,String requestUrl){
        Set<String> userNames = redisCacheService.get("scm_upgrade_tag" , Set.class);
        if (userNames != null&& (!requestUrl.equals("/systemConfig/getUpgradeInfo")) &&(!userNames.contains(userName))) {
            return true;
        }
        return false;
    }

    private boolean tokenAvailable(String token, String userAgent) {
        if (Strings.isNullOrEmpty(token)) {
            return false;
        }
        try {
            String data = redisCacheService.getToken(token);
            if (Utils.isEmpty(data)) {
                return false;
            }
            redisCacheService.set(token, data, 1200);
        } catch (Exception e) {
            logger.info("校验token失败");
        }
        return true;
    }

    @Override
    public void destroy() {

    }

    private boolean shouldLog(String contentType) {
        String jsonType = "application/json";

        if (contentType == null) {
            return false;
        }
        String[] cells = contentType.split(";");
        for (String cell : cells) {
            if (cell.trim().toLowerCase().equals(jsonType)) {
                return true;
            }
        }

        return false;
    }
}
package com.hellobike.manage.flter;

import javax.servlet.ReadListener;
import javax.servlet.ServletInputStream;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;
import java.io.BufferedReader;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.util.HashMap;
import java.util.Map;

public class RequestWrapper extends HttpServletRequestWrapper {
    private final String body;
    public RequestWrapper(HttpServletRequest request) throws IOException {
        super(request);
        StringBuilder stringBuilder = new StringBuilder();
        BufferedReader bufferedReader = null;
        try {
            InputStream inputStream = request.getInputStream();
            if (inputStream != null) {
                bufferedReader = new BufferedReader(new InputStreamReader(inputStream));
                char[] charBuffer = new char[128];
                int bytesRead = -1;
                while ((bytesRead = bufferedReader.read(charBuffer)) > 0) {
                    stringBuilder.append(charBuffer, 0, bytesRead);
                }
            } else {
                stringBuilder.append("");
            }
        } catch (IOException ex) {
            throw ex;
        } finally {
            if (bufferedReader != null) {
                try {
                    bufferedReader.close();
                } catch (IOException ex) {
                    throw ex;
                }
            }
        }
        body = stringBuilder.toString();
    }

    @Override
    public ServletInputStream getInputStream() throws IOException {
        final ByteArrayInputStream byteArrayInputStream = new     ByteArrayInputStream(body.getBytes());
        ServletInputStream servletInputStream = new ServletInputStream() {
            @Override
            public boolean isFinished() {
                return false;
            }

            @Override
            public boolean isReady() {
                return false;
            }

            @Override
            public void setReadListener(ReadListener readListener) {

            }

            public int read() throws IOException {
                return byteArrayInputStream.read();
            }
        };
        return servletInputStream;
    }

    @Override
    public BufferedReader getReader() throws IOException {
        return new BufferedReader(new InputStreamReader(this.getInputStream()));
    }

    public String getBody() {
        return this.body;
    }

    private Map<String , String[]> params = new HashMap<String, String[]>();


    public void setParameter(String name, Object value) {
        if (value != null) {
            System.out.println(value);
            if (value instanceof String[]) {
                params.put(name, (String[]) value);
            } else if (value instanceof String) {
                params.put(name, new String[]{(String) value});
            } else {
                params.put(name, new String[]{String.valueOf(value)});
            }
        }
    }

    @Override
    public String getParameter(String name) {
        String[]values = params.get(name);
        if(values == null || values.length == 0) {
            return null;
        }
        return values[0];
    }

}

  下面说一下我遇到的坑,系统要做一个简单的系统升级界面,此时用户不可访问,因此做了个请求转发,在springboot1.5.1版本中,FilterRegistrationBean默认会拦截转发的请求,此时request.getRequestDispatcher("/systemConfig/getUpgradeInfo").forward(request, response)会被再次拦截,因为request中流已被读取,所以此时RequestWrapper会报stream closed。在springboot2.0.4版本中,FilterRegistrationBean默认不会拦截转发(FORWARD)类型的请求,所以不会报stream closed,所以第一个坑就是FilterRegistrationBean默认的拦截类型的设置。修改转发方法传参为httpRequest,即使用RequestWrapper封装后的httpRquest,转发方式变为request.getRequestDispatcher("/systemConfig/getUpgradeInfo").forward(httpRquest, response),此时在配置FilterRegistrationBean拦截转发(FORWARD)类型的请求后,会出现无限循环的现象,一直转发下去,这也是第二个坑,因此我们要在转发是否的判断条件是将我们要转发的URL排除在外,即

requestUrl.equals("/systemConfig/getUpgradeInfo")不再进行转发。

  问题虽然简单,但是前前后后还是浪费不少时间,特记录一下


免责声明!

本站转载的文章为个人学习借鉴使用,本站对版权不负任何法律责任。如果侵犯了您的隐私权益,请联系本站邮箱yoyou2525@163.com删除。



 
粤ICP备18138465号  © 2018-2025 CODEPRJ.COM