@@ -462,26 +462,6 @@ def draw(self, renderer):
462462 self .offsetText .set_ha (align )
463463 self .offsetText .draw (renderer )
464464
465- if self .axes ._draw_grid and len (ticks ):
466- # Grid points where the planes meet
467- xyz0 = np .tile (minmax , (len (ticks ), 1 ))
468- xyz0 [:, index ] = [tick .get_loc () for tick in ticks ]
469-
470- # Grid lines go from the end of one plane through the plane
471- # intersection (at xyz0) to the end of the other plane. The first
472- # point (0) differs along dimension index-2 and the last (2) along
473- # dimension index-1.
474- lines = np .stack ([xyz0 , xyz0 , xyz0 ], axis = 1 )
475- lines [:, 0 , index - 2 ] = maxmin [index - 2 ]
476- lines [:, 2 , index - 1 ] = maxmin [index - 1 ]
477- self .gridlines .set_segments (lines )
478- gridinfo = info ['grid' ]
479- self .gridlines .set_color (gridinfo ['color' ])
480- self .gridlines .set_linewidth (gridinfo ['linewidth' ])
481- self .gridlines .set_linestyle (gridinfo ['linestyle' ])
482- self .gridlines .do_3d_projection ()
483- self .gridlines .draw (renderer )
484-
485465 # Draw ticks:
486466 tickdir = self ._get_tickdir ()
487467 tickdelta = deltas [tickdir ] if highs [tickdir ] else - deltas [tickdir ]
@@ -519,6 +499,46 @@ def draw(self, renderer):
519499 renderer .close_group ('axis3d' )
520500 self .stale = False
521501
502+ @artist .allow_rasterization
503+ def draw_grid (self , renderer ):
504+ if not self .axes ._draw_grid :
505+ return
506+
507+ self .label ._transform = self .axes .transData
508+ renderer .open_group ("grid3d" , gid = self .get_gid ())
509+
510+ ticks = self ._update_ticks ()
511+ if len (ticks ):
512+ # Get general axis information:
513+ info = self ._axinfo
514+ index = info ["i" ]
515+
516+ mins , maxs , tc , highs = self ._get_coord_info ()
517+
518+ minmax = np .where (highs , maxs , mins )
519+ maxmin = np .where (~ highs , maxs , mins )
520+
521+ # Grid points where the planes meet
522+ xyz0 = np .tile (minmax , (len (ticks ), 1 ))
523+ xyz0 [:, index ] = [tick .get_loc () for tick in ticks ]
524+
525+ # Grid lines go from the end of one plane through the plane
526+ # intersection (at xyz0) to the end of the other plane. The first
527+ # point (0) differs along dimension index-2 and the last (2) along
528+ # dimension index-1.
529+ lines = np .stack ([xyz0 , xyz0 , xyz0 ], axis = 1 )
530+ lines [:, 0 , index - 2 ] = maxmin [index - 2 ]
531+ lines [:, 2 , index - 1 ] = maxmin [index - 1 ]
532+ self .gridlines .set_segments (lines )
533+ gridinfo = info ['grid' ]
534+ self .gridlines .set_color (gridinfo ['color' ])
535+ self .gridlines .set_linewidth (gridinfo ['linewidth' ])
536+ self .gridlines .set_linestyle (gridinfo ['linestyle' ])
537+ self .gridlines .do_3d_projection ()
538+ self .gridlines .draw (renderer )
539+
540+ renderer .close_group ('grid3d' )
541+
522542 # TODO: Get this to work (more) properly when mplot3d supports the
523543 # transforms framework.
524544 def get_tightbbox (self , renderer = None , * , for_layout_only = False ):
0 commit comments