@@ -1002,43 +1002,23 @@ def __init__(self, ax, labels, actives=None):
10021002 if actives is None :
10031003 actives = [False ] * len (labels )
10041004
1005- if len (labels ) > 1 :
1006- dy = 1. / (len (labels ) + 1 )
1007- ys = np .linspace (1 - dy , dy , len (labels ))
1008- else :
1009- dy = 0.25
1010- ys = [0.5 ]
1011-
1012- axcolor = ax .get_facecolor ()
1013-
1014- self .labels = []
1015- self .lines = []
1016- self .rectangles = []
1017-
1018- lineparams = {'color' : 'k' , 'linewidth' : 1.25 ,
1019- 'transform' : ax .transAxes , 'solid_capstyle' : 'butt' }
1020- for y , label , active in zip (ys , labels , actives ):
1021- t = ax .text (0.25 , y , label , transform = ax .transAxes ,
1022- horizontalalignment = 'left' ,
1023- verticalalignment = 'center' )
1024-
1025- w , h = dy / 2 , dy / 2
1026- x , y = 0.05 , y - h / 2
1027-
1028- p = Rectangle (xy = (x , y ), width = w , height = h , edgecolor = 'black' ,
1029- facecolor = axcolor , transform = ax .transAxes )
1005+ ys = np .linspace (1 , 0 , len (labels )+ 2 )[1 :- 1 ]
1006+ text_size = mpl .rcParams ["font.size" ] / 2
10301007
1031- l1 = Line2D ([x , x + w ], [y + h , y ], ** lineparams )
1032- l2 = Line2D ([x , x + w ], [y , y + h ], ** lineparams )
1008+ self .labels = [
1009+ ax .text (0.25 , y , label , transform = ax .transAxes ,
1010+ horizontalalignment = "left" , verticalalignment = "center" )
1011+ for y , label in zip (ys , labels )]
10331012
1034- l1 .set_visible (active )
1035- l2 .set_visible (active )
1036- self .labels .append (t )
1037- self .rectangles .append (p )
1038- self .lines .append ((l1 , l2 ))
1039- ax .add_patch (p )
1040- ax .add_line (l1 )
1041- ax .add_line (l2 )
1013+ self ._squares = ax .scatter (
1014+ [0.15 ] * len (ys ), ys , marker = 's' , s = text_size ** 2 ,
1015+ c = "none" , linewidth = 1 , transform = ax .transAxes , edgecolor = "k"
1016+ )
1017+ self ._crosses = ax .scatter (
1018+ [0.15 ] * len (ys ), ys , marker = 'x' , linewidth = 1 , s = text_size ** 2 ,
1019+ c = ["k" if active else "none" for active in actives ],
1020+ transform = ax .transAxes
1021+ )
10421022
10431023 self .connect_event ('button_press_event' , self ._clicked )
10441024
@@ -1047,11 +1027,27 @@ def __init__(self, ax, labels, actives=None):
10471027 def _clicked (self , event ):
10481028 if self .ignore (event ) or event .button != 1 or event .inaxes != self .ax :
10491029 return
1050- for i , (p , t ) in enumerate (zip (self .rectangles , self .labels )):
1051- if (t .get_window_extent ().contains (event .x , event .y ) or
1052- p .get_window_extent ().contains (event .x , event .y )):
1053- self .set_active (i )
1054- break
1030+ pclicked = self .ax .transAxes .inverted ().transform ((event .x , event .y ))
1031+ distances = {}
1032+ if hasattr (self , "_rectangles" ):
1033+ for i , (p , t ) in enumerate (zip (self ._rectangles , self .labels )):
1034+ x0 , y0 = p .get_xy ()
1035+ if (t .get_window_extent ().contains (event .x , event .y )
1036+ or (x0 <= pclicked [0 ] <= x0 + p .get_width ()
1037+ and y0 <= pclicked [1 ] <= y0 + p .get_height ())):
1038+ distances [i ] = np .linalg .norm (pclicked - p .get_center ())
1039+ else :
1040+ _ , square_inds = self ._squares .contains (event )
1041+ coords = self ._squares .get_offset_transform ().transform (
1042+ self ._squares .get_offsets ()
1043+ )
1044+ for i , t in enumerate (self .labels ):
1045+ if (i in square_inds ["ind" ]
1046+ or t .get_window_extent ().contains (event .x , event .y )):
1047+ distances [i ] = np .linalg .norm (pclicked - coords [i ])
1048+ if len (distances ) > 0 :
1049+ closest = min (distances , key = distances .get )
1050+ self .set_active (closest )
10551051
10561052 def set_active (self , index ):
10571053 """
@@ -1072,9 +1068,20 @@ def set_active(self, index):
10721068 if index not in range (len (self .labels )):
10731069 raise ValueError (f'Invalid CheckButton index: { index } ' )
10741070
1075- l1 , l2 = self .lines [index ]
1076- l1 .set_visible (not l1 .get_visible ())
1077- l2 .set_visible (not l2 .get_visible ())
1071+ cross_facecolors = self ._crosses .get_facecolor ()
1072+ cross_facecolors [index ] = colors .to_rgba (
1073+ "black"
1074+ if colors .same_color (
1075+ cross_facecolors [index ], colors .to_rgba ("none" )
1076+ )
1077+ else "none"
1078+ )
1079+ self ._crosses .set_facecolor (cross_facecolors )
1080+
1081+ if hasattr (self , "_lines" ):
1082+ l1 , l2 = self ._lines [index ]
1083+ l1 .set_visible (not l1 .get_visible ())
1084+ l2 .set_visible (not l2 .get_visible ())
10781085
10791086 if self .drawon :
10801087 self .ax .figure .canvas .draw ()
@@ -1086,7 +1093,8 @@ def get_status(self):
10861093 """
10871094 Return a list of the status (True/False) of all of the check buttons.
10881095 """
1089- return [l1 .get_visible () for (l1 , l2 ) in self .lines ]
1096+ return [not colors .same_color (color , colors .to_rgba ("none" ))
1097+ for color in self ._crosses .get_facecolors ()]
10901098
10911099 def on_clicked (self , func ):
10921100 """
@@ -1100,6 +1108,57 @@ def disconnect(self, cid):
11001108 """Remove the observer with connection id *cid*."""
11011109 self ._observers .disconnect (cid )
11021110
1111+ @_api .deprecated ("3.7" )
1112+ @property
1113+ def rectangles (self ):
1114+ if not hasattr (self , "_rectangles" ):
1115+ ys = np .linspace (1 , 0 , len (self .labels )+ 2 )[1 :- 1 ]
1116+ dy = 1. / (len (self .labels ) + 1 )
1117+ w , h = dy / 2 , dy / 2
1118+ rectangles = self ._rectangles = [
1119+ Rectangle (xy = (0.05 , ys [i ] - h / 2 ), width = w , height = h ,
1120+ edgecolor = "black" ,
1121+ facecolor = "none" ,
1122+ transform = self .ax .transAxes
1123+ )
1124+ for i , y in enumerate (ys )
1125+ ]
1126+ self ._squares .set_visible (False )
1127+ for rectangle in rectangles :
1128+ self .ax .add_patch (rectangle )
1129+ if not hasattr (self , "_lines" ):
1130+ with _api .suppress_matplotlib_deprecation_warning ():
1131+ _ = self .lines
1132+ return self ._rectangles
1133+
1134+ @_api .deprecated ("3.7" )
1135+ @property
1136+ def lines (self ):
1137+ if not hasattr (self , "_lines" ):
1138+ ys = np .linspace (1 , 0 , len (self .labels )+ 2 )[1 :- 1 ]
1139+ self ._crosses .set_visible (False )
1140+ dy = 1. / (len (self .labels ) + 1 )
1141+ w , h = dy / 2 , dy / 2
1142+ self ._lines = []
1143+ current_status = self .get_status ()
1144+ lineparams = {'color' : 'k' , 'linewidth' : 1.25 ,
1145+ 'transform' : self .ax .transAxes ,
1146+ 'solid_capstyle' : 'butt' }
1147+ for i , y in enumerate (ys ):
1148+ x , y = 0.05 , y - h / 2
1149+ l1 = Line2D ([x , x + w ], [y + h , y ], ** lineparams )
1150+ l2 = Line2D ([x , x + w ], [y , y + h ], ** lineparams )
1151+
1152+ l1 .set_visible (current_status [i ])
1153+ l2 .set_visible (current_status [i ])
1154+ self ._lines .append ((l1 , l2 ))
1155+ self .ax .add_patch (l1 )
1156+ self .ax .add_patch (l2 )
1157+ if not hasattr (self , "_rectangles" ):
1158+ with _api .suppress_matplotlib_deprecation_warning ():
1159+ _ = self .rectangles
1160+ return self ._lines
1161+
11031162
11041163class TextBox (AxesWidget ):
11051164 """
@@ -1457,8 +1516,10 @@ def set_active(self, index):
14571516 if index not in range (len (self .labels )):
14581517 raise ValueError (f'Invalid RadioButton index: { index } ' )
14591518 self .value_selected = self .labels [index ].get_text ()
1460- self ._buttons .get_facecolor ()[:] = colors .to_rgba ("none" )
1461- self ._buttons .get_facecolor ()[index ] = colors .to_rgba (self .activecolor )
1519+ button_facecolors = self ._buttons .get_facecolor ()
1520+ button_facecolors [:] = colors .to_rgba ("none" )
1521+ button_facecolors [index ] = colors .to_rgba (self .activecolor )
1522+ self ._buttons .set_facecolor (button_facecolors )
14621523 if hasattr (self , "_circles" ): # Remove once circles is removed.
14631524 for i , p in enumerate (self ._circles ):
14641525 p .set_facecolor (self .activecolor if i == index else "none" )
0 commit comments