栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

tensorflow2.x 高阶操作

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

tensorflow2.x 高阶操作

高阶操作
        • where
        • scatter_nd
        • tf.meshgrid()
        • tf.meshgrid()介绍
        • 例:画出函数풛 = 풔풊풏 풙 + 풔풊풏(풚)的曲面和等高线

where
  • 一个参数:condition=mask, x=None, y=None, name=None
    • 返回 mask中值为True的元素的坐标
  • 3个参数:condition=mask, x=A, y=B, name=None
    • condition:为一个True和False的tensor
    • x,y必须和condition形状一样
    • 如果为True,将A中对应的值填入;如果为False,将B中对应的值填入
    • 返回的就是这个由A、B中元素组成的新的tensor
a = tf.random.normal([3,3])
print(a)
mask = a > 0
print(mask)#返回的是True和False形成的一个tensor
print(tf.boolean_mask(a, mask))#取出所有的满足条件的值(True)
indices = tf.where(mask)
print(indices)#所有True对应的索引
print(tf.gather_nd(a, indices))

A = tf.ones([3,3])
B = tf.zeros([3,3])
print(tf.where(mask, A, B))

scatter_nd

import tensorflow as tf

# tf.scatter_nd()
indices = tf.constant([[4],[3],[1],[7]])#更新值对应的索引
updates = tf.constant([9,10,11,12])#更新的值
shape = tf.constant([8])#底板 —— 长度为8,值全为0
print(tf.scatter_nd(indices, updates, shape))
# tf.Tensor([ 0 11  0 10  9  0  0 12], shape=(8,), dtype=int32)

indices = tf.constant([[0],[2]])#更新的值对应
updatas = tf.constant([[[5,5,5,5],[6,6,6,6],
                        [7,7,7,7],[8,8,8,8]],
                       [[5,5,5,5],[6,6,6,6],
                        [7,7,7,7],[8,8,8,8]]])
print(updatas.shape)
shape = tf.constant([4,4,4])
print(tf.scatter_nd(indices, updatas, shape))

tf.meshgrid()

要求:得到下图右侧的点集

  • 用numpy实现
import numpy as np

#numpy实现 没有实现GPU加速
points = []
for y in np.linspace(-2,2,5):
    for x in np.linspace(-2,2,5):
        points.append([x,y])
  • 用tf.meshgrid 、tf.stack实现
#tensorflow实现  实现GPU加速
y = tf.linspace(-2.,2,5)
# print(y)#[-2. -1.  0.  1.  2.]
x = tf.linspace(-2.,2,5)
# print(x)#[-2. -1.  0.  1.  2.]
points_x,points_y = tf.meshgrid(x,y)
print(points_x.shape)# (5, 5)
print(points_x)
print(points_y)
points = tf.stack([points_x,points_y],axis=2)#张量的拼接——在原基础上增加维度
print(points)
print(points.shape)#(5, 5, 2)

tf.meshgrid()介绍

def meshgrid(*args, **kwargs):
给定N个一维坐标数组’ *args ‘,返回一个列表’ outputs '的N-D坐标数组,用于在N-D网格上计算表达式。
(来自有道翻译)

——自己的理解(小白学python,不一定准确)

points_x,points_y = tf.meshgrid(X,Y)
传递了两个一维数组坐标X、Y,
X为1行m列,Y为1行n列

  • 对于points_x :
    • points_x.shape = (n,m)
    • 每一行的值都与X相同
  • 对于points_y:
    • points_y.shape = (n,m)
    • 每i行的值,都等于Y中i列的值

总的来说:
他就是给定n个一维数组坐标,结合stack,那么,所有的给定的坐标所能形成的点都可以用tensor表示出来
然后如果是二维,则形成一个网格表,三维也差不多是这个意思

例:画出函数풛 = 풔풊풏 풙 + 풔풊풏(풚)的曲面和等高线

该例子的话是给定xy的范围(或值),根据函数풛 = 풔풊풏 풙 + 풔풊풏(풚),计算出z,再用图形表示出来这个函数f(x,y)=z
points 对应的是[x,y]形成的(500,500,2)的形状的tensor,因为给定的x、y个有500个值

通过x[…,0],x[…,1]取出x、y,并放入到函数풛 = 풔풊풏 풙 + 풔풊풏(풚),计算出z,最后作图

# -*- coding = utf-8 -*-
# @Time :2021/9/30 11:30
# @Author : Min
# @File : t9
# @Software : PyCharm
import tensorflow as tf
import matplotlib.pyplot as plt

def func(x):
    #实现函数
    z = tf.math.sin(x[...,0]) + tf.math.sin(x[...,1])
    return z;

x = tf.linspace(0.,2*3.14,500)
y = tf.linspace(0.,2*3.14,500)
points_x,points_y = tf.meshgrid(x,y)
points = tf.stack([points_x,points_y],axis=2)
z = func(points)

#画图
plt.figure('plot 2d func value')
plt.imshow(z,origin='lower',interpolation='none')
plt.colorbar()
plt.figure('plot 2d func contour')
plt.contour(points_x,points_y,z)
plt.colorbar()
plt.show()


转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/283342.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号