diff --git a/setup.py b/setup.py index 8030c5408e..06b938bb31 100644 --- a/setup.py +++ b/setup.py @@ -86,12 +86,8 @@ def _main(): print("-- Git branch:", branch) print("-- Git SHA:", sha) print("-- Git tag:", tag) - # This used to be passed to install_requires - # which would cause pinning against a specific torch version in releases. - # I don't think we want to pin at all? - # TODO: revisit if needed. Maybe it's needed for nightlies. Unsure. - # pytorch_package_dep = _get_pytorch_version() - # print("-- PyTorch dependency:", pytorch_package_dep) + pytorch_package_dep = _get_pytorch_version() + print("-- PyTorch dependency:", pytorch_package_dep) version = _get_version(sha) print("-- Building version", version) @@ -139,7 +135,7 @@ def _main(): "build_ext": setup_helpers.get_build_ext(), "clean": clean, }, - install_requires=[], + install_requires=[pytorch_package_dep], zip_safe=False, )