首頁 > 軟體

TensorFlow和Numpy矩陣操作中axis理解及axis=-1的解釋

2022-03-24 13:01:21

1. axis的基本使用

axis常常用在numpy和tensorflow中用到,作為對矩陣(張量)進行操作時需要指定的重要引數之一。設定axis=-1,0,1...,用來指定執行操作的資料如何劃分。

一句話解釋:設axis=i,則沿著第i個下標變化的方向進行操作![1]

簡單例子就不舉了,其他部落格有很多,這裡舉一個稍微複雜一點的三維矩陣的例子:

設embeddings是一個shape=[3,4,5]的矩陣,如下:

embeddings =    [[[-0.30166972  0.25741747 -0.07442257  0.24321035 -0.3538919 ]
                  [-0.22572032  0.1288028  -0.4686908  -0.07217035  0.05287632]
                  [ 0.15845934  0.07064888  0.00922218  0.2841002  -0.24992025]
                  [ 0.43347922 -0.43738696 -0.08176881  0.34185413 -0.2826353 ]]
                
                 [[-0.08590135  0.06792518 -0.07807922 -0.28746927 -0.10613027]
                  [ 0.07476929  0.132256   -0.0926154   0.39621904  0.2497718 ]
                  [-0.15389556  0.0867373   0.19403657 -0.11003655  0.317669  ]
                  [ 0.3949038  -0.17275128  0.34710506 -0.02576578 -0.17427891]]
                
                 [[-0.27703786  0.02631402  0.22129896 -0.07714707  0.41439041]
                  [-0.08512023  0.19059369 -0.13418713 -0.12881753 -0.26143318]
                  [-0.333749    0.27034065  0.45429572 -0.46164128 -0.3955955 ]
                  [ 0.24430516 -0.3841647   0.37126407 -0.463441   -0.1441828 ]]]

對embeddings矩陣執行下面操作:

a = tf.math.argmax(embeddings, axis=-1)   # tf.math.argmax=tf.argmax,用來返回最大數值對應的index
b = tf.math.argmax(embeddings, axis=1)
c = tf.math.argmax(embeddings, axis=0)

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
   
    print(embeddings.eval())   # a.eval 在列印時,等同於 sess.run(a)
    print(a.eval())
    print(b.eval())
    print(c.eval())

得到的結果是:

[[1 1 3 0]
 [1 3 4 0]
 [4 1 2 2]]   # axis=-1, shape=[3,4]
 
[[3 0 2 3 1]
 [3 1 3 1 2]
 [3 2 2 0 0]]   # axis=1, shape=[3,5]
 
[[1 0 2 0 2]
 [1 2 1 1 1]
 [0 2 2 0 1]
 [0 1 2 0 2]]   # axis=0, shape=[4,5]

看懂了嗎?參考上面的一句話解釋,再結合矩陣的下標表示理解一下。

剛剛的矩陣寫成下標表示就是:

embeddings = [[[a000,a001,a002,a003,a004],
               [a010,a011,a012,a013,a014],
               [a020,a021,a022,a023,a024],              
               [a030,a031,a032,a033,a034]]
              
              [[a100,a101,a102,a103,a104],
               [a110,a111,a112,a113,a114],
               [a120,a121,a122,a123,a124],              
               [a130,a131,a132,a133,a134]]
              
              [[a200,a201,a202,a203,a204],
               [a210,a211,a212,a213,a214],
               [a220,a221,a222,a223,a224],              
               [a230,a231,a232,a233,a234]]                

以axis=0為例,則沿著第0個下標(最左邊的下標)變化的方向進行操作,也就是將除了第0個下標外,其他兩個下標都相同的部分分成一組,然後再進行操作。具體分組如下(省略了一些組):

從上圖可以看出,每3個數分成一組,所以現在總共是分了4*5個組(所以最終返回的結果也是一個shape=[4,5]的矩陣),對每個組都執行一次 reduce_max操作,將每個組的三個數中數值最大的數的index返回構成矩陣即可。

這裡需要特別說明一下axis=-1的操作,可能對python不熟悉的人會不理解這裡的-1是哪個維度。在pyhton中,-1代表倒數第一個,也就是說,假如你的矩陣shape=[3,4,5],那麼對這個矩陣來說,axis=-1,其實也就等於axis=2。因為這是個三維矩陣,所以axis可能的取值為0,1,2,所以最後一個就是2。你可以自己試試看兩個取值結果是否相同。

2. 對axis的理解

通過上面的例子,你可能已經發現了,axis是將矩陣進行分組,然後再操作。而分組則意味著會降維。

以剛剛的例子,原始矩陣的shape=[3,4,5],取axis=0再進行操作後,得到的矩陣shape=[4,5]。同樣的,取axis=1再進行操作後,得到的矩陣shape=[3,5]。取axis=-1(axis=2)再操作後,shape=[3,4]。掌握這一點,能有利於你在神經網路中的變換或是資料操作中明確矩陣變換前後的形狀,從而加快對模型的理解。

總結

到此這篇關於TensorFlow和Numpy矩陣操作中axis理解及axis=-1解釋的文章就介紹到這了,更多相關矩陣操作中axis=-1的解釋內容請搜尋it145.com以前的文章或繼續瀏覽下面的相關文章希望大家以後多多支援it145.com!

參考資料

[1] 這個一句話解釋來源於:https://www.jb51.net/article/242077.htm


IT145.com E-mail:sddin#qq.com