3232
3333from  mplfinance  import  _styles 
3434
35- from  mplfinance ._arg_validators  import  _check_and_prepare_data , _mav_validator 
35+ from  mplfinance ._arg_validators  import  _check_and_prepare_data , _mav_validator ,  _label_validator 
3636from  mplfinance ._arg_validators  import  _get_valid_plot_types , _fill_between_validator 
3737from  mplfinance ._arg_validators  import  _process_kwargs , _validate_vkwargs_dict 
3838from  mplfinance ._arg_validators  import  _kwarg_not_implemented , _bypass_kwarg_validation 
@@ -765,6 +765,8 @@ def plot( data, **kwargs ):
765765
766766        elif  not  _list_of_dict (addplot ):
767767            raise  TypeError ('addplot must be `dict`, or `list of dict`, NOT ' + str (type (addplot )))
768+         
769+         contains_legend_label = [] # a list of axes that contains legend labels 
768770
769771        for  apdict  in  addplot :
770772
@@ -788,10 +790,28 @@ def plot( data, **kwargs ):
788790                else :
789791                    havedf  =  False       # must be a single series or array 
790792                    apdata  =  [apdata ,]  # make it iterable 
793+                 if  havedf  and  apdict ['label' ]:
794+                     if  not  isinstance (apdict ['label' ],(list ,tuple ,np .ndarray )):
795+                        nlabels  =  1 
796+                     else :
797+                        nlabels  =  len (apdict ['label' ])
798+                     ncolumns  =  len (apdata .columns )
799+                     #print('nlabels=',nlabels,'ncolumns=',ncolumns) 
800+                     if  nlabels  <  ncolumns :
801+                         warnings .warn ('\n  =======================================\n ' + 
802+                                       ' addplot MISMATCH between data and labels:\n ' + 
803+                                       ' have ' + str (ncolumns )+ ' columns to plot \n ' + 
804+                                       ' BUT  ' + str (nlabels )+ ' labels for them.\n ' )
805+                 colcount  =  0 
791806                for  column  in  apdata :
792807                    ydata  =  apdata .loc [:,column ] if  havedf  else  column 
793-                     ax  =  _addplot_columns (panid ,panels ,ydata ,apdict ,xdates ,config )
808+                     ax  =  _addplot_columns (panid ,panels ,ydata ,apdict ,xdates ,config , colcount )
794809                    _addplot_apply_supplements (ax ,apdict ,xdates )
810+                     colcount  +=  1 
811+                     if  apdict ['label' ]: # not supported for aptype == 'ohlc' or 'candle' 
812+                         contains_legend_label .append (ax )
813+         for  ax  in  set (contains_legend_label ): # there might be duplicates 
814+             ax .legend ()
795815
796816    # fill_between is NOT supported for external_axes_mode 
797817    # (caller can easily call ax.fill_between() themselves). 
@@ -1079,7 +1099,7 @@ def _addplot_collections(panid,panels,apdict,xdates,config):
10791099    ax .autoscale_view ()
10801100    return  ax 
10811101
1082- def  _addplot_columns (panid ,panels ,ydata ,apdict ,xdates ,config ):
1102+ def  _addplot_columns (panid ,panels ,ydata ,apdict ,xdates ,config , colcount ):
10831103    external_axes_mode  =  apdict ['ax' ] is  not None 
10841104    if  not  external_axes_mode :
10851105        secondary_y  =  False 
@@ -1101,6 +1121,10 @@ def _addplot_columns(panid,panels,ydata,apdict,xdates,config):
11011121        ax  =  apdict ['ax' ]
11021122
11031123    aptype  =  apdict ['type' ]
1124+     if  isinstance (apdict ['label' ],(list ,tuple ,np .ndarray )):
1125+         label  =  apdict ['label' ][colcount ]
1126+     else : # isinstance(...,str) 
1127+         label  =  apdict ['label' ]
11041128    if  aptype  ==  'scatter' :
11051129        size   =  apdict ['markersize' ]
11061130        mark   =  apdict ['marker' ]
@@ -1111,27 +1135,27 @@ def _addplot_columns(panid,panels,ydata,apdict,xdates,config):
11111135
11121136        if  isinstance (mark ,(list ,tuple ,np .ndarray )):
11131137            _mscatter (xdates , ydata , ax = ax , m = mark , s = size , color = color , alpha = alpha , edgecolors = edgecolors , linewidths = linewidths )
1114-         else :
1115-             ax .scatter (xdates , ydata , s = size , marker = mark , color = color , alpha = alpha , edgecolors = edgecolors , linewidths = linewidths ) 
1138+         else :  
1139+             ax .scatter (xdates , ydata , s = size , marker = mark , color = color , alpha = alpha , edgecolors = edgecolors , linewidths = linewidths , label = label )  
11161140    elif  aptype  ==  'bar' :
11171141        width   =  0.8  if  apdict ['width' ] is  None  else  apdict ['width' ]
11181142        bottom  =  apdict ['bottom' ]
11191143        color   =  apdict ['color' ]
11201144        alpha   =  apdict ['alpha' ]
1121-         ax .bar (xdates ,ydata ,width = width ,bottom = bottom ,color = color ,alpha = alpha )
1145+         ax .bar (xdates ,ydata ,width = width ,bottom = bottom ,color = color ,alpha = alpha , label = label )
11221146    elif  aptype  ==  'line' :
11231147        ls      =  apdict ['linestyle' ]
11241148        color   =  apdict ['color' ]
11251149        width   =  apdict ['width' ] if  apdict ['width' ] is  not None  else  1.6 * config ['_width_config' ]['line_width' ]
11261150        alpha   =  apdict ['alpha' ]
1127-         ax .plot (xdates ,ydata ,linestyle = ls ,color = color ,linewidth = width ,alpha = alpha )
1151+         ax .plot (xdates ,ydata ,linestyle = ls ,color = color ,linewidth = width ,alpha = alpha , label = label )
11281152    elif  aptype  ==  'step' :
11291153        stepwhere  =  apdict ['stepwhere' ]
11301154        ls  =  apdict ['linestyle' ]
11311155        color   =  apdict ['color' ]
11321156        width   =  apdict ['width' ] if  apdict ['width' ] is  not None  else  1.6 * config ['_width_config' ]['line_width' ]
11331157        alpha   =  apdict ['alpha' ]
1134-         ax .step (xdates ,ydata ,where  =  stepwhere ,linestyle = ls ,color = color ,linewidth = width ,alpha = alpha )
1158+         ax .step (xdates ,ydata ,where  =  stepwhere ,linestyle = ls ,color = color ,linewidth = width ,alpha = alpha , label = label )
11351159    else :
11361160        raise  ValueError ('addplot type "' + str (aptype )+ '" NOT yet supported.' )
11371161
@@ -1384,6 +1408,9 @@ def _valid_addplot_kwargs():
13841408        'fill_between' : { 'Default'      : None ,    # added by Wen 
13851409                          'Description'  : " fill region" ,
13861410                          'Validator'    : _fill_between_validator  },
1411+         'label'       : {  'Default'      : None ,
1412+                           'Description'  : 'Label for the added plot. One per added plot.' ,
1413+                           'Validator'    : _label_validator  },
13871414
13881415    }
13891416
0 commit comments