flask_wtf flask 的 CSRF 源代碼初研究


因為要搞一個基於flask的前后端分離的個人網站,所以需要研究下flask的csrf防護原理.

用的擴展是flask_wtf,也算是比較官方的擴展庫了.

先上相關源代碼:

  1 def validate_csrf(data, secret_key=None, time_limit=None, token_key=None):
  2     """Check if the given data is a valid CSRF token. This compares the given
  3     signed token to the one stored in the session.
  4 
  5     :param data: The signed CSRF token to be checked.
  6     :param secret_key: Used to securely sign the token. Default is
  7         ``WTF_CSRF_SECRET_KEY`` or ``SECRET_KEY``.
  8     :param time_limit: Number of seconds that the token is valid. Default is
  9         ``WTF_CSRF_TIME_LIMIT`` or 3600 seconds (60 minutes).
 10     :param token_key: Key where token is stored in session for comparision.
 11         Default is ``WTF_CSRF_FIELD_NAME`` or ``'csrf_token'``.
 12 
 13     :raises ValidationError: Contains the reason that validation failed.
 14 
 15     .. versionchanged:: 0.14
 16         Raises ``ValidationError`` with a specific error message rather than
 17         returning ``True`` or ``False``.
 18     """
 19 
 20     secret_key = _get_config(
 21         secret_key, 'WTF_CSRF_SECRET_KEY', current_app.secret_key,
 22         message='A secret key is required to use CSRF.'
 23     )
 24     field_name = _get_config(
 25         token_key, 'WTF_CSRF_FIELD_NAME', 'csrf_token',
 26         message='A field name is required to use CSRF.'
 27     )
 28     time_limit = _get_config(
 29         time_limit, 'WTF_CSRF_TIME_LIMIT', 3600, required=False
 30     )
 31 
 32     if not data:
 33         raise ValidationError('The CSRF token is missing.')
 34 
 35     if field_name not in session:
 36         raise ValidationError('The CSRF session token is missing.')
 37 
 38     s = URLSafeTimedSerializer(secret_key, salt='wtf-csrf-token')
 39 
 40     try:
 41         token = s.loads(data, max_age=time_limit)
 42     except SignatureExpired:
 43         raise ValidationError('The CSRF token has expired.')
 44     except BadData:
 45         raise ValidationError('The CSRF token is invalid.')
 46 
 47     if not safe_str_cmp(session[field_name], token):
 48         raise ValidationError('The CSRF tokens do not match.')
 49 
 50 
 51 class CSRFProtect(object):
 52     """Enable CSRF protection globally for a Flask app.
 53 
 54     ::
 55 
 56         app = Flask(__name__)
 57         csrf = CsrfProtect(app)
 58 
 59     Checks the ``csrf_token`` field sent with forms, or the ``X-CSRFToken``
 60     header sent with JavaScript requests. Render the token in templates using
 61     ``{{ csrf_token() }}``.
 62 
 63     See the :ref:`csrf` documentation.
 64     """
 65 
 66     def __init__(self, app=None):
 67         self._exempt_views = set()
 68         self._exempt_blueprints = set()
 69 
 70         if app:
 71             self.init_app(app)
 72 
 73     def init_app(self, app):
 74         app.extensions['csrf'] = self
 75 
 76         app.config.setdefault('WTF_CSRF_ENABLED', True)
 77         app.config.setdefault('WTF_CSRF_CHECK_DEFAULT', True)
 78         app.config['WTF_CSRF_METHODS'] = set(app.config.get(
 79             'WTF_CSRF_METHODS', ['POST', 'PUT', 'PATCH', 'DELETE']
 80         ))
 81         app.config.setdefault('WTF_CSRF_FIELD_NAME', 'csrf_token')
 82         app.config.setdefault(
 83             'WTF_CSRF_HEADERS', ['X-CSRFToken', 'X-CSRF-Token']
 84         )
 85         app.config.setdefault('WTF_CSRF_TIME_LIMIT', 3600)
 86         app.config.setdefault('WTF_CSRF_SSL_STRICT', True)
 87 
 88         app.jinja_env.globals['csrf_token'] = generate_csrf        <><><><><><><><><><><><><><><><><><><>
 89         app.context_processor(lambda: {'csrf_token': generate_csrf})
 90 
 91         @app.before_request
 92         def csrf_protect():
 93             if not app.config['WTF_CSRF_ENABLED']:
 94                 return
 95 
 96             if not app.config['WTF_CSRF_CHECK_DEFAULT']:
 97                 return
 98 
 99             if request.method not in app.config['WTF_CSRF_METHODS']:
100                 return
101 
102             if not request.endpoint:
103                 return
104 
105             view = app.view_functions.get(request.endpoint)
106 
107             if not view:
108                 return
109 
110             if request.blueprint in self._exempt_blueprints:
111                 return
112 
113             dest = '%s.%s' % (view.__module__, view.__name__)
114 
115             if dest in self._exempt_views:
116                 return
117 
118             self.protect()
119 
120     def _get_csrf_token(self):
121         # find the ``csrf_token`` field in the subitted form
122         # if the form had a prefix, the name will be
123         # ``{prefix}-csrf_token``
124         field_name = current_app.config['WTF_CSRF_FIELD_NAME']
125 
126         for key in request.form:
127             if key.endswith(field_name):
128                 csrf_token = request.form[key]
129 
130                 if csrf_token:
131                     return csrf_token
132 
133         for header_name in current_app.config['WTF_CSRF_HEADERS']:
134             csrf_token = request.headers.get(header_name)
135 
136             if csrf_token:
137                 return csrf_token
138 
139         return None
140 
141     def protect(self):
142         if request.method not in current_app.config['WTF_CSRF_METHODS']:
143             return
144 
145         try:
146             validate_csrf(self._get_csrf_token())
147         except ValidationError as e:
148             logger.info(e.args[0])
149             self._error_response(e.args[0])
150 
151         if request.is_secure and current_app.config['WTF_CSRF_SSL_STRICT']:
152             if not request.referrer:
153                 self._error_response('The referrer header is missing.')
154 
155             good_referrer = 'https://{0}/'.format(request.host)
156 
157             if not same_origin(request.referrer, good_referrer):
158                 self._error_response('The referrer does not match the host.')
159 
160         g.csrf_valid = True  # mark this request as CSRF valid

 先說明下csrftoken的普通機制,上面代碼中有一行代碼后面被我加了一串<>符號,這行代碼表明,默認的jinja2渲染的方式就是通過generate_csrf 方法生成csrftoken字符串,所以前后端分離的話,可以直接通過這個方法獲取csrftoken,效果是一樣的.

進入generate_csrf函數內部,會發現他做了這么點事:生成token,放在session里,然后返回一個加工過的token.這一塊說明每當不同的訪問觸發該函數,那么服務器session內的csrftoken值就會不一樣,所以,你可以這么做,獲取一次之后在有效期(一個小時內)可以重復使用,但是不建議這么做.然后如果不是form表單提交的話,該csrf系統不會從json中獲取token,而會從請求頭獲取,所以需要在請求頭內添加關鍵字段:X-CSRFToken,將這個值賦值為獲取的token即可.

首先獲取csrftoken的方式: _get_csrf_token

會先從表單中查找關鍵字段,如果獲取,那么返回該值,獲取不到,從請求頭獲取,方式和django的基本一致,畢竟也就這兩種規范方式.

 91         @app.before_request
 92         def csrf_protect():

這兩行代碼表明wtf是如何實現校驗的,通過flask的鈎子函數在每次請求開始時進行校驗,這是在初始化wtf init_app(app)的時候就已經添加了該鈎子函數.

在django里面,一旦中間件的process_request返回任何值,中間件即開始執行響應回調,視圖不在執行,那么上面的兩行代碼下面好像不停地return了好多次,到底啥意思呢,只好再找源碼看看.相關源碼在下面:

    @setupmethod
    def before_request(self, f):
        """Registers a function to run before each request.

        For example, this can be used to open a database connection, or to load
        the logged in user from the session.

        The function will be called without any arguments. If it returns a
        non-None value, the value is handled as if it was the return value from
        the view, and further request handling is stopped.
        """
        self.before_request_funcs.setdefault(None, []).append(f)
        return f

可以看到添加鈎子函數的裝飾器執行了什么操作,他只是把鈎子函數放進了一個函數列表里,然后我們看看這個函數列表是什么方式處理的.源碼如下:

 

    def preprocess_request(self):
        """Called before the request is dispatched. Calls
        :attr:`url_value_preprocessors` registered with the app and the
        current blueprint (if any). Then calls :attr:`before_request_funcs`
        registered with the app and the blueprint.

        If any :meth:`before_request` handler returns a non-None value, the
        value is handled as if it was the return value from the view, and
        further request handling is stopped.
        """

        bp = _request_ctx_stack.top.request.blueprint

        funcs = self.url_value_preprocessors.get(None, ())
        if bp is not None and bp in self.url_value_preprocessors:
            funcs = chain(funcs, self.url_value_preprocessors[bp])
        for func in funcs:
            func(request.endpoint, request.view_args)

        funcs = self.before_request_funcs.get(None, ())
        if bp is not None and bp in self.before_request_funcs:
            funcs = chain(funcs, self.before_request_funcs[bp])
        for func in funcs:
            rv = func()
            if rv is not None:
                return rv

該方法的注釋說明了,如果鈎子函數返回任意不為空的數據,那么等同於視圖的響應,所以僅僅return 不會導致鈎子函數結束,仍然可以訪問視圖.

 現在可以解釋def csrf_protect():函數的內容了,即,請求方式不在保護范圍內時,跳過校驗,未開啟防護時,跳過校驗,視圖無效時跳過校驗.

 

csrf_protect 中會執行 protect ,protect 會執行 validate_csrf(),validate_csrf()是校驗的關鍵,源代碼如下:

def validate_csrf(data, secret_key=None, time_limit=None, token_key=None):
    """Check if the given data is a valid CSRF token. This compares the given
    signed token to the one stored in the session.

    :param data: The signed CSRF token to be checked.
    :param secret_key: Used to securely sign the token. Default is
        ``WTF_CSRF_SECRET_KEY`` or ``SECRET_KEY``.
    :param time_limit: Number of seconds that the token is valid. Default is
        ``WTF_CSRF_TIME_LIMIT`` or 3600 seconds (60 minutes).
    :param token_key: Key where token is stored in session for comparision.
        Default is ``WTF_CSRF_FIELD_NAME`` or ``'csrf_token'``.

    :raises ValidationError: Contains the reason that validation failed.

    .. versionchanged:: 0.14
        Raises ``ValidationError`` with a specific error message rather than
        returning ``True`` or ``False``.
    """

    secret_key = _get_config(
        secret_key, 'WTF_CSRF_SECRET_KEY', current_app.secret_key,
        message='A secret key is required to use CSRF.'
    )
    field_name = _get_config(
        token_key, 'WTF_CSRF_FIELD_NAME', 'csrf_token',
        message='A field name is required to use CSRF.'
    )
    time_limit = _get_config(
        time_limit, 'WTF_CSRF_TIME_LIMIT', 3600, required=False
    )

    if not data:
        raise ValidationError('The CSRF token is missing.')

    if field_name not in session:
        raise ValidationError('The CSRF session token is missing.')

    s = URLSafeTimedSerializer(secret_key, salt='wtf-csrf-token')

    try:
        token = s.loads(data, max_age=time_limit)
    except SignatureExpired:
        raise ValidationError('The CSRF token has expired.')
    except BadData:
        raise ValidationError('The CSRF token is invalid.')

    if not safe_str_cmp(session[field_name], token):
        raise ValidationError('The CSRF tokens do not match.')

該方法前面部分就是在獲取相關秘鑰和關鍵字,如果不自己自定義的話,這一塊通常不會出問題,后面可以看到,方法會從全局變量session中尋找csrftoken字段名,然后最后一步進行校驗,所以,wtf是通過比對session中的CSRFtoken和表單中的csrftoken是否一致.

 所以前后端分離方式開發的話,需要將csrftoken通過接口或者cookie的方式傳給前端,前端將該部分數據取出保存,提交表單的時候帶上.

至於關鍵字,最上面那段代碼寫的很清楚,默認的,表單是csrf_token, 請求頭是 X-CSRFToken.

 
       


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM