1- #otherwise conda install may fail
2- del __file__
3- __package__ = None
4- __spec__ = None
5-
61import sys
72import importlib
83import importlib .util
94import argparse
5+ import subprocess
106parser = argparse .ArgumentParser ()
11- parser .add_argument ("--channel" , help = "pytorch channel" , type = str , default = 'pytorch' )
127parser .add_argument ("--cuda" , help = "install nvidia gpu support" , action = "store_true" , default = False )
138parser .add_argument ("--package" , help = "install specific package" , action = 'append' , nargs = '+' , default = [])
149args , unknown = parser .parse_known_args ()
1510
16- if importlib .util .find_spec ("conda " ) is None :
17- print ("Error: A Conda environment is required to install the required packages." , file = sys .stderr )
11+ if importlib .util .find_spec ("pip " ) is None :
12+ print ("Error: Pip is required to install the required packages." , file = sys .stderr )
1813 exit ()
1914
20- import conda .cli .python_api as Conda
21-
22- #increase rows (from default 20 when no terminal is found) to display all parallel packages downloads at once
23- from tqdm import tqdm
24- init_source = tqdm .__init__
25- def init_patch (self , ** kwargs ):
26- kwargs ['ncols' ]= 80
27- kwargs ['nrows' ]= 80
28- init_source (self , ** kwargs )
29- tqdm .__init__ = init_patch
30-
3115if not args .package :
32- #https://edcarp.github.io/introduction-to-conda-for-data-scientists/03-using-packages-and-channels/index.html#alternative-syntax-for-installing-packages-from-specific-channels
33- conda_install = f"pytorch torchvision torchaudio torchtext"
16+ #first install python-graphviz as it only exist as a conda package, and conda is recommended before pip: https://www.anaconda.com/blog/using-pip-in-a-conda-environment
17+ if importlib .util .find_spec ("conda" ) is None :
18+ print ("Error: Conda is required to install the graphviz package." , file = sys .stderr )
19+ exit ()
20+ else :
21+ print ("Downloading and installing the graphviz package..." )
22+ print ("" )
23+ result = subprocess .run ([sys .executable , "-m" , "conda" , "install" , "-y" , "python-graphviz" , "-c" , "conda-forge" ])
24+ if result .returncode != 0 :
25+ exit (result .returncode )
26+ print ("" )
27+
28+ pip_install = "torch torchvision torchaudio torchtext"
3429 if (sys .platform .startswith ('win' ) or sys .platform .startswith ('linux' )):
3530 if args .cuda :
3631 print ("Checking the latest supported CUDA version..." )
37- highest_cuda_version = ( 11 , 6 ) # highest supported cuda version for PyTorch 1.12
32+ highest_cuda_version = 118 #11.8 highest supported cuda version for PyTorch 2.0
3833 import requests
3934 try :
40- pytorch_repo = requests .get ("https://anaconda. org/" + args . channel + "/pytorch/files " )
35+ pytorch_repo = requests .get ("https://download.pytorch. org/whl/torch " )
4136 except :
4237 print ("Could not retrieve the latest supported CUDA version" )
4338 else :
4439 import re
45- regex_request = re .compile ("cuda([0-9]+. [0-9]+)" )
40+ regex_request = re .compile ("cu( [0-9]+)" )
4641 results = re .findall (regex_request , pytorch_repo .text )
47- highest_cuda_version = ( 11 , 6 )
42+ highest_cuda_version = 118
4843 for cuda_string in results :
49- cuda_version = tuple ( int (i ) for i in cuda_string . split ( '.' ) )
44+ cuda_version = int (cuda_string )
5045 if cuda_version > highest_cuda_version :
5146 highest_cuda_version = cuda_version
52- highest_cuda_string = '.' . join ([ str (value ) for value in highest_cuda_version ])
47+ highest_cuda_string = str ( highest_cuda_version )[: 2 ] + "." + str (highest_cuda_version )[ 2 :]
5348 print ("Using CUDA " + highest_cuda_string )
5449 print ("" )
55- conda_install += " pytorch-cuda=" + highest_cuda_string + " -c " + args .channel + " -c nvidia"
56- else :
57- conda_install += " cpuonly -c " + args .channel
58- else :
59- conda_install += " -c " + args .channel
60- print (f"Downloading and installing { args .channel } packages..." )
50+ pip_install += " --index-url https://download.pytorch.org/whl/cu" + str (highest_cuda_version )
51+
52+ print ("Downloading and installing pytorch packages..." )
6153 print ("" )
62- # https://stackoverflow.com/questions/41767340/using-conda-install-within-a-python-script
63- ( stdout_str , stderr_str , return_code_int ) = Conda . run_command ( Conda . Commands . INSTALL , conda_install .split (), use_exception_handler = True , stdout = sys . stdout , stderr = sys . stderr )
64- if return_code_int != 0 :
65- exit (return_code_int )
54+
55+ result = subprocess . run ([ sys . executable , "-m" , "pip" , "install" ] + pip_install .split ())
56+ if result . returncode != 0 :
57+ exit (result . returncode )
6658 print ("" )
6759
6860 # onnx required for onnx export
@@ -73,16 +65,13 @@ def init_patch(self, **kwargs):
7365 # python-graphviz required by torchstudio graph
7466 # paramiko required for ssh connections (+updated cffi required on intel mac)
7567 # pysoundfile required by torchaudio datasets: https://pytorch.org/audio/stable/backend.html#soundfile-backend
76- conda_install = "onnx datasets scipy pandas matplotlib-base python-graphviz paramiko pysoundfile"
77- if sys .platform .startswith ('darwin' ):
78- conda_install += " cffi"
68+ pip_install = "onnx datasets scipy pandas matplotlib paramiko pysoundfile"
7969
8070else :
81- conda_install = " " .join (args .package [0 ])
71+ pip_install = " " .join (args .package [0 ])
8272
83- print ("Downloading and installing conda-forge packages..." )
73+ print ("Downloading and installing additional packages..." )
8474print ("" )
85- conda_install += " -c conda-forge"
86- (stdout_str , stderr_str , return_code_int ) = Conda .run_command (Conda .Commands .INSTALL ,conda_install .split (),use_exception_handler = True ,stdout = sys .stdout ,stderr = sys .stderr )
87- if return_code_int != 0 :
88- exit (return_code_int )
75+ result = subprocess .run ([sys .executable , "-m" , "pip" , "install" ]+ pip_install .split ())
76+ if result .returncode != 0 :
77+ exit (result .returncode )
0 commit comments