Skip to content

Commit 859385a

Browse files
handle png masks and proper spectrogram representation
1 parent 3eb0e67 commit 859385a

File tree

4 files changed

+39
-23
lines changed

4 files changed

+39
-23
lines changed

torchstudio/datasets/genericloader.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ def __init__(self, path:str='', classification:bool=True, separator:str='/', ext
128128
def to_tensors(self, path:str):
129129
if path.endswith('.jpg') or path.endswith('.jpeg') or path.endswith('.png'):
130130
img=Image.open(path)
131-
if img.getpalette():
131+
if img.mode=='1' or img.mode=='L' or img.mode=='P':
132132
return [torch.from_numpy(np.array(img, dtype=np.uint8))]
133133
else:
134134
trans=torchvision.transforms.ToTensor()

torchstudio/renderers/bitmap.py

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
class Bitmap(Renderer):
99
"""Bitmap Renderer
10-
Renders 3D tensors (CHW)
10+
Renders 3D tensors (CHW) and 2D tensors of ints (HW)
1111
1212
Usage:
1313
Drag: pan
@@ -20,20 +20,26 @@ class Bitmap(Renderer):
2020
Values can be 'viridis', 'plasma', 'inferno', 'magma', 'cividis'
2121
colors: List of colors for each channel for multi channels bitmaps (looped if necessary)
2222
rotate (int): Number of time to rotate the bitmap by 90 degree (counter-clockwise)
23+
invert (bool): Invert vertical axis.
2324
"""
24-
def __init__(self, colormap='inferno', colors=['#ff0000','#00ff00','#0000ff','#ffff00','#00ffff','#ff00ff'], rotate=0):
25+
def __init__(self, colormap='inferno', colors=['#ff0000','#00ff00','#0000ff','#ffff00','#00ffff','#ff00ff'], rotate=0, invert=False):
2526
super().__init__()
2627
self.colormap=colormap
2728
self.colors=colors
2829
self.rotate=rotate
30+
self.invert=invert
2931

3032
def render(self, title, tensor, size, dpi, shift=(0,0,0,0), scale=(1,1,1,1), input_tensors=[], target_tensor=None, labels=[]):
3133
#check dimensions
32-
if len(tensor.shape)!=3:
33-
print("Bitmap renderer requires a 3D tensor, got a "+str(len(tensor.shape))+"D tensor.", file=sys.stderr)
34+
print(str(tensor.dtype))
35+
if len(tensor.shape)!=3 and (len(tensor.shape)!=2 or 'int' not in str(tensor.dtype)):
36+
print("Bitmap renderer requires a 3D tensor or 2D tensor of ints, got a "+str(len(tensor.shape))+"D tensor.", file=sys.stderr)
3437
return None
3538

3639
#flatten
40+
if len(tensor.shape)==2 and 'int' in str(tensor.dtype):
41+
tensor=np.expand_dims(tensor, axis=0)
42+
3743
if tensor.shape[0]>1:
3844
zero = np.zeros((3,tensor.shape[1], tensor.shape[2]))
3945
for i in range(tensor.shape[0]):
@@ -81,15 +87,21 @@ def render(self, title, tensor, size, dpi, shift=(0,0,0,0), scale=(1,1,1,1), inp
8187
render_size=(xmax-xmin,ymin-ymax)
8288
xmin-=shift[0]/scale[1]*render_size[0]
8389
xmax-=shift[0]/scale[1]*render_size[0]
84-
ymin+=shift[1]/scale[1]*render_size[1]
85-
ymax+=shift[1]/scale[1]*render_size[1]
90+
if self.invert:
91+
ymin-=shift[1]/scale[1]*render_size[1]
92+
ymax-=shift[1]/scale[1]*render_size[1]
93+
else:
94+
ymin+=shift[1]/scale[1]*render_size[1]
95+
ymax+=shift[1]/scale[1]*render_size[1]
8696

8797
#scale
8898
render_center=(xmin+render_size[0]/2.0,ymax+render_size[1]/2.0)
8999
xmin=render_center[0]-(render_size[0]/scale[1]/2.0)
90100
xmax=render_center[0]+(render_size[0]/scale[1]/2.0)
91101
ymax=render_center[1]-(render_size[1]/scale[1]/2.0)
92102
ymin=render_center[1]+(render_size[1]/scale[1]/2.0)
103+
if self.invert:
104+
ymin, ymax = ymax, ymin
93105

94106
#render
95107
plt.axis(xmin=xmin,xmax=xmax,ymin=ymin,ymax=ymax)
@@ -102,13 +114,3 @@ def render(self, title, tensor, size, dpi, shift=(0,0,0,0), scale=(1,1,1,1), inp
102114
img = PIL.Image.frombytes('RGB',canvas.get_width_height(),canvas.tostring_rgb())
103115
plt.close()
104116
return img
105-
106-
#from PIL import ImageDraw
107-
#img = PIL.Image.new( mode = "RGB", size = (512, 512), color = (209, 123, 193) )
108-
#draw = PIL.ImageDraw.Draw(img)
109-
#font = PIL.ImageFont.truetype("arial.ttf", 72)
110-
#draw.text((40, 100),"Sample Text\nSecond Line\nThird Line\nEnd",fill=(255,255,255), font=font)
111-
#tensor = (np.array(img).astype(np.float32)/255).transpose((2,0,1))[[0,1]]
112-
#print(tensor[0][0][0],tensor.dtype)
113-
#img = chw_3d(tensor, (400,300), 192)
114-
#img.save('output.png')

torchstudio/renderers/spectrogram.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,12 @@ class Spectrogram(Renderer):
2121
colors: List of colors for each channel for multi channels spectrograms (looped if necessary)
2222
rotate (int): Number of time to rotate the bitmap by 90 degree (counter-clockwise)
2323
"""
24-
def __init__(self, colormap='inferno', colors=['#ff0000','#00ff00','#0000ff','#ffff00','#00ffff','#ff00ff'], rotate=0):
24+
def __init__(self, colormap='inferno', colors=['#ff0000','#00ff00','#0000ff','#ffff00','#00ffff','#ff00ff'], rotate=0, invert=False):
2525
super().__init__()
2626
self.colormap=colormap
2727
self.colors=colors
2828
self.rotate=rotate
29+
self.invert=invert
2930

3031
def render(self, title, tensor, size, dpi, shift=(0,0,0,0), scale=(1,1,1,1), input_tensors=[], target_tensor=None, labels=[]):
3132
#check dimensions
@@ -91,15 +92,21 @@ def render(self, title, tensor, size, dpi, shift=(0,0,0,0), scale=(1,1,1,1), inp
9192
render_size=(xmax-xmin,ymin-ymax)
9293
xmin-=shift[0]/scale[1]*render_size[0]
9394
xmax-=shift[0]/scale[1]*render_size[0]
94-
ymin+=shift[1]/scale[1]*render_size[1]
95-
ymax+=shift[1]/scale[1]*render_size[1]
95+
if self.invert:
96+
ymin-=shift[1]/scale[1]*render_size[1]
97+
ymax-=shift[1]/scale[1]*render_size[1]
98+
else:
99+
ymin+=shift[1]/scale[1]*render_size[1]
100+
ymax+=shift[1]/scale[1]*render_size[1]
96101

97102
#scale
98103
render_center=(xmin+render_size[0]/2.0,ymax+render_size[1]/2.0)
99104
xmin=render_center[0]-(render_size[0]/scale[1]/2.0)
100105
xmax=render_center[0]+(render_size[0]/scale[1]/2.0)
101106
ymax=render_center[1]-(render_size[1]/scale[1]/2.0)
102107
ymin=render_center[1]+(render_size[1]/scale[1]/2.0)
108+
if self.invert:
109+
ymin, ymax = ymax, ymin
103110

104111
#render
105112
plt.axis(xmin=xmin,xmax=xmax,ymin=ymin,ymax=ymax)

torchstudio/renderers/volume.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,12 @@ class Volume(Renderer):
2222
colors: List of colors for each channel for multi channels volumes (looped if necessary)
2323
rotate (int): Number of time to rotate the bitmap by 90 degree (counter-clockwise)
2424
"""
25-
def __init__(self, colormap='inferno', colors=['#ff0000','#00ff00','#0000ff','#ffff00','#00ffff','#ff00ff'], rotate=0):
25+
def __init__(self, colormap='inferno', colors=['#ff0000','#00ff00','#0000ff','#ffff00','#00ffff','#ff00ff'], rotate=0, invert=False):
2626
super().__init__()
2727
self.colormap=colormap
2828
self.colors=colors
2929
self.rotate=rotate
30+
self.invert=invert
3031

3132
def render(self, title, tensor, size, dpi, shift=(0,0,0,0), scale=(1,1,1,1), input_tensors=[], target_tensor=None, labels=[]):
3233
#check dimensions
@@ -88,15 +89,21 @@ def render(self, title, tensor, size, dpi, shift=(0,0,0,0), scale=(1,1,1,1), inp
8889
render_size=(xmax-xmin,ymin-ymax)
8990
xmin-=shift[0]/scale[1]*render_size[0]
9091
xmax-=shift[0]/scale[1]*render_size[0]
91-
ymin+=shift[1]/scale[1]*render_size[1]
92-
ymax+=shift[1]/scale[1]*render_size[1]
92+
if self.invert:
93+
ymin-=shift[1]/scale[1]*render_size[1]
94+
ymax-=shift[1]/scale[1]*render_size[1]
95+
else:
96+
ymin+=shift[1]/scale[1]*render_size[1]
97+
ymax+=shift[1]/scale[1]*render_size[1]
9398

9499
#scale
95100
render_center=(xmin+render_size[0]/2.0,ymax+render_size[1]/2.0)
96101
xmin=render_center[0]-(render_size[0]/scale[1]/2.0)
97102
xmax=render_center[0]+(render_size[0]/scale[1]/2.0)
98103
ymax=render_center[1]-(render_size[1]/scale[1]/2.0)
99104
ymin=render_center[1]+(render_size[1]/scale[1]/2.0)
105+
if self.invert:
106+
ymin, ymax = ymax, ymin
100107

101108
#render
102109
plt.axis(xmin=xmin,xmax=xmax,ymin=ymin,ymax=ymax)

0 commit comments

Comments
 (0)