sample函数相较于beam_search函数要简单的多,但是需要注意的一点是,sample需要搭配logits_warper处理器列表使用,相应的处理器函数在下面。sample函数的源码解释如下,比较浅显易懂。
# auto-regressive generation
while True:
# prepare model inputs
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
# forward pass to get next token
outputs = self(
**model_inputs,
return_dict=True,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
if synced_gpus and this_peer_finished:
cur_len = cur_len + 1
continue # don't waste resources running the code we don't need
next_token_logits = outputs.logits[:, -1, :] #获取cur-step的预测输出
# pre-process distribution
next_token_scores = logits_processor(input_ids, next_token_logits)
next_token_scores = logits_warper(input_ids, next_token_scores)
# Store scores, attentions and hidden_states when required
if return_dict_in_generate:
if output_scores:
scores += (next_token_scores,)
if output_attentions:
decoder_attentions += (
(outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
)
if self.config.is_encoder_decoder:
cross_attentions += (outputs.cross_attentions,)
if output_hidden_states:
decoder_hidden_states += (
(outputs.decoder_hidden_states,)
if self.config.is_encoder_decoder
else (outputs.hidden_states,)
)
# sample
probs = F.softmax(next_token_scores, dim=-1) #对重构后的预测分布进行softmax
# 多项式函数取样
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
# finished sentences should have their next token be a padding token
if eos_token_id is not None:
assert pad_token_id is not None, "If eos_token_id is defined, make sure that pad_token_id is defined."
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
# update generated ids, model inputs, and length for next step
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
model_kwargs = self._update_model_kwargs_for_generation(
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
)
cur_len = cur_len + 1
# if eos_token was found in one sentence, set sentence to finished
if eos_token_id is not None:
unfinished_sequences = unfinished_sequences.mul((next_tokens != eos_token_id).long())
# stop when each sentence is finished, or if we exceed the maximum length
if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):
if not synced_gpus:
break
else:
this_peer_finished = True
return input_ids
transfomrer库中的logits_warper系列处理函数包括TemperatureLogitsWarper、TopKLogitsWarper和TopPLogitsWarper。
- TemperatureLogitsWarper函数:对生成的预测分布概率进行重构,主要是将分布分数除于Temperature超参。在大多数研究中, tempreature的选择,往往呈现如下规律:
- 实际应用中,往往experiment with multiple temperature values! 当保持了一定的随机性又能不破坏结构时,往往会得到有意思的生成文本。
- 当 设置高 temperature时,文本局部结构往往会被破坏,大多数词可能会时semi-random strings 的形式。
- 当temperatures较大时, 生成的文本更具有随机性( random)、趣味性( interesting),甚至创造性( creative); 甚至有些时候能发现一些新词(misspelled words) 。
- 当temperature较小时,会引发极大的 repetitive 和predictable文本,但是文本内容往往更贴合语料(highly realistic),基本所有的词都来自与语料库。
- 当 temperature 设置为较小或者0的值时, Temperature Sampling 等同于 每次选择最大概率的 Greedy Search。
class TemperatureLogitsWarper(LogitsWarper):
r"""
:class:`transformers.LogitsWarper` for temperature (exponential scaling output probability distribution).
Args:
temperature (:obj:`float`):
The value used to module the logits distribution.
"""
def __init__(self, temperature: float):
if not isinstance(temperature, float) or not (temperature > 0):
raise ValueError(f"`temperature` has to be a strictly positive float, but is {temperature}")
self.temperature = temperature
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
scores = scores / self.temperature
return scores
- TopKLogitsWarper函数:对模型cur-step预测的概率分布进行排序,保留前topK个分布。
- 优点:基本top k的采样方法,能够提升生成质量,因为它会把概率较低的结果丢弃( removing the tail),因此能使得生成过程不那么偏离主题。
-
缺点:丢弃掉的部分(Tail)可能会包含很多的词语,这导致我们能选择的词汇较少。而在一些情况下,丢弃掉大部分可能包含的词汇较少,我们能生成较为丰富的文本。
-
因此, k 值的选择对于生成结果极其重要。
class TopKLogitsWarper(LogitsWarper):
r"""
:class:`transformers.LogitsWarper` that performs top-k, i.e. restricting to the k highest probability elements.
Args:
top_k (:obj:`int`):
The number of highest probability vocabulary tokens to keep for top-k-filtering.
filter_value (:obj:`float`, `optional`, defaults to :obj:`-float("Inf")`):
All filtered values will be set to this float value.
min_tokens_to_keep (:obj:`int`, `optional`, defaults to 1):
Minimum number of tokens that cannot be filtered.
"""
def __init__(self, top_k: int, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
if not isinstance(top_k, int) or top_k <= 0:
raise ValueError(f"`top_k` has to be a strictly positive integer, but is {top_k}")
self.top_k = top_k
self.filter_value = filter_value
self.min_tokens_to_keep = min_tokens_to_keep
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
top_k = min(max(self.top_k, self.min_tokens_to_keep), scores.size(-1)) # Safety check
# Remove all tokens with a probability less than the last token of the top-k
#torch.topk(scores, top_k)[0][..., -1]获得top_k排序后最后一个值,即把分布中小于该值的位置设置为true。
indices_to_remove = scores < torch.topk(scores, top_k)[0][..., -1, None]
#将indices_to_remove中为true的值替换为self.filter_value(一般为-inf)
scores = scores.masked_fill(indices_to_remove, self.filter_value)
return scores



