tensorflow hub 嘗鮮

今年年初,伴隨著 tensorflow 更新到 1.7.0 版本,Google 發(fā)布了 tensorflow hub。tensorflow hub 的主要目標是為模型提供一種簡便的封裝方式,同時可以簡便地復用已封裝的模型,可以說 tf hub 是為遷移學習而生的。

熟悉自然語言處理的同學都知道大部分 nlp 模型的底層都是 word2vec 詞向量作為一個詞的特征,當然近幾年越來越多的模型會構建于語言模型之上,例如用 ELMo 代替詞向量。其實無論哪種方式,downstream 的任務都建立在這些預訓練好的向量之上,downstream 的任務與底層 embedding 的訓練是高度解耦的。所以完全可以有專門的團隊負責底層 embedding 的優(yōu)化與開發(fā),讓后將它們用 tf hub 封裝成 module 供下游應用團隊使用,這些 module 對于使用人員就是黑盒子,他們無需關心 module 的實現(xiàn)細節(jié)。
在圖像領域也是一樣的,通常一些 downstream 的任務都會建立在一些經(jīng)典的模型(vgg, resnet, mobilenet 等)之上,它們會利用這些模型預訓練好的權重及結構作為特征提取器。

這里結合 nlp 中的 embedding 的封裝和使用介紹一下 tensorflow hub 的細節(jié),安裝方式看 github。主要參考了官網(wǎng)和 github:
https://tensorflow.google.cn/hub/
https://github.com/tensorflow/hub

我們先看看怎么使用一個別人為我們封裝好的模型:

hub_module = hub.Module(self.get_temp_dir())
tokens = tf.constant(["cat", "lizard", "dog"])
embeddings = hub_module(tokens)
with tf.Session() as session:
    session.run(tf.tables_initializer())
    session.run(tf.global_variables_initializer())
    self.assertAllClose(
        session.run(embeddings),
        [[1.11, 2.56, 3.45], [0.0, 0.0, 0.0], [1.0, 2.0, 3.0]])

代碼的第一行表示加載一個封裝好的 tensorflow hub 模型,參數(shù)可以是模型的路徑也可以是一個保存有模型的 http 地址。
第二行創(chuàng)建了一個包含三個字符串的張量。
第三行調用了剛剛創(chuàng)建的 hub 模型,字符串張量作為模型的輸入,embeddings就是模型的輸出了。
調用 session.run 就能得到具體的 embeddings 輸出值。

我第一次見到這個 demo 的時候驚喜之處在于,我們往常使用 embedding 時都需要將舒服的單詞轉換成相應的 id,將 id 作為輸入查詢相應的 embedding。難道這里把轉換操作也封裝進 hub module 里的?為了滿足好奇心我們可以看一下如何封裝這樣一個 module。

hub module 封裝:

def module_fn():
    """Spec function for a token embedding module."""
    tokens = tf.placeholder(shape=[None], dtype=tf.string, name="tokens")

    embeddings_var = tf.get_variable(
        initializer=tf.zeros([vocab_size + num_oov_buckets, embeddings_dim]),
        name=EMBEDDINGS_VAR_NAME,
        dtype=tf.float32)

    lookup_table = tf.contrib.lookup.index_table_from_file(
        vocabulary_file=vocabulary_file,
        num_oov_buckets=num_oov_buckets,
    )
    ids = lookup_table.lookup(tokens)
    combined_embedding = tf.nn.embedding_lookup(params=embeddings_var, ids=ids)
    hub.add_signature("default", {"tokens": tokens},
                      {"default": combined_embedding})

spec = hub.create_module_spec(module_fn)
with tf.Graph().as_default():
      m = hub.Module(spec)
      p_embeddings = tf.placeholder(tf.float32)
      load_embeddings = tf.assign(m.variable_map[EMBEDDINGS_VAR_NAME],
                                  p_embeddings)

      with tf.Session() as sess:
        sess.run([load_embeddings], feed_dict={p_embeddings: embeddings})
        m.export(export_path, sess)

根據(jù)上面代碼,創(chuàng)建 hub module 的流程如下:
1、調用 hub.create_module_spec 創(chuàng)建一個 spec,函數(shù)的參數(shù)是 module 的計算圖創(chuàng)建函數(shù)
2、調用 hub.Module 創(chuàng)建一個 module 對象,參數(shù)是上一步創(chuàng)建的 spec
3、在 session 中訓練模型,這個 demo 里面沒有訓練,而是直接利用 tf 的賦值操作將一個 numpy 矩陣賦值給了模型的參數(shù)
4、調用 export 函數(shù)將當前的 session 保存到某個路徑中

在 module_fn 函數(shù)的最后調用了 hub.add_signature,第一個參數(shù)是創(chuàng)建的這個 hub module 的名稱,第二個參數(shù)是 module 的輸入,它是一個字典,支持多個輸入,第三個參數(shù)是 module 能提供的輸出,同樣也是字典,支持輸出多個數(shù)據(jù)。
可以看到,module 的輸入定義為一個 字符創(chuàng)類型的 placeholder,然后利用 index_table_from_file 創(chuàng)建了一個 lookup_table,這個 table 就可以將字符串轉化為相應的 id,這里就解答了之前的好奇。

總結

第一次看到 tensorflow hub 就覺得很優(yōu)雅,以前做 nlp 的工作會花大量的時間在準備數(shù)據(jù)上,而利用 tf hub 以后所有的任務都可以使用統(tǒng)一的 embedding module,并且可以直接將字符串作為輸入,不用再手動轉換。
hub module 在使用時還能設定為參數(shù)可訓練或者參數(shù)不可訓練,這樣對于不同的任務就能有更靈活的選擇。對于一些訓練樣本較少的情況,可以凍結底層 module 的參數(shù),做完全的遷移學習。
首次嘗試 tensorflow hub 還是相當欣喜的,以后也會盡量使用。


?著作權歸作者所有,轉載或內(nèi)容合作請聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時請結合常識與多方信息審慎甄別。
平臺聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點,簡書系信息發(fā)布平臺,僅提供信息存儲服務。

相關閱讀更多精彩內(nèi)容

友情鏈接更多精彩內(nèi)容