Skip to content

Commit

Permalink
Finish unet architecture text visualizer
Browse files Browse the repository at this point in the history
  • Loading branch information
nthistle committed Jul 5, 2018
1 parent 60a67e6 commit 131de8c
Showing 1 changed file with 59 additions and 3 deletions.
62 changes: 59 additions & 3 deletions random_scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def back_calc( output_shape, decrease_per=4, factor=2, stages=3 ):
def back_calc_nice( output_shape, decrease_per=2, num_times=2, factor=2, stages=3 ):
last = lambda l : l[-1][-1]

cur_stage = [output_shape]
cur_stage = [True, output_shape]
shapes = [cur_stage]

for _ in range(stages):
Expand All @@ -34,14 +34,70 @@ def back_calc_nice( output_shape, decrease_per=2, num_times=2, factor=2, stages=
cur_stage = [False, last(shapes)*factor]
shapes.append(cur_stage)
shapes.pop()
shapes.append([False])
shapes = [[not shapes[i+1][0],*x[1:][::-1] ] for i, x in enumerate(shapes[:-1])][::-1]
shapes[0].pop(0)

return shapes

## Pretty formatting!
## Get the architecture from back_calc_nice, preferably (similar structure).
def format_unet_nice( unet_arch ):
mat = [[None]]
extend_up = lambda : mat.insert([None]*len(mat[0]))
extend_up = lambda : mat.insert(0, [None]*len(mat[0])) ## borks when used
extend_right = lambda : [m.append(None) for m in mat]
extend_down = lambda : mat.append([None]*len(mat[0]))
## WIP

posi, posj, midx = 0, 0, 0
for stage in unet_arch:
if type(stage[0])==bool:
if stage[0]:
posi += 1
else:
posi -= 1
stage = stage[1:]
midx += 1
for val in stage:
if posi < 0:
extend_up()
if posi >= len(mat):
extend_down()
if posj >= len(mat[0]):
extend_right()
mat[posi][posj] = (val, midx)
midx += 1
posj += 1
posj -= 1
omat = mat
mat = [[str(x[0]) if x else "" for x in y] for y in mat]
maxwidth = max([max([len(x) for x in y]) for y in mat])
output_str = ["" for _ in mat]
for i,row in enumerate(mat):
for j,val in enumerate(row):
if j > 0:
if row[j-1] and val:
output_str[i] += " -> "
else:
output_str[i] += " "
output_str[i] += val + " "*(maxwidth-len(val))
final_output = ""
for i,row in enumerate(mat):
final_output += output_str[i] + "\n"
if i == len(mat)-1:
continue
nline1, nline2 = "", ""
for j in range(len(row)):
if mat[i][j] and mat[i+1][j]:
if omat[i+1][j][1] > omat[i][j][1]:
nline1 += "|" + " "*(maxwidth-1)
nline2 += "v" + " "*(maxwidth-1)
else:
nline1 += "^" + " "*(maxwidth-1)
nline2 += "|" + " "*(maxwidth-1)
else:
nline1 += " "*maxwidth
nline2 += " "*maxwidth
nline1 += " "
nline2 += " "
final_output += nline1 + "\n" + nline2 + "\n"
return final_output #"\n".join(output_str)

0 comments on commit 131de8c

Please sign in to comment.