pytorch_l2_normalize.py
import torch import tensorflow as tf ########## PyTorch Version 1 ################ x = torch.randn(5, 6) norm_th = x/torch.norm(x, p=2, dim=1, keepdim=True) norm_th[torch.isnan(norm_th)] = 0 # to avoid nan ########## PyTorch Version 2 ################ norm_th = torch.nn.functional.normalize(x, p=2, dim=1) ########### Equivalent to ############ norm_tf = tf.nn.l2_normalize(x.numpy(), axis=1) print(norm_th) print(norm_tf)
reference
https://gist.github.com/EdisonLeeeee/290691c8b1895427024875c3fafece67



