Beam Search的問題
先解釋一下什么要對Beam Search進行改進,因為Beam Search雖然比貪心強了不少,但還是會生成出空洞、重復、前后矛盾的文本,如果你有文本生成經驗,一定對這些現象并不陌生,在語言模型還不像如今的BERT、GPT這么厲害的時候,這種現象更加明顯,
沒有經驗也沒關系,我們來看一個論文里面的例子,輸入模型的引文(context)
"The study, published in the Proceedings of the They were cattle called Bolivian Cavalleros; they live in a National Academy of Sciences of the United States of remote desert uninterrupted by town, and they speak huge, America (PNAS), was conducted by researchers from the beautiful, paradisiacal Bolivian linguistic thing. They say, Universidad Nacional Autónoma de México (UNAM) and
GPT-2模型, Beam Search, num_beams=32的生成結果:
'Lunch, marge.' They don't tell what the lunch is," director the Universidad Nacional Autónoma de México Professor Chuperas Omwell told Sky News. "They've only (UNAM/Universidad Nacional Autónoma de been talking to scientists, like we're being interviewed by TV México/Universidad Nacional Autónoma de reporters. We don't even stick around to be interviewed by México/Universidad Nacional Autónoma de TV reporters. Maybe that's how they figured out that they're México/Universidad Nacional Autónoma de ...”
可以發現即使是如今最頂級的語言模型加上足夠長的引文輸入,還是無法得到高質量的生成結果,
論文認為這種問題是由于這種試圖最大化序列條件概率的解碼策略從根上就有問題,他們對比了給定同樣引文的情況下人類續寫和機器生成的詞用語言模型計算出來的概率,如下圖所示,人類選擇的詞(橙線)并不是像機器選擇的(藍線)那樣總是那些條件概率最大的詞,從生成的結果也可以看出,機器生成的結果有大量重復,
機器選詞和人類選詞的概率對比圖
解決對策
人們其實嘗試了各種辦法對Beam Search進行改進,其實都很好理解,這篇論文總結的也比較到位,
隨機采樣
第一種方法是用隨機采樣(sampling)代替取概率最大的詞,采樣的依據就是解碼器輸出的詞典中每個詞的概率分布,相比于按概率“掐尖”,這樣會增大所選詞的范圍,引入更多的隨機性,當時那篇論文的結論就是這種隨機采樣的方法遠好于Beam Search,但這其實也是有條件的,隨機采樣容易產生前后不一致的問題,而在開放閑聊領域,生成文本的長度都比較短,這種問題就被自然的淡化了,
采樣的時候有一個可以控制的超引數,稱為溫度(temperature, ),解碼器的輸出層后面通常會跟一個softmax函式來將輸出概率歸一化,通過改變 可以控制概率分布的形貌,softmax的公式如下,當 大的時候,概率分布趨向平均,隨機性增大;當 小的時候,概率密度趨向于集中,即強者愈強,隨機性降低,會更多地采樣出“放之四海而皆準”的詞匯,
top-k采樣
這個方法就是在采樣前將輸出的概率分布截斷,取出概率最大的k個詞構成一個集合,然后將這個子集詞的概率再歸一化,最后從新的概率分布中采樣詞匯,這個辦法據說可以獲得比Beam Search好很多的效果,但也有一個問題,就是這個k不太好選,
While top-k sampling leads to considerably higher quality text than either beam search or sampling from the full distribution, the use of a constant k is sub-optimal across varying contexts.
為啥呢?因為這個概率分布變化比較大,有時候可能很均勻(flat),有的時候比較集中(peaked),對于集中的情況還好說,當分布均勻時,一個較小的k容易丟掉很多優質候選詞,但如果k定的太大,這個方法又會退化回普通采樣,
兩種分布,左邊是均勻的,右邊是集中的
核采樣(Nucleus sampling)
首先表示我不確定這個翻譯是不是對的,
這是這篇論文提出的方式,也是相比前面那些都更好的采樣方式,這個方法不再取一個固定的k,而是固定候選集合的概率密度和在整個概率分布中的比例,也就是構造一個最小候選集V ,使得

選出來這個集合之后也和top-k采樣一樣,重新歸一化集合內詞的概率,并把集合外詞的概率設為0,這種方式也稱為top-p采樣,
論文有一個圖,對比了這幾種采樣方式的效果,
效果對比圖,紅字是前后不符,藍字是重復,Nucleus效果拔群,
懲罰重復
為了解決重復問題,還可以通過懲罰因子將出現過詞的概率變小或者強制不使用重復詞來解決,懲罰因子來自于同樣廣為流傳的《CTRL: A Conditional Transformer Language Model for Controllable Generation》[2],如果大家感興趣的話后面可以專門寫一期可控文本生成方向的解讀,
代碼決議
其實上述各種采樣方式在HuggingFace的庫里都已經實作了(感動!),我們來看一下代碼,
先看top-k和top-p采樣
1 # 代碼輸入的是logits,而且考慮很周全(我感覺漏了考慮k和p都給了的情況,這應該是不合適的) 2 # 巧妙地使用了torch.cumsum 3 # 避免了一個詞都選不出來的尷尬情況 4 def top_k_top_p_filtering(logits, top_k=0, top_p=1.0, filter_value=https://www.cnblogs.com/cs-markdown10086/p/-float("Inf"), min_tokens_to_keep=1): 5 """ Filter a distribution of logits using top-k and/or nucleus (top-p) filtering 6 Args: 7 logits: logits distribution shape (batch size, vocabulary size) 8 if top_k > 0: keep only top k tokens with highest probability (top-k filtering). 9 if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering). 10 Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751) 11 Make sure we keep at least min_tokens_to_keep per batch example in the output 12 From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317 13 """ 14 if top_k > 0: 15 top_k = min(max(top_k, min_tokens_to_keep), logits.size(-1)) # Safety check 16 # Remove all tokens with a probability less than the last token of the top-k 17 indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] 18 logits[indices_to_remove] = filter_value 19 20 if top_p < 1.0: 21 sorted_logits, sorted_indices = torch.sort(logits, descending=True) 22 cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) 23 24 # Remove tokens with cumulative probability above the threshold (token with 0 are kept) 25 sorted_indices_to_remove = cumulative_probs > top_p 26 if min_tokens_to_keep > 1: 27 # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below) 28 sorted_indices_to_remove[..., :min_tokens_to_keep] = 0 29 # Shift the indices to the right to keep also the first token above the threshold 30 sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() 31 sorted_indices_to_remove[..., 0] = 0 32 33 # scatter sorted tensors to original indexing 34 indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) 35 logits[indices_to_remove] = filter_value 36 return logits
再看看重復懲罰
1 # 輸入的同樣是logits(lprobs) 2 # 同時輸入了之前出現過的詞以及懲罰系數(大于1的) 3 # 考慮到了logit是正和負時處理方式應該不一樣 4 def enforce_repetition_penalty_(self, lprobs, batch_size, num_beams, prev_output_tokens, repetition_penalty): 5 """repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858). """ 6 for i in range(batch_size * num_beams): 7 for previous_token in set(prev_output_tokens[i].tolist()): 8 # if score < 0 then repetition penalty has to multiplied to reduce the previous token probability 9 if lprobs[i, previous_token] < 0: 10 lprobs[i, previous_token] *= repetition_penalty 11 else: 12 lprobs[i, previous_token] /= repetition_penalty
最后是重復詞去除
1 # 這個函式將會回傳一個不可使用的詞表 2 # 生成n-gram的巧妙方式大家可以借鑒一下 3 # 下面是一個3-gram的例子 4 # a = [1,2,3,4,5] 5 # for ngram in zip(*[a[i:] for i in range(3)]): 6 # print(ngram) 7 def calc_banned_tokens(prev_input_ids, num_hypos, no_repeat_ngram_size, cur_len): 8 # Copied from fairseq for no_repeat_ngram in beam_search""" 9 if cur_len + 1 < no_repeat_ngram_size: 10 # return no banned tokens if we haven't generated no_repeat_ngram_size tokens yet 11 return [[] for _ in range(num_hypos)] 12 generated_ngrams = [{} for _ in range(num_hypos)] 13 for idx in range(num_hypos): 14 gen_tokens = prev_input_ids[idx].numpy().tolist() 15 generated_ngram = generated_ngrams[idx] 16 # 就是這巧妙的一句 17 for ngram in zip(*[gen_tokens[i:] for i in range(no_repeat_ngram_size)]): 18 prev_ngram_tuple = tuple(ngram[:-1]) 19 generated_ngram[prev_ngram_tuple] = generated_ngram.get(prev_ngram_tuple, []) + [ngram[-1]] 20 def _get_generated_ngrams(hypo_idx): 21 # Before decoding the next token, prevent decoding of ngrams that have already appeared 22 start_idx = cur_len + 1 - no_repeat_ngram_size 23 ngram_idx = tuple(prev_input_ids[hypo_idx, start_idx:cur_len].numpy().tolist()) 24 return generated_ngrams[hypo_idx].get(ngram_idx, []) 25 banned_tokens = [_get_generated_ngrams(hypo_idx) for hypo_idx in range(num_hypos)] 26 return banned_tokens
以上這些代碼應該在哪里呼叫相信看上一篇文章的朋友都應該知道了,這里就放出來最核心的差異,
1 if do_sample: 2 # 這是今天的采樣方式 3 _scores = scores + beam_scores[:, None].expand_as(scores) # (batch_size * num_beams, vocab_size) 4 # Top-p/top-k filtering,這一步重建了候選集 5 _scores = top_k_top_p_filtering( 6 _scores, top_k=top_k, top_p=top_p, min_tokens_to_keep=2 7 ) # (batch_size * num_beams, vocab_size) 8 # re-organize to group the beam together to sample from all beam_idxs 9 _scores = _scores.contiguous().view( 10 batch_size, num_beams * vocab_size 11 ) # (batch_size, num_beams * vocab_size) 12 13 # Sample 2 next tokens for each beam (so we have some spare tokens and match output of greedy beam search) 14 probs = F.softmax(_scores, dim=-1) 15 # 采樣 16 next_tokens = torch.multinomial(probs, num_samples=2 * num_beams) # (batch_size, num_beams * 2) 17 # Compute next scores 18 next_scores = torch.gather(_scores, -1, next_tokens) # (batch_size, num_beams * 2) 19 # sort the sampled vector to make sure that the first num_beams samples are the best 20 next_scores, next_scores_indices = torch.sort(next_scores, descending=True, dim=1) 21 next_tokens = torch.gather(next_tokens, -1, next_scores_indices) # (batch_size, num_beams * 2) 22 else: 23 # 這是昨天的beam search方式 24 # 直接將log概率相加求條件概率 25 next_scores = scores + beam_scores[:, None].expand_as(scores) # (batch_size * num_beams, vocab_size) 26 27 # re-organize to group the beam together (we are keeping top hypothesis accross beams) 28 next_scores = next_scores.view( 29 batch_size, num_beams * vocab_size 30 ) # (batch_size, num_beams * vocab_size) 31 32 next_scores, next_tokens = torch.topk(next_scores, 2 * num_beams, dim=1, largest=True, sorted=True)
OK,謝謝各位看到這里,祝大家生成出高質量的文本!
參考資料
[1]
The Curious Case of Neural Text Degeneration: https://arxiv.org/abs/1904.09751
[2]
CTRL: A Conditional Transformer Language Model for Controllable Generation: https://arxiv.org/abs/1909.05858
轉載請註明出處,本文鏈接:https://www.uj5u.com/qita/285822.html
標籤:其他
下一篇:Celery異步任務
