栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

transformers.generator

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

transformers.generator

        sample函数相较于beam_search函数要简单的多,但是需要注意的一点是,sample需要搭配logits_warper处理器列表使用,相应的处理器函数在下面。sample函数的源码解释如下,比较浅显易懂。

# auto-regressive generation
while True:
    # prepare model inputs
    model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)

    # forward pass to get next token
    outputs = self(
        **model_inputs,
        return_dict=True,
        output_attentions=output_attentions,
        output_hidden_states=output_hidden_states,
    )

    if synced_gpus and this_peer_finished:
        cur_len = cur_len + 1
        continue  # don't waste resources running the code we don't need

    next_token_logits = outputs.logits[:, -1, :] #获取cur-step的预测输出

    # pre-process distribution
    next_token_scores = logits_processor(input_ids, next_token_logits)
    next_token_scores = logits_warper(input_ids, next_token_scores)

    # Store scores, attentions and hidden_states when required
    if return_dict_in_generate:
        if output_scores:
            scores += (next_token_scores,)
        if output_attentions:
            decoder_attentions += (
                (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
            )
            if self.config.is_encoder_decoder:
                cross_attentions += (outputs.cross_attentions,)

        if output_hidden_states:
            decoder_hidden_states += (
                (outputs.decoder_hidden_states,)
                if self.config.is_encoder_decoder
                else (outputs.hidden_states,)
            )

    # sample
    probs = F.softmax(next_token_scores, dim=-1) #对重构后的预测分布进行softmax
    # 多项式函数取样
    next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)

    # finished sentences should have their next token be a padding token
    if eos_token_id is not None:
        assert pad_token_id is not None, "If eos_token_id is defined, make sure that pad_token_id is defined."
        next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)

    # update generated ids, model inputs, and length for next step
    input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
    model_kwargs = self._update_model_kwargs_for_generation(
        outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
    )
    cur_len = cur_len + 1

    # if eos_token was found in one sentence, set sentence to finished
    if eos_token_id is not None:
        unfinished_sequences = unfinished_sequences.mul((next_tokens != eos_token_id).long())

    # stop when each sentence is finished, or if we exceed the maximum length
    if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):
        if not synced_gpus:
            break
        else:
            this_peer_finished = True
return input_ids

        transfomrer库中的logits_warper系列处理函数包括TemperatureLogitsWarper、TopKLogitsWarper和TopPLogitsWarper。

  • TemperatureLogitsWarper函数:对生成的预测分布概率进行重构,主要是将分布分数除于Temperature超参。在大多数研究中, tempreature的选择,往往呈现如下规律:
    • 实际应用中,往往experiment with multiple temperature values! 当保持了一定的随机性又能不破坏结构时,往往会得到有意思的生成文本。
    • 当 设置高 temperature时,文本局部结构往往会被破坏,大多数词可能会时semi-random strings 的形式。
    • 当temperatures较大时, 生成的文本更具有随机性( random)、趣味性( interesting),甚至创造性( creative); 甚至有些时候能发现一些新词(misspelled words) 。
    • 当temperature较小时,会引发极大的 repetitive 和predictable文本,但是文本内容往往更贴合语料(highly realistic),基本所有的词都来自与语料库。
    • 当 temperature 设置为较小或者0的值时, Temperature Sampling 等同于 每次选择最大概率的 Greedy Search。
class TemperatureLogitsWarper(LogitsWarper):
    r"""
    :class:`transformers.LogitsWarper` for temperature (exponential scaling output probability distribution).

    Args:
        temperature (:obj:`float`):
            The value used to module the logits distribution.
    """

    def __init__(self, temperature: float):
        if not isinstance(temperature, float) or not (temperature > 0):
            raise ValueError(f"`temperature` has to be a strictly positive float, but is {temperature}")

        self.temperature = temperature

    def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
        scores = scores / self.temperature
        return scores
  • TopKLogitsWarper函数:对模型cur-step预测的概率分布进行排序,保留前topK个分布。
    • 优点:基本top k的采样方法,能够提升生成质量,因为它会把概率较低的结果丢弃( removing the tail),因此能使得生成过程不那么偏离主题。
    • 缺点:丢弃掉的部分(Tail)可能会包含很多的词语,这导致我们能选择的词汇较少。而在一些情况下,丢弃掉大部分可能包含的词汇较少,我们能生成较为丰富的文本。

    • 因此, k 值的选择对于生成结果极其重要。

class TopKLogitsWarper(LogitsWarper):
    r"""
    :class:`transformers.LogitsWarper` that performs top-k, i.e. restricting to the k highest probability elements.

    Args:
        top_k (:obj:`int`):
            The number of highest probability vocabulary tokens to keep for top-k-filtering.
        filter_value (:obj:`float`, `optional`, defaults to :obj:`-float("Inf")`):
            All filtered values will be set to this float value.
        min_tokens_to_keep (:obj:`int`, `optional`, defaults to 1):
            Minimum number of tokens that cannot be filtered.
    """

    def __init__(self, top_k: int, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
        if not isinstance(top_k, int) or top_k <= 0:
            raise ValueError(f"`top_k` has to be a strictly positive integer, but is {top_k}")

        self.top_k = top_k
        self.filter_value = filter_value
        self.min_tokens_to_keep = min_tokens_to_keep

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
        top_k = min(max(self.top_k, self.min_tokens_to_keep), scores.size(-1))  # Safety check
        # Remove all tokens with a probability less than the last token of the top-k
        #torch.topk(scores, top_k)[0][..., -1]获得top_k排序后最后一个值,即把分布中小于该值的位置设置为true。
        indices_to_remove = scores < torch.topk(scores, top_k)[0][..., -1, None]
        #将indices_to_remove中为true的值替换为self.filter_value(一般为-inf)
        scores = scores.masked_fill(indices_to_remove, self.filter_value) 
        return scores
转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/580860.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号