diff --git a/lib/mpl_toolkits/mplot3d/art3d.py b/lib/mpl_toolkits/mplot3d/art3d.py index e051e44fb23d..e579c6e48595 100644 --- a/lib/mpl_toolkits/mplot3d/art3d.py +++ b/lib/mpl_toolkits/mplot3d/art3d.py @@ -58,11 +58,11 @@ def get_dir_vector(zdir): x, y, z : array The direction vector. """ - if zdir == 'x': + if cbook._str_equal(zdir, 'x'): return np.array((1, 0, 0)) - elif zdir == 'y': + elif cbook._str_equal(zdir, 'y'): return np.array((0, 1, 0)) - elif zdir == 'z': + elif cbook._str_equal(zdir, 'z'): return np.array((0, 0, 1)) elif zdir is None: return np.array((0, 0, 0)) diff --git a/lib/mpl_toolkits/mplot3d/tests/test_art3d.py b/lib/mpl_toolkits/mplot3d/tests/test_art3d.py index 8ff6050443ab..aca943f9e0c0 100644 --- a/lib/mpl_toolkits/mplot3d/tests/test_art3d.py +++ b/lib/mpl_toolkits/mplot3d/tests/test_art3d.py @@ -1,15 +1,32 @@ import numpy as np +import numpy.testing as nptest +import pytest import matplotlib.pyplot as plt from matplotlib.backend_bases import MouseEvent from mpl_toolkits.mplot3d.art3d import ( + get_dir_vector, Line3DCollection, Poly3DCollection, _all_points_on_plane, ) +@pytest.mark.parametrize("zdir, expected", [ + ("x", (1, 0, 0)), + ("y", (0, 1, 0)), + ("z", (0, 0, 1)), + (None, (0, 0, 0)), + ((1, 2, 3), (1, 2, 3)), + (np.array([4, 5, 6]), (4, 5, 6)), +]) +def test_get_dir_vector(zdir, expected): + res = get_dir_vector(zdir) + assert isinstance(res, np.ndarray) + nptest.assert_array_equal(res, expected) + + def test_scatter_3d_projection_conservation(): fig = plt.figure() ax = fig.add_subplot(projection='3d')