slim.arg_scope()的使用


【https://blog.csdn.net/u013921430 轉載】

     slim是一種輕量級的tensorflow庫,可以使模型的構建,訓練,測試都變得更加簡單。slim庫中對很多常用的函數進行了定義,slim.arg_scope()是slim庫中經常用到的函數之一。函數的定義如下;


   
   
  
  
          
  1. @tf_contextlib.contextmanager
  2. def arg_scope(list_ops_or_scope, **kwargs):
  3. """Stores the default arguments for the given set of list_ops.
  4. For usage, please see examples at top of the file.
  5. Args:
  6. list_ops_or_scope: List or tuple of operations to set argument scope for or
  7. a dictionary containing the current scope. When list_ops_or_scope is a
  8. dict, kwargs must be empty. When list_ops_or_scope is a list or tuple,
  9. then every op in it need to be decorated with @add_arg_scope to work.
  10. **kwargs: keyword=value that will define the defaults for each op in
  11. list_ops. All the ops need to accept the given set of arguments.
  12. Yields:
  13. the current_scope, which is a dictionary of {op: {arg: value}}
  14. Raises:
  15. TypeError: if list_ops is not a list or a tuple.
  16. ValueError: if any op in list_ops has not be decorated with @add_arg_scope.
  17. """
  18. if isinstance(list_ops_or_scope, dict):
  19. # Assumes that list_ops_or_scope is a scope that is being reused.
  20. if kwargs:
  21. raise ValueError( 'When attempting to re-use a scope by suppling a'
  22. 'dictionary, kwargs must be empty.')
  23. current_scope = list_ops_or_scope.copy()
  24. try:
  25. _get_arg_stack().append(current_scope)
  26. yield current_scope
  27. finally:
  28. _get_arg_stack().pop()
  29. else:
  30. # Assumes that list_ops_or_scope is a list/tuple of ops with kwargs.
  31. if not isinstance(list_ops_or_scope, (list, tuple)):
  32. raise TypeError( 'list_ops_or_scope must either be a list/tuple or reused'
  33. 'scope (i.e. dict)')
  34. try:
  35. current_scope = current_arg_scope().copy()
  36. for op in list_ops_or_scope:
  37. key_op = _key_op(op)
  38. if not has_arg_scope(op):
  39. raise ValueError( '%s is not decorated with @add_arg_scope',
  40. _name_op(op))
  41. if key_op in current_scope:
  42. current_kwargs = current_scope[key_op].copy()
  43. current_kwargs.update(kwargs)
  44. current_scope[key_op] = current_kwargs
  45. else:
  46. current_scope[key_op] = kwargs.copy()
  47. _get_arg_stack().append(current_scope)
  48. yield current_scope
  49. finally:
  50. _get_arg_stack().pop()

     如注釋中所說,這個函數的作用是給list_ops中的內容設置默認值。但是每個list_ops中的每個成員需要用@add_arg_scope修飾才行。所以使用slim.arg_scope()有兩個步驟:

  1. 使用@slim.add_arg_scope修飾目標函數
  2.  slim.arg_scope()為目標函數設置默認參數.

     例如如下代碼;首先用@slim.add_arg_scope修飾目標函數fun1(),然后利用slim.arg_scope()為它設置默認參數。


   
   
  
  
          
  1. import tensorflow as tf
  2. slim =tf.contrib.slim
  3. @slim.add_arg_scope
  4. def fun1(a=0,b=0):
  5. return (a+b)
  6. with slim.arg_scope([fun1],a= 10):
  7. x=fun1(b= 30)
  8. print(x)

     運行結果為:

40
  
  
 
 
         

    平常所用到的slim.conv2d( ),slim.fully_connected( ),slim.max_pool2d( )等函數在他被定義的時候就已經添加了@add_arg_scope。以slim.conv2d( )為例;


   
   
  
  
          
  1. @ add_arg_scope
  2. def convolution(inputs,
  3. num_outputs,
  4. kernel_size,
  5. stride=1,
  6. padding='SAME',
  7. data_format=None,
  8. rate=1,
  9. activation_fn=nn.relu,
  10. normalizer_fn=None,
  11. normalizer_params=None,
  12. weights_initializer=initializers.xavier_initializer(),
  13. weights_regularizer=None,
  14. biases_initializer=init_ops.zeros_initializer(),
  15. biases_regularizer=None,
  16. reuse=None,
  17. variables_collections=None,
  18. outputs_collections=None,
  19. trainable=True,
  20. scope=None):

     所以,在使用過程中可以直接slim.conv2d( )等函數設置默認參數。例如在下面的代碼中,不做單獨聲明的情況下,slim.conv2d, slim.max_pool2d, slim.avg_pool2d三個函數默認的步長都設為1,padding模式都是'VALID'的。但是也可以在調用時進行單獨聲明。這種參數設置方式在構建網絡模型時,尤其是較深的網絡時,可以節省時間。


   
   
  
  
          
  1. with slim.arg_scope(
  2. [slim.conv2d, slim.max_pool2d, slim.avg_pool2d],stride = 1, padding = 'VALID'):
  3. net = slim.conv2d(inputs, 32, [ 3, 3], stride = 2, scope = 'Conv2d_1a_3x3')
  4. net = slim.conv2d(net, 32, [ 3, 3], scope = 'Conv2d_2a_3x3')
  5. net = slim.conv2d(net, 64, [ 3, 3], padding = 'SAME', scope = 'Conv2d_2b_3x3')

@修飾符     

     其實這種用法是python中常用到的。在python中@修飾符放在函數定義的上方,它將被修飾的函數作為參數,並返回修飾后的同名函數。形式如下;


   
   
  
  
          
  1. @fun_a #等價於fun_a(fun_b)
  2. def fun_b():

      這在本質上講跟直接調用被修飾的函數沒什么區別,但是有時候也有用處,例如在調用被修飾函數前需要輸出時間信息,我們可以在@后方的函數中添加輸出時間信息的語句,這樣每次我們只需要調用@后方的函數即可。


   
   
  
  
          
  1. def funs(fun,factor=20):
  2. x=fun()
  3. print(factor*x)
  4. @funs #等價funs(add(),fator=20)
  5. def add(a=10,b=20):
  6. return(a+b)

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM