环境设置:
import matplotlib.pyplot as plt
import numpy as np
np.set_printoptions(precision=4, suppress=True)
# 设置 numpy 的输出精度
red = 'orangered'
orange = 'orange'
yellow = 'yellow'
blue = 'deepskyblue'
purple = 'mediumpurple'
pink = 'violet'
# matplotlib 颜色常量
齐次变换矩阵:
def trans(dx, dy, dz):
''' 齐次变换矩阵: 平移'''
mat = np.eye(4)
mat[:3, -1] += np.array([dx, dy, dz])
return mat
def rot(theta, axis):
''' 齐次变换矩阵: 旋转'''
theta = theta / 180 * np.pi
# 角度 -> 弧度
sin = np.sin(theta)
cos = np.cos(theta)
mat = np.eye(4)
axis_idx = {'x': 0, 'y': 1, 'z': 2}
if isinstance(axis, str):
axis = axis_idx[axis]
# 字符 -> 空间轴名称
if axis == 0:
mat[1: 3, 1: 3] = np.array([[cos, -sin],
[sin, cos]])
elif axis == 1:
mat[:3, :3] = np.array([[cos, 0, sin],
[0, 1, 0],
[-sin, 0, cos]])
elif axis == 2:
mat[:2, :2] = np.array([[cos, -sin],
[sin, cos]])
else:
raise AssertionError(f'axis: {axis_idx}')
return mat
命令执行:
def run(commands, start=np.eye(4)):
''' 执行运动命令
rot: ('rot', angle, axis, pace)
trans: ('trans', dx, dy, dz, pace)'''
state = [start]
for i, com in enumerate(commands):
type_, *args, pace = com
# 解析参数
if type_ == 'rot':
angle_sum, axis = args
angle = angle_sum / pace
mat = rot(angle, axis)
# 分步旋转矩阵
elif type_ == 'trans':
move_sum = np.array(args) / pace
mat = trans(*move_sum)
# 分步平移矩阵
else:
raise AssertionError(f'Command {i + 1}: Illegal type')
for p in range(pace):
new = mat @ state[-1]
state.append(new)
# 状态转移
return np.stack(state)
命令分步:
def get_commands(rot_pace, trans_pace):
''' 旋转步数、平移步数 -> 命令组'''
return [
('rot', 90, 'z', rot_pace),
('rot', 90, 'y', rot_pace),
('trans', 4, -3, 7, trans_pace)
]
3d坐标系绘制:
def plot_coord_sys(state, colors=[orange, yellow, blue],
labels='noa', linewidth=None, scale=.5):
''' 绘制坐标系'''
pos = state[:3, -1]
for idx, (c, label) in enumerate(zip(colors, labels)):
axis = state[:3, idx] * scale
plt.plot(*zip(pos, pos + axis), c=c, label=label, linewidth=linewidth)
def plot_global_sys(colors=[red, pink, purple], linewidth=None, scale=1.):
''' 绘制全局坐标系'''
plot_coord_sys(np.eye(4), colors=colors, labels='xyz', linewidth=linewidth, scale=scale)
def figure3d():
''' 创建3d工作站'''
figure = plt.subplot(projection='3d')
figure.set_xlabel('x')
figure.set_ylabel('y')
figure.set_zlabel('z')
return figure
播放动画/逐步变换:
def show(cartoon=False):
if cartoon:
# 播放动画
figure3d()
plt.ion()
plot_global_sys(linewidth=5)
plt.pause(3)
for i, s in enumerate(run(get_commands(10, 10))):
plot_coord_sys(s)
if i == 0:
plt.legend()
plt.pause(0.25)
plt.pause(0)
else:
# 逐步变换
for i, s in enumerate(run(get_commands(1, 1))):
figure3d()
plot_global_sys(linewidth=5)
plot_coord_sys(s)
plt.legend()
plt.show()
show(True)
最终效果: