栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

pytorch中的l2

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

pytorch中的l2

pytorch_l2_normalize.py

import torch
import tensorflow as tf

########## PyTorch Version 1 ################
x = torch.randn(5, 6)
norm_th = x/torch.norm(x, p=2, dim=1, keepdim=True)
norm_th[torch.isnan(norm_th)] = 0 # to avoid nan

########## PyTorch Version 2 ################
norm_th = torch.nn.functional.normalize(x, p=2, dim=1)

########### Equivalent to ############
norm_tf = tf.nn.l2_normalize(x.numpy(), axis=1)

print(norm_th)
print(norm_tf)

reference
https://gist.github.com/EdisonLeeeee/290691c8b1895427024875c3fafece67

转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/344532.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号