栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

batch

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

batch

这个bug发生于写BERT的在线inference代码。错误如下:

Traceback (most recent call last):
  File "classifier.py", line 94, in 
    result = clf.predict(content)
  File "classifier.py", line 77, in predict
    outputs = self.model(ids, mask, token_type_ids)
  File "/data/angchen/env/torch/lib/python3.6/site-packages/torch/nn/modules/module.py", line 493, in __call__
    result = self.forward(*input, **kwargs)
  File "../src/bert.py", line 29, in forward
    _, output_1 = self.l1(ids, attention_mask = mask, token_type_ids=token_type_ids)
  File "/data/angchen/env/torch/lib/python3.6/site-packages/torch/nn/modules/module.py", line 493, in __call__
    result = self.forward(*input, **kwargs)
  File "/data/angchen/env/torch/lib/python3.6/site-packages/transformers/models/bert/modeling_bert.py", line 950, in forward
    batch_size, seq_length = input_shape
ValueError: not enough values to unpack (expected 2, got 1)

inference部分代码如下:

46     def tokenize(self, text):
 47         text = text.strip()
 48         inputs = self.tokenizer.encode_plus(
 49             text,
 50             None,
 51             add_special_tokens=True,
 52             max_length=self.max_len,
 53             pad_to_max_length=True, #不足部分填充
 54             return_token_type_ids=True,
 55             truncation=True #超过部分截断
 56             )
 57 
 58         #提取id, attention mask, token type id
 59         ids = torch.tensor(inputs['input_ids'], dtype=torch.long).to(device, dtype=torch.long)
 60         mask = torch.tensor(inputs['attention_mask'], dtype=torch.long).to(device, dtype=torch.long)
 61         token_type_ids = torch.tensor(inputs['token_type_ids'], dtype=torch.long).to(device, dtype=torch.long)
 62 
 63         return ids, mask, token_type_ids
 64 
 65 
 66     def predict(self, content):
 67         result = {}
 68         #如果content为空 (文本分类器可以预测不包含汉字的文本) 则直接返回结果
 69         if not content:
 70             result["category"] = -1
 71             result["catScores"] = []
 72             return result
 73         ids, mask, token_type_ids = self.tokenize(content)
 74         print("ids:n", ids, "size:n", ids.size())
 75         print("mask:n", mask)
 76         print("token_type_ids:n", token_type_ids)
 77         outputs = self.model(ids, mask, token_type_ids)
 78         print(outputs)
 79         """

后来发现原因是预测部分喂给模型的是一个长度为 [200]的tensor,而模型要求tensor至少还有一个维度,比如[1, 200],所以导致了 batch_size, seq_length = input_shape
ValueError: not enough values to unpack (expected 2, got 1)

解决方案:
因为缺少一个维度,所以要增加一个维度,这里用unsqueeze函数。修改tokenize函数代码如下:

46     def tokenize(self, text):
 47         text = text.strip()
 48         inputs = self.tokenizer.encode_plus(
 49             text,
 50             None,
 51             add_special_tokens=True,
 52             max_length=self.max_len,
 53             pad_to_max_length=True, #不足部分填充
 54             return_token_type_ids=True,
 55             truncation=True #超过部分截断
 56             )
 57 
 58         #提取id, attention mask, token type id
 59         ids = torch.tensor(inputs['input_ids'], dtype=torch.long).to(device, dtype=torch.long).unsqueeze(0)
 60         mask = torch.tensor(inputs['attention_mask'], dtype=torch.long).to(device, dtype=torch.long).unsqueeze(0)
 61         token_type_ids = torch.tensor(inputs['token_type_ids'], dtype=torch.long).to(device, dtype=torch.long).unsqueeze(0)
 62 
 63         return ids, mask, token_type_ids

之后就能跑通了

转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/588903.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号