77
88class Bitmap (Renderer ):
99 """Bitmap Renderer
10- Renders 3D tensors (CHW)
10+ Renders 3D tensors (CHW) and 2D tensors of ints (HW)
1111
1212 Usage:
1313 Drag: pan
@@ -20,20 +20,26 @@ class Bitmap(Renderer):
2020 Values can be 'viridis', 'plasma', 'inferno', 'magma', 'cividis'
2121 colors: List of colors for each channel for multi channels bitmaps (looped if necessary)
2222 rotate (int): Number of time to rotate the bitmap by 90 degree (counter-clockwise)
23+ invert (bool): Invert vertical axis.
2324 """
24- def __init__ (self , colormap = 'inferno' , colors = ['#ff0000' ,'#00ff00' ,'#0000ff' ,'#ffff00' ,'#00ffff' ,'#ff00ff' ], rotate = 0 ):
25+ def __init__ (self , colormap = 'inferno' , colors = ['#ff0000' ,'#00ff00' ,'#0000ff' ,'#ffff00' ,'#00ffff' ,'#ff00ff' ], rotate = 0 , invert = False ):
2526 super ().__init__ ()
2627 self .colormap = colormap
2728 self .colors = colors
2829 self .rotate = rotate
30+ self .invert = invert
2931
3032 def render (self , title , tensor , size , dpi , shift = (0 ,0 ,0 ,0 ), scale = (1 ,1 ,1 ,1 ), input_tensors = [], target_tensor = None , labels = []):
3133 #check dimensions
32- if len (tensor .shape )!= 3 :
33- print ("Bitmap renderer requires a 3D tensor, got a " + str (len (tensor .shape ))+ "D tensor." , file = sys .stderr )
34+ print (str (tensor .dtype ))
35+ if len (tensor .shape )!= 3 and (len (tensor .shape )!= 2 or 'int' not in str (tensor .dtype )):
36+ print ("Bitmap renderer requires a 3D tensor or 2D tensor of ints, got a " + str (len (tensor .shape ))+ "D tensor." , file = sys .stderr )
3437 return None
3538
3639 #flatten
40+ if len (tensor .shape )== 2 and 'int' in str (tensor .dtype ):
41+ tensor = np .expand_dims (tensor , axis = 0 )
42+
3743 if tensor .shape [0 ]> 1 :
3844 zero = np .zeros ((3 ,tensor .shape [1 ], tensor .shape [2 ]))
3945 for i in range (tensor .shape [0 ]):
@@ -81,15 +87,21 @@ def render(self, title, tensor, size, dpi, shift=(0,0,0,0), scale=(1,1,1,1), inp
8187 render_size = (xmax - xmin ,ymin - ymax )
8288 xmin -= shift [0 ]/ scale [1 ]* render_size [0 ]
8389 xmax -= shift [0 ]/ scale [1 ]* render_size [0 ]
84- ymin += shift [1 ]/ scale [1 ]* render_size [1 ]
85- ymax += shift [1 ]/ scale [1 ]* render_size [1 ]
90+ if self .invert :
91+ ymin -= shift [1 ]/ scale [1 ]* render_size [1 ]
92+ ymax -= shift [1 ]/ scale [1 ]* render_size [1 ]
93+ else :
94+ ymin += shift [1 ]/ scale [1 ]* render_size [1 ]
95+ ymax += shift [1 ]/ scale [1 ]* render_size [1 ]
8696
8797 #scale
8898 render_center = (xmin + render_size [0 ]/ 2.0 ,ymax + render_size [1 ]/ 2.0 )
8999 xmin = render_center [0 ]- (render_size [0 ]/ scale [1 ]/ 2.0 )
90100 xmax = render_center [0 ]+ (render_size [0 ]/ scale [1 ]/ 2.0 )
91101 ymax = render_center [1 ]- (render_size [1 ]/ scale [1 ]/ 2.0 )
92102 ymin = render_center [1 ]+ (render_size [1 ]/ scale [1 ]/ 2.0 )
103+ if self .invert :
104+ ymin , ymax = ymax , ymin
93105
94106 #render
95107 plt .axis (xmin = xmin ,xmax = xmax ,ymin = ymin ,ymax = ymax )
@@ -102,13 +114,3 @@ def render(self, title, tensor, size, dpi, shift=(0,0,0,0), scale=(1,1,1,1), inp
102114 img = PIL .Image .frombytes ('RGB' ,canvas .get_width_height (),canvas .tostring_rgb ())
103115 plt .close ()
104116 return img
105-
106- #from PIL import ImageDraw
107- #img = PIL.Image.new( mode = "RGB", size = (512, 512), color = (209, 123, 193) )
108- #draw = PIL.ImageDraw.Draw(img)
109- #font = PIL.ImageFont.truetype("arial.ttf", 72)
110- #draw.text((40, 100),"Sample Text\nSecond Line\nThird Line\nEnd",fill=(255,255,255), font=font)
111- #tensor = (np.array(img).astype(np.float32)/255).transpose((2,0,1))[[0,1]]
112- #print(tensor[0][0][0],tensor.dtype)
113- #img = chw_3d(tensor, (400,300), 192)
114- #img.save('output.png')
0 commit comments