在10.7 節中,我們介紹了編碼器-解碼器架構,以及端到端訓練它們的標準技術。然而,當談到測試時間預測時,我們只提到了 貪心策略,我們在每個時間步選擇下一個預測概率最高的標記,直到在某個時間步,我們發現我們已經預測了特殊的序列結尾“”標記。在本節中,我們將從形式化這種貪婪搜索策略開始,并確定從業者往往會遇到的一些問題。隨后,我們將該策略與兩種替代方案進行比較:窮舉搜索(說明性但不實用)和 波束搜索(實踐中的標準方法)。
讓我們從設置我們的數學符號開始,借用第 10.7 節中的約定。隨時步t′,解碼器輸出表示詞匯表中每個標記出現在序列中的概率的預測(可能的值 yt′+1, 以先前的標記為條件 y1,…,yt′和上下文變量c,由編碼器產生以表示輸入序列。為了量化計算成本,表示為Y輸出詞匯表(包括特殊的序列結束標記“”)。我們還將輸出序列的最大標記數指定為 T′. 我們的目標是搜索所有的理想輸出 O(|Y|T′)可能的輸出序列。請注意,這稍微高估了不同輸出的數量,因為在“”標記出現之后沒有后續標記。然而,出于我們的目的,這個數字大致反映了搜索空間的大小。
10.8.1。貪心搜索
考慮第 10.7 節中的簡單貪婪搜索策略 。在這里,隨時步t′,我們只需從中選擇條件概率最高的標記 Y, IE,
(10.8.1)yt′=argmaxy∈YP(y∣y1,…,yt′?1,c).
一旦我們的模型輸出“”(或者我們達到最大長度 T′) 輸出序列完成。
這個策略看似合理,其實還不錯!考慮到它在計算上的要求是多么的低,你很難獲得更多的收益。然而,如果我們暫時擱置效率,搜索最有可能的序列似乎更合理,而不是(貪婪選擇的)最有可能的標記序列。事實證明,這兩個對象可能完全不同。最可能的序列是最大化表達式的序列 ∏t′=1T′P(yt′∣y1,…,yt′?1,c). 在我們的機器翻譯示例中,如果解碼器真正恢復了潛在生成過程的概率,那么這將為我們提供最有可能的翻譯。不幸的是,不能保證貪心搜索會給我們這個序列。
讓我們用一個例子來說明它。假設輸出字典中有四個標記“A”、“B”、“C”和“”。在 圖10.8.1中,每個時間步下的四個數字分別代表在該時間步生成“A”、“B”、“C”、“”的條件概率。
圖 10.8.1在每個時間步,貪婪搜索選擇條件概率最高的標記。
在每個時間步,貪心搜索選擇條件概率最高的標記。因此,將預測輸出序列“A”、“B”、“C”和“”(圖 10.8.1)。這個輸出序列的條件概率是 0.5×0.4×0.4×0.6=0.048.
接下來,讓我們看一下圖 10.8.2中的另一個例子。與圖 10.8.1不同,在時間步 2 中,我們選擇圖 10.8.2中的標記“C” ,它具有第二高的條件概率。
圖 10.8.2每個時間步下的四個數字代表在該時間步生成“A”、“B”、“C”和“”的條件概率。在時間步 2,選擇具有第二高條件概率的標記“C”。
由于時間步3所基于的時間步1和2的輸出子序列已經從圖10.8.1中的“A”和“B”變為圖10.8.2 中的“A”和“C” ,圖 10.8.2中每個標記在時間步長 3 的條件概率也發生了變化 。假設我們在時間步 3 選擇標記“B”。現在時間步 4 以前三個時間步“A”、“C”和“B”的輸出子序列為條件,這與“A”不同、“B”、“C”在圖 10.8.1中。因此,圖 10.8.2中第 4 步生成每個 token 的條件概率 也與 圖 10.8.1不同. 因此, 圖 10.8.2中輸出序列“A”、“C”、“B”和“”的條件概率為 0.5×0.3×0.6×0.6=0.054,大于圖 10.8.1中的貪心搜索。在本例中,貪心搜索得到的輸出序列“A”、“B”、“C”、“”并不是最優序列。
10.8.2。窮舉搜索
如果目標是獲得最可能的序列,我們可以考慮使用 窮舉搜索:窮舉所有可能的輸出序列及其條件概率,然后輸出得分最高的預測概率。
雖然這肯定會給我們想要的東西,但它的計算成本卻高得令人望而卻步 O(|Y|T′),序列長度呈指數增長,詞匯量很大。例如,當|Y|=10000和T′=10,我們需要評估1000010=1040序列。與實際應用程序相比,這些數字很小,但已經超出了任何可預見的計算機的能力。另一方面,貪心搜索的計算成本是 O(|Y|T′): 奇跡般地便宜,但遠非最佳。例如,當|Y|=10000和 T′=10, 我們只需要評估10000×10=105 序列。
10.8.3。波束搜索
您可以將序列解碼策略視為位于頻譜上, 波束搜索在貪婪搜索的效率和窮舉搜索的最優性之間做出折衷。波束搜索的最直接版本的特征在于單個超參數, 波束大小,k. 在時間步 1,我們選擇k具有最高預測概率的標記。他們每個人都將是第一個令牌k候選輸出序列,分別。在隨后的每個時間步,基于k上一時間步的候選輸出序列,我們繼續選擇k具有最高預測概率的候選輸出序列 k|Y|可能的選擇。
圖 10.8.3束搜索過程(束大小:2,輸出序列的最大長度:3)。候選輸出序列是A, C,AB,CE,ABD, 和CED.
圖 10.8.3舉例說明了 beam search 的過程。假設輸出詞匯表只包含五個元素:Y={A,B,C,D,E},其中之一是“”。令波束大小為 2,輸出序列的最大長度為 3。在時間步長 1,假設具有最高條件概率的標記P(y1∣c)是A 和C. 在時間步 2,對于所有y2∈Y,我們計算
(10.8.2)P(A,y2∣c)=P(A∣c)P(y2∣A,c),P(C,y2∣c)=P(C∣c)P(y2∣C,c),
并在這十個值中選擇最大的兩個,比如說 P(A,B∣c)和P(C,E∣c). 然后在第 3 步,對于所有y3∈Y, 我們計算
(10.8.3)P(A,B,y3∣c)=P(A,B∣c)P(y3∣A,B,c),P(C,E,y3∣c)=P(C,E∣c)P(y3∣C,E,c),
并在這十個值中選擇最大的兩個,比如說 P(A,B,D∣c)和 P(C,E,D∣c).結果,我們得到六個候選輸出序列:(i)A; (二)C; (三)A, B; (四)C,E; (五)A,B, D; (六)C,E,D.
最后,我們根據這六個序列得到最終候選輸出序列的集合(例如,丟棄包括“”和“”之后的部分)。然后我們選擇以下得分最高的序列作為輸出序列:
(10.8.4)1Lαlog?P(y1,…,yL∣c)=1Lα∑t′=1Llog?P(yt′∣y1,…,yt′?1,c),
在哪里L是最終候選序列的長度, α通常設置為 0.75。由于較長的序列在(10.8.4)的總和中具有更多的對數項,因此項Lα在分母中懲罰長序列。
beam search的計算成本是 O(k|Y|T′). 這個結果介于貪婪搜索和窮舉搜索之間。貪心搜索可以看作是波束大小為 1 時出現的波束搜索的特例。
10.8.4。概括
序列搜索策略包括貪婪搜索、窮舉搜索和波束搜索。波束搜索通過其對波束大小的靈活選擇,在準確性與計算成本之間進行權衡。
10.8.5。練習
我們可以將窮舉搜索視為一種特殊類型的波束搜索嗎?為什么或者為什么不?
在10.7 節的機器翻譯問題中應用集束搜索 。光束大小如何影響翻譯結果和預測速度?
在第 9.5 節中,我們使用語言建模來生成遵循用戶提供的前綴的文本。它使用哪種搜索策略?你能改進它嗎?
Discussions
-
解碼器
+關注
關注
9文章
1147瀏覽量
40875 -
pytorch
+關注
關注
2文章
808瀏覽量
13330
發布評論請先 登錄
相關推薦
評論