@@ -20,13 +20,15 @@ class Spectrogram(Renderer):
2020 Values can be 'viridis', 'plasma', 'inferno', 'magma', 'cividis'
2121 colors: List of colors for each channel for multi channels spectrograms (looped if necessary)
2222 rotate (int): Number of time to rotate the bitmap by 90 degree (counter-clockwise)
23+ normalize (bool): Normalize values
2324 """
24- def __init__ (self , colormap = 'inferno' , colors = ['#ff0000' ,'#00ff00' ,'#0000ff' ,'#ffff00' ,'#00ffff' ,'#ff00ff' ], rotate = 0 , invert = False ):
25+ def __init__ (self , colormap = 'inferno' , colors = ['#ff0000' ,'#00ff00' ,'#0000ff' ,'#ffff00' ,'#00ffff' ,'#ff00ff' ], rotate = 0 , invert = False , normalize = False ):
2526 super ().__init__ ()
2627 self .colormap = colormap
2728 self .colors = colors
2829 self .rotate = rotate
2930 self .invert = invert
31+ self .normalize = normalize
3032
3133 def render (self , title , tensor , size , dpi , shift = (0 ,0 ,0 ,0 ), scale = (1 ,1 ,1 ,1 ), input_tensors = [], target_tensor = None , labels = []):
3234 #check dimensions
@@ -35,8 +37,8 @@ def render(self, title, tensor, size, dpi, shift=(0,0,0,0), scale=(1,1,1,1), inp
3537 return None
3638
3739 if np .iscomplexobj (tensor )== False and tensor .shape [0 ]% 2 != 0 :
38- print ( "Spectrogram renderer requires a complex tensor or a tensor with an even number of channels" , file = sys . stderr )
39- return None
40+ #add missing channel (needs pairs to be interpred as complex channels)
41+ tensor = np . append ( tensor , np . zeros (( 1 , tensor . shape [ 1 ], tensor . shape [ 2 ])), axis = 0 )
4042
4143 #convert complex spectrogram to amplitude spectrogram
4244 if np .iscomplexobj (tensor ):
@@ -55,6 +57,12 @@ def render(self, title, tensor, size, dpi, shift=(0,0,0,0), scale=(1,1,1,1), inp
5557 if self .rotate > 0 :
5658 tensor = np .rot90 (tensor , self .rotate , axes = (1 , 2 ))
5759
60+ tensor = tensor .astype (np .float32 )
61+ if self .normalize :
62+ max_value = np .amax (tensor )
63+ if max_value > 0 :
64+ tensor = tensor / max_value
65+
5866 #apply brightness, gamma and conversion to uint8, then transform CHW to HWC
5967 tensor = np .multiply (np .clip (np .power (tensor * scale [0 ],1 / scale [3 ]),0 ,1 ),255 ).astype (np .uint8 )
6068 tensor = tensor .transpose ((1 , 2 , 0 ))
0 commit comments