24
24
PUBLIC_BRANCH = "publication"
25
25
URL_DOWNLOAD = f"https://github.com/PyTorchLightning/{ REPO_NAME } /raw/{ DEFAULT_BRANCH } "
26
26
ENV_DEVICE = "ACCELERATOR"
27
- DEVICE_ACCELERATOR = os .environ .get (ENV_DEVICE , ' cpu' ).lower ()
27
+ DEVICE_ACCELERATOR = os .environ .get (ENV_DEVICE , " cpu" ).lower ()
28
28
TEMPLATE_HEADER = f"""# %%%% [markdown]
29
29
#
30
30
# # %(title)s
92
92
def default_requirements (path_req : str = PATH_REQ_DEFAULT ) -> list :
93
93
with open (path_req ) as fp :
94
94
req = fp .readlines ()
95
- req = [r [:r .index ("#" )] if "#" in r else r for r in req ]
95
+ req = [r [: r .index ("#" )] if "#" in r else r for r in req ]
96
96
req = [r .strip () for r in req ]
97
97
req = [r for r in req if r ]
98
98
return req
@@ -101,6 +101,7 @@ def default_requirements(path_req: str = PATH_REQ_DEFAULT) -> list:
101
101
def get_running_cuda_version () -> str :
102
102
try :
103
103
import torch
104
+
104
105
return torch .version .cuda or ""
105
106
except ImportError :
106
107
return ""
@@ -109,8 +110,9 @@ def get_running_cuda_version() -> str:
109
110
def get_running_torch_version ():
110
111
try :
111
112
import torch
113
+
112
114
ver = torch .__version__
113
- return ver [:ver .index ('+' )] if '+' in ver else ver
115
+ return ver [: ver .index ("+" )] if "+" in ver else ver
114
116
except ImportError :
115
117
return ""
116
118
@@ -119,8 +121,8 @@ def get_running_torch_version():
119
121
CUDA_VERSION = get_running_cuda_version ()
120
122
RUNTIME_VERSIONS = dict (
121
123
TORCH_VERSION_FULL = TORCH_VERSION ,
122
- TORCH_VERSION = TORCH_VERSION [:TORCH_VERSION .index ('+' )] if '+' in TORCH_VERSION else TORCH_VERSION ,
123
- TORCH_MAJOR_DOT_MINOR = '.' .join (TORCH_VERSION .split ('.' )[:2 ]),
124
+ TORCH_VERSION = TORCH_VERSION [: TORCH_VERSION .index ("+" )] if "+" in TORCH_VERSION else TORCH_VERSION ,
125
+ TORCH_MAJOR_DOT_MINOR = "." .join (TORCH_VERSION .split ("." )[:2 ]),
124
126
CUDA_VERSION = CUDA_VERSION ,
125
127
CUDA_MAJOR_MINOR = CUDA_VERSION .replace ("." , "" ),
126
128
DEVICE = f"cu{ CUDA_VERSION .replace ('.' , '' )} " if CUDA_VERSION else "cpu" ,
@@ -130,7 +132,7 @@ def get_running_torch_version():
130
132
class HelperCLI :
131
133
132
134
DIR_NOTEBOOKS = ".notebooks"
133
- META_REQUIRED_FIELDS = (' title' , ' author' , ' license' , ' description' )
135
+ META_REQUIRED_FIELDS = (" title" , " author" , " license" , " description" )
134
136
SKIP_DIRS = (
135
137
".actions" ,
136
138
".azure-pipelines" ,
@@ -144,7 +146,7 @@ class HelperCLI:
144
146
META_FILE_REGEX = ".meta.{yaml,yml}"
145
147
REQUIREMENTS_FILE = "requirements.txt"
146
148
PIP_ARGS_FILE = "pip_arguments.txt"
147
- META_PIP_KEY = ' pip__'
149
+ META_PIP_KEY = " pip__"
148
150
149
151
@staticmethod
150
152
def _meta_file (folder : str ) -> str :
@@ -171,7 +173,7 @@ def augment_script(fpath: str):
171
173
generated = datetime .now ().isoformat (),
172
174
)
173
175
174
- meta [' description' ] = meta [' description' ].replace (os .linesep , f"{ os .linesep } # " )
176
+ meta [" description" ] = meta [" description" ].replace (os .linesep , f"{ os .linesep } # " )
175
177
176
178
header = TEMPLATE_HEADER % meta
177
179
requires = set (default_requirements () + meta ["requirements" ])
@@ -203,22 +205,22 @@ def _replace_images(lines: list, local_dir: str) -> list:
203
205
url_path = p_img
204
206
im = requests .get (p_img , stream = True ).raw .read ()
205
207
else :
206
- url_path = '/' .join ([URL_DOWNLOAD , local_dir , p_img ])
208
+ url_path = "/" .join ([URL_DOWNLOAD , local_dir , p_img ])
207
209
p_local_img = os .path .join (local_dir , p_img )
208
210
with open (p_local_img , "rb" ) as fp :
209
211
im = fp .read ()
210
212
im_base64 = base64 .b64encode (im ).decode ("utf-8" )
211
213
_ , ext = os .path .splitext (p_img )
212
214
md = md .replace (f'src="{ p_img } "' , f'src="{ url_path } "' )
213
- md = md .replace (f' ]({ p_img } )' , f' ](data:image/{ ext [1 :]} ;base64,{ im_base64 } )' )
215
+ md = md .replace (f" ]({ p_img } )" , f" ](data:image/{ ext [1 :]} ;base64,{ im_base64 } )" )
214
216
215
217
return [ln + os .linesep for ln in md .split (os .linesep )]
216
218
217
219
@staticmethod
218
220
def _is_ipynb_parent_dir (dir_path : str ) -> bool :
219
221
if HelperCLI ._meta_file (dir_path ):
220
222
return True
221
- sub_dirs = [d for d in glob .glob (os .path .join (dir_path , '*' )) if os .path .isdir (d )]
223
+ sub_dirs = [d for d in glob .glob (os .path .join (dir_path , "*" )) if os .path .isdir (d )]
222
224
return any (HelperCLI ._is_ipynb_parent_dir (d ) for d in sub_dirs )
223
225
224
226
@staticmethod
@@ -296,15 +298,14 @@ def parse_requirements(dir_path: str):
296
298
meta = yaml .safe_load (open (fpath ))
297
299
pprint (meta )
298
300
299
- req = meta .get (' requirements' , [])
301
+ req = meta .get (" requirements" , [])
300
302
fname = os .path .join (dir_path , HelperCLI .REQUIREMENTS_FILE )
301
303
print (f"File for requirements: { fname } " )
302
304
with open (fname , "w" ) as fp :
303
305
fp .write (os .linesep .join (req ))
304
306
305
307
pip_args = {
306
- k .replace (HelperCLI .META_PIP_KEY , '' ): v
307
- for k , v in meta .items () if k .startswith (HelperCLI .META_PIP_KEY )
308
+ k .replace (HelperCLI .META_PIP_KEY , "" ): v for k , v in meta .items () if k .startswith (HelperCLI .META_PIP_KEY )
308
309
}
309
310
cmd_args = []
310
311
for pip_key in pip_args :
@@ -327,33 +328,31 @@ def _get_card_item_cell(path_ipynb: str) -> Dict[str, Any]:
327
328
328
329
# Clamp description length
329
330
wrapped_description = wrap (
330
- meta .get (' short_description' , meta [' description' ]).strip ().replace (os .linesep , " " ), 175
331
+ meta .get (" short_description" , meta [" description" ]).strip ().replace (os .linesep , " " ), 175
331
332
)
332
333
suffix = "..." if len (wrapped_description ) > 1 else ""
333
- meta [' short_description' ] = wrapped_description [0 ] + suffix
334
+ meta [" short_description" ] = wrapped_description [0 ] + suffix
334
335
335
336
# Resolve some default tags based on accelerators and directory name
336
- meta [' tags' ] = meta .get (' tags' , [])
337
+ meta [" tags" ] = meta .get (" tags" , [])
337
338
338
- accelerators = meta .get ("accelerator" , (' CPU' , ))
339
- if (' GPU' in accelerators ) or (' TPU' in accelerators ):
340
- meta [' tags' ].append (' GPU/TPU' )
339
+ accelerators = meta .get ("accelerator" , (" CPU" , ))
340
+ if (" GPU" in accelerators ) or (" TPU" in accelerators ):
341
+ meta [" tags" ].append (" GPU/TPU" )
341
342
342
343
dirname = os .path .basename (os .path .dirname (path_ipynb ))
343
344
if dirname != ".notebooks" :
344
- meta [' tags' ].append (dirname )
345
+ meta [" tags" ].append (dirname )
345
346
346
- meta [' tags' ] = "," .join (meta [' tags' ])
347
+ meta [" tags" ] = "," .join (meta [" tags" ])
347
348
348
349
# Build the notebook cell
349
350
rst_cell = TEMPLATE_CARD_ITEM % meta
350
351
351
352
return {
352
353
"cell_type" : "raw" ,
353
- "metadata" : {
354
- "raw_mimetype" : "text/restructuredtext"
355
- },
356
- "source" : rst_cell .strip ().splitlines (True )
354
+ "metadata" : {"raw_mimetype" : "text/restructuredtext" },
355
+ "source" : rst_cell .strip ().splitlines (True ),
357
356
}
358
357
359
358
@staticmethod
@@ -365,27 +364,27 @@ def copy_notebooks(path_root: str, path_docs_ipynb: str = "docs/source/notebooks
365
364
path_docs_ipynb: destination path to the notebooks location
366
365
"""
367
366
ls_ipynb = []
368
- for sub in ([' *.ipynb' ], ['**' , ' *.ipynb' ]):
367
+ for sub in ([" *.ipynb" ], ["**" , " *.ipynb" ]):
369
368
ls_ipynb += glob .glob (os .path .join (path_root , HelperCLI .DIR_NOTEBOOKS , * sub ))
370
369
371
370
os .makedirs (path_docs_ipynb , exist_ok = True )
372
371
ipynb_content = []
373
372
for path_ipynb in tqdm .tqdm (ls_ipynb ):
374
373
ipynb = path_ipynb .split (os .path .sep )
375
- sub_ipynb = os .path .sep .join (ipynb [ipynb .index (HelperCLI .DIR_NOTEBOOKS ) + 1 :])
374
+ sub_ipynb = os .path .sep .join (ipynb [ipynb .index (HelperCLI .DIR_NOTEBOOKS ) + 1 :])
376
375
new_ipynb = os .path .join (path_docs_ipynb , sub_ipynb )
377
376
os .makedirs (os .path .dirname (new_ipynb ), exist_ok = True )
378
- print (f' { path_ipynb } -> { new_ipynb } ' )
377
+ print (f" { path_ipynb } -> { new_ipynb } " )
379
378
380
379
with open (path_ipynb ) as f :
381
380
ipynb = json .load (f )
382
381
383
382
ipynb ["cells" ].append (HelperCLI ._get_card_item_cell (path_ipynb ))
384
383
385
- with open (new_ipynb , 'w' ) as f :
384
+ with open (new_ipynb , "w" ) as f :
386
385
json .dump (ipynb , f )
387
386
388
- ipynb_content .append (os .path .join (' notebooks' , sub_ipynb ))
387
+ ipynb_content .append (os .path .join (" notebooks" , sub_ipynb ))
389
388
390
389
@staticmethod
391
390
def valid_accelerator (dir_path : str ):
@@ -397,7 +396,7 @@ def valid_accelerator(dir_path: str):
397
396
assert fpath , f"Missing Meta file in { dir_path } "
398
397
meta = yaml .safe_load (open (fpath ))
399
398
# default is CPU runtime
400
- accels = [acc .lower () for acc in meta .get ("accelerator" , (' CPU' ))]
399
+ accels = [acc .lower () for acc in meta .get ("accelerator" , (" CPU" ))]
401
400
dev_accels = DEVICE_ACCELERATOR .split ("," )
402
401
return int (any (ac in accels for ac in dev_accels ))
403
402
@@ -413,7 +412,7 @@ def update_env_details(dir_path: str):
413
412
# default is COU runtime
414
413
with open (PATH_REQ_DEFAULT ) as fp :
415
414
req = fp .readlines ()
416
- req += meta .get (' requirements' , [])
415
+ req += meta .get (" requirements" , [])
417
416
req = [r .strip () for r in req ]
418
417
419
418
def _parse (pkg : str , keys : str = " <=>" ) -> str :
@@ -425,12 +424,12 @@ def _parse(pkg: str, keys: str = " <=>") -> str:
425
424
426
425
require = {_parse (r ) for r in req if r }
427
426
env = {_parse (p ): p for p in freeze .freeze ()}
428
- meta [' environment' ] = [env [r ] for r in require ]
429
- meta [' published' ] = datetime .now ().isoformat ()
427
+ meta [" environment" ] = [env [r ] for r in require ]
428
+ meta [" published" ] = datetime .now ().isoformat ()
430
429
431
430
fmeta = os .path .join (HelperCLI .DIR_NOTEBOOKS , dir_path ) + ".yaml"
432
- yaml .safe_dump (meta , stream = open (fmeta , 'w' ), sort_keys = False )
431
+ yaml .safe_dump (meta , stream = open (fmeta , "w" ), sort_keys = False )
433
432
434
433
435
- if __name__ == ' __main__' :
434
+ if __name__ == " __main__" :
436
435
fire .Fire (HelperCLI )
0 commit comments