diff --git a/python/private/pypi/extension.bzl b/python/private/pypi/extension.bzl index 405c22f60e..961206eac3 100644 --- a/python/private/pypi/extension.bzl +++ b/python/private/pypi/extension.bzl @@ -472,6 +472,7 @@ You cannot use both the additive_build_content and additive_build_content_file a index_url = pip_attr.experimental_index_url, extra_index_urls = pip_attr.experimental_extra_index_urls or [], index_url_overrides = pip_attr.experimental_index_url_overrides or {}, + index_strategy = pip_attr.index_strategy, sources = distributions, envsubst = pip_attr.envsubst, # Auth related info @@ -681,6 +682,11 @@ stable. This is equivalent to `--index-url` `pip` option. +:::{warn} +`rules_python` will fallback to using `pip` to download wheels if the requirements +files do not have hashes. +::: + :::{versionchanged} 0.37.0 If {attr}`download_only` is set, then `sdist` archives will be discarded and `pip.parse` will operate in wheel-only mode. @@ -688,13 +694,17 @@ operate in wheel-only mode. """, ), "experimental_index_url_overrides": attr.string_dict( + # TODO @aignas 2025-03-01: consider using string_list_dict so that + # we could have index_url_overrides per package for different + # platforms like what `uv` has. + # See https://docs.astral.sh/uv/configuration/indexes/#-index-url-and-extra-index-url doc = """\ The index URL overrides for each package to use for downloading wheels using bazel downloader. This value is going to be subject to `envsubst` substitutions if necessary. The key is the package name (will be normalized before usage) and the value is the -index URL. +index URLs separated with `,`. This design pattern has been chosen in order to be fully deterministic about which packages come from which source. We want to avoid issues similar to what happened in @@ -702,6 +712,11 @@ https://pytorch.org/blog/compromised-nightly-dependency/. The indexes must support Simple API as described here: https://packaging.python.org/en/latest/specifications/simple-repository-api/ + +:::{versionchanged} VERSION_NEXT_PATCH +This can contain comma separated values per package to allow `torch` being +indexed from multiple sources. +::: """, ), "hub_name": attr.string( @@ -724,6 +739,21 @@ is not required. Each hub is a separate resolution of pip dependencies. This means if different programs need different versions of some library, separate hubs can be created, and each program can use its respective hub's targets. Targets from different hubs should not be used together. +""", + ), + "index_strategy": attr.string( + default = "first-index", + values = ["first-index", "unsafe"], + doc = """\ +The strategy used when fetching package locations from indexes. This is to allow fetching +`torch` from the `torch` maintained and PyPI index so that on different platforms users +can have different torch versions (e.g. gpu accelerated on linux and cpu on the +rest of the platforms). + +See https://docs.astral.sh/uv/configuration/indexes/#searching-across-multiple-indexes. + +:::{versionadded} VERSION_NEXT_PATCH +::: """, ), "parallel_download": attr.bool( diff --git a/python/private/pypi/simpleapi_download.bzl b/python/private/pypi/simpleapi_download.bzl index ef39fb8723..098bcdda4f 100644 --- a/python/private/pypi/simpleapi_download.bzl +++ b/python/private/pypi/simpleapi_download.bzl @@ -31,6 +31,7 @@ def simpleapi_download( parallel_download = True, read_simpleapi = None, get_auth = None, + _print = print, _fail = fail): """Download Simple API HTML. @@ -43,6 +44,8 @@ def simpleapi_download( separate packages. * extra_index_urls: Extra index URLs that will be looked up after the main is looked up. + * index_strategy: The string identifier representing the strategy + used here. Can be either "first-index" or "unsafe". * sources: list[str], the sources to download things for. Each value is the contents of requirements files. * envsubst: list[str], the envsubst vars for performing substitution in index url. @@ -61,6 +64,7 @@ def simpleapi_download( read_simpleapi: a function for reading and parsing of the SimpleAPI contents. Used in tests. get_auth: A function to get auth information passed to read_simpleapi. Used in tests. + _print: a function to print. Used in tests. _fail: a function to print a failure. Used in tests. Returns: @@ -71,6 +75,9 @@ def simpleapi_download( for p, i in (attr.index_url_overrides or {}).items() } + if attr.index_strategy not in ["unsafe", "first-index"]: + fail("TODO") + download_kwargs = {} if bazel_features.external_deps.download_has_block_param: download_kwargs["block"] = not parallel_download @@ -80,8 +87,12 @@ def simpleapi_download( contents = {} index_urls = [attr.index_url] + attr.extra_index_urls read_simpleapi = read_simpleapi or _read_simpleapi + sources = { + pkg: normalize_name(pkg) + for pkg in attr.sources + } - found_on_index = {} + found_on_indexes = {} warn_overrides = False for i, index_url in enumerate(index_urls): if i != 0: @@ -89,45 +100,82 @@ def simpleapi_download( warn_overrides = True async_downloads = {} - sources = [pkg for pkg in attr.sources if pkg not in found_on_index] - for pkg in sources: + for pkg, pkg_normalized in sources.items(): + if pkg not in found_on_indexes: + # We have not found the pkg yet, let's search for it + pass + elif "first-index" == attr.index_strategy and pkg in found_on_indexes: + # We have found it and we are using a safe strategy, let's not + # search anymore. + continue + elif pkg in found_on_indexes and pkg_normalized in index_url_overrides: + # This pkg has been overriden, be strict and use `first-index` strategy + # implicitly. + continue + elif "unsafe" in attr.index_strategy: + # We can search for the packages + pass + else: + fail("BUG: Unknown state of searching of packages") + pkg_normalized = normalize_name(pkg) - result = read_simpleapi( - ctx = ctx, - url = "{}/{}/".format( - index_url_overrides.get(pkg_normalized, index_url).rstrip("/"), - pkg, - ), - attr = attr, - cache = cache, - get_auth = get_auth, - **download_kwargs - ) - if hasattr(result, "wait"): - # We will process it in a separate loop: - async_downloads[pkg] = struct( - pkg_normalized = pkg_normalized, - wait = result.wait, + override_urls = index_url_overrides.get(pkg_normalized, index_url) + for url in override_urls.split(","): + result = read_simpleapi( + ctx = ctx, + url = "{}/{}/".format( + url.rstrip("/"), + pkg, + ), + attr = attr, + cache = cache, + get_auth = get_auth, + **download_kwargs ) - elif result.success: - contents[pkg_normalized] = result.output - found_on_index[pkg] = index_url + if hasattr(result, "wait"): + # We will process it in a separate loop: + async_downloads.setdefault(pkg, []).append( + struct( + pkg_normalized = pkg_normalized, + wait = result.wait, + ), + ) + elif result.success: + current = contents.get( + pkg_normalized, + struct(sdists = {}, whls = {}), + ) + contents[pkg_normalized] = struct( + # Always prefer the current values, so that the first index wins + sdists = result.output.sdists | current.sdists, + whls = result.output.whls | current.whls, + ) + found_on_indexes.setdefault(pkg, []).append(url) if not async_downloads: continue # If we use `block` == False, then we need to have a second loop that is # collecting all of the results as they were being downloaded in parallel. - for pkg, download in async_downloads.items(): - result = download.wait() - - if result.success: - contents[download.pkg_normalized] = result.output - found_on_index[pkg] = index_url - - failed_sources = [pkg for pkg in attr.sources if pkg not in found_on_index] + for pkg, downloads in async_downloads.items(): + for download in downloads: + result = download.wait() + + if result.success: + current = contents.get( + download.pkg_normalized, + struct(sdists = {}, whls = {}), + ) + contents[download.pkg_normalized] = struct( + # Always prefer the current values, so that the first index wins + sdists = result.output.sdists | current.sdists, + whls = result.output.whls | current.whls, + ) + found_on_indexes.setdefault(pkg, []).append(index_url) + + failed_sources = [pkg for pkg in attr.sources if pkg not in found_on_indexes] if failed_sources: - _fail("Failed to download metadata for {} for from urls: {}".format( + _fail("Failed to download metadata for {} from urls: {}".format( failed_sources, index_urls, )) @@ -135,13 +183,12 @@ def simpleapi_download( if warn_overrides: index_url_overrides = { - pkg: found_on_index[pkg] + pkg: ",".join(found_on_indexes[pkg]) for pkg in attr.sources - if found_on_index[pkg] != attr.index_url + if found_on_indexes[pkg] != attr.index_url } - # buildifier: disable=print - print("You can use the following `index_url_overrides` to avoid the 404 warnings:\n{}".format( + _print("You can use the following `index_url_overrides` to avoid the 404 warnings:\n{}".format( render.dict(index_url_overrides), )) diff --git a/tests/pypi/simpleapi_download/simpleapi_download_tests.bzl b/tests/pypi/simpleapi_download/simpleapi_download_tests.bzl index 964d3e25ea..274c94ddb1 100644 --- a/tests/pypi/simpleapi_download/simpleapi_download_tests.bzl +++ b/tests/pypi/simpleapi_download/simpleapi_download_tests.bzl @@ -21,6 +21,7 @@ _tests = [] def _test_simple(env): calls = [] + warnings_suggestion = [] def read_simpleapi(ctx, url, attr, cache, get_auth, block): _ = ctx # buildifier: disable=unused-variable @@ -31,12 +32,18 @@ def _test_simple(env): calls.append(url) if "foo" in url and "main" in url: return struct( - output = "", + output = struct( + sdists = {"": ""}, + whls = {}, + ), success = False, ) else: return struct( - output = "data from {}".format(url), + output = struct( + sdists = {"": "data from {}".format(url)}, + whls = {}, + ), success = True, ) @@ -48,12 +55,14 @@ def _test_simple(env): index_url_overrides = {}, index_url = "main", extra_index_urls = ["extra"], + index_strategy = "first-index", sources = ["foo", "bar", "baz"], envsubst = [], ), cache = {}, parallel_download = True, read_simpleapi = read_simpleapi, + _print = warnings_suggestion.append, ) env.expect.that_collection(calls).contains_exactly([ @@ -63,13 +72,195 @@ def _test_simple(env): "main/foo/", ]) env.expect.that_dict(contents).contains_exactly({ - "bar": "data from main/bar/", - "baz": "data from main/baz/", - "foo": "data from extra/foo/", + "bar": struct( + sdists = {"": "data from main/bar/"}, + whls = {}, + ), + "baz": struct( + sdists = {"": "data from main/baz/"}, + whls = {}, + ), + "foo": struct( + sdists = {"": "data from extra/foo/"}, + whls = {}, + ), }) + env.expect.that_collection(warnings_suggestion).contains_exactly([ + """\ +You can use the following `index_url_overrides` to avoid the 404 warnings: +{ + "foo": "extra", + "bar": "main", + "baz": "main", +}""", + ]) _tests.append(_test_simple) +def _test_overrides_and_precedence(env): + calls = [] + + def read_simpleapi(ctx, url, attr, cache, get_auth, block): + _ = ctx # buildifier: disable=unused-variable + _ = attr + _ = cache + _ = get_auth + env.expect.that_bool(block).equals(False) + calls.append(url) + if "foo" in url and "main" in url: + return struct( + output = struct( + sdists = {"": ""}, + whls = {}, + ), + # This will ensure that we fail the test if we go into this + # branch unexpectedly. + success = False, + ) + else: + return struct( + output = struct( + sdists = {"": "data from {}".format(url)}, + whls = { + url: "whl from {}".format(url), + } if "foo" in url else {}, + ), + success = True, + ) + + contents = simpleapi_download( + ctx = struct( + os = struct(environ = {}), + ), + attr = struct( + index_url_overrides = { + "foo": "extra1,extra2", + }, + index_url = "main", + extra_index_urls = [], + # If we pass overrides, then we will get packages from all indexes. + # However, for packages without index_url_overrides, we will honour + # the strategy setting. + index_strategy = "first-index", + sources = ["foo", "bar", "baz"], + envsubst = [], + ), + cache = {}, + parallel_download = True, + read_simpleapi = read_simpleapi, + _print = fail, + ) + + env.expect.that_collection(calls).contains_exactly([ + "extra1/foo/", + "extra2/foo/", + "main/bar/", + "main/baz/", + ]) + env.expect.that_dict(contents).contains_exactly({ + "bar": struct( + sdists = {"": "data from main/bar/"}, + whls = {}, + ), + "baz": struct( + sdists = {"": "data from main/baz/"}, + whls = {}, + ), + "foo": struct( + # We prioritize the first index + sdists = {"": "data from extra1/foo/"}, + whls = { + "extra1/foo/": "whl from extra1/foo/", + "extra2/foo/": "whl from extra2/foo/", + }, + ), + }) + +_tests.append(_test_overrides_and_precedence) + +def _test_unsafe_strategy(env): + calls = [] + warnings_suggestion = [] + + def read_simpleapi(ctx, url, attr, cache, get_auth, block): + _ = ctx # buildifier: disable=unused-variable + _ = attr + _ = cache + _ = get_auth + env.expect.that_bool(block).equals(False) + calls.append(url) + return struct( + output = struct( + sdists = {"": "data from {}".format(url)}, + whls = { + url: "whl from {}".format(url), + } if "foo" in url else {}, + ), + success = True, + ) + + contents = simpleapi_download( + ctx = struct( + os = struct(environ = {}), + ), + attr = struct( + index_url_overrides = { + "foo": "extra1,extra2", + }, + index_url = "main", + # This field would be ignored for others + extra_index_urls = ["extra"], + # If we pass overrides, then we will get packages from all indexes. + # However, for packages without index_url_overrides, we will honour + # the strategy setting. + index_strategy = "unsafe", + sources = ["foo", "bar", "baz"], + envsubst = [], + ), + cache = {}, + parallel_download = True, + read_simpleapi = read_simpleapi, + _print = warnings_suggestion.append, + ) + + env.expect.that_collection(calls).contains_exactly([ + "extra1/foo/", + "extra2/foo/", + "main/bar/", + "main/baz/", + "extra/bar/", + "extra/baz/", + ]) + env.expect.that_dict(contents).contains_exactly({ + "bar": struct( + sdists = {"": "data from main/bar/"}, + whls = {}, + ), + "baz": struct( + sdists = {"": "data from main/baz/"}, + whls = {}, + ), + "foo": struct( + # We prioritize the first index + sdists = {"": "data from extra1/foo/"}, + whls = { + "extra1/foo/": "whl from extra1/foo/", + "extra2/foo/": "whl from extra2/foo/", + }, + ), + }) + env.expect.that_collection(warnings_suggestion).contains_exactly([ + """\ +You can use the following `index_url_overrides` to avoid the 404 warnings: +{ + "foo": "extra1,extra2", + "bar": "main,extra", + "baz": "main,extra", +}""", + ]) + +_tests.append(_test_unsafe_strategy) + def _test_fail(env): calls = [] fails = [] @@ -83,12 +274,18 @@ def _test_fail(env): calls.append(url) if "foo" in url: return struct( - output = "", + output = struct( + sdists = {"": ""}, + whls = {}, + ), success = False, ) else: return struct( - output = "data from {}".format(url), + output = struct( + sdists = {"": "data from {}".format(url)}, + whls = {}, + ), success = True, ) @@ -100,6 +297,7 @@ def _test_fail(env): index_url_overrides = {}, index_url = "main", extra_index_urls = ["extra"], + index_strategy = "first-index", sources = ["foo", "bar", "baz"], envsubst = [], ), @@ -107,10 +305,11 @@ def _test_fail(env): parallel_download = True, read_simpleapi = read_simpleapi, _fail = fails.append, + _print = fail, ) env.expect.that_collection(fails).contains_exactly([ - """Failed to download metadata for ["foo"] for from urls: ["main", "extra"]""", + """Failed to download metadata for ["foo"] from urls: ["main", "extra"]""", ]) env.expect.that_collection(calls).contains_exactly([ "extra/foo/", @@ -140,12 +339,14 @@ def _test_download_url(env): index_url_overrides = {}, index_url = "https://example.com/main/simple/", extra_index_urls = [], + index_strategy = "first-index", sources = ["foo", "bar", "baz"], envsubst = [], ), cache = {}, parallel_download = False, get_auth = lambda ctx, urls, ctx_attr: struct(), + _print = fail, ) env.expect.that_dict(downloads).contains_exactly({ @@ -175,12 +376,14 @@ def _test_download_url_parallel(env): index_url_overrides = {}, index_url = "https://example.com/main/simple/", extra_index_urls = [], + index_strategy = "first-index", sources = ["foo", "bar", "baz"], envsubst = [], ), cache = {}, parallel_download = True, get_auth = lambda ctx, urls, ctx_attr: struct(), + _print = fail, ) env.expect.that_dict(downloads).contains_exactly({ @@ -210,12 +413,14 @@ def _test_download_envsubst_url(env): index_url_overrides = {}, index_url = "$INDEX_URL", extra_index_urls = [], + index_strategy = "first-index", sources = ["foo", "bar", "baz"], envsubst = ["INDEX_URL"], ), cache = {}, parallel_download = False, get_auth = lambda ctx, urls, ctx_attr: struct(), + _print = fail, ) env.expect.that_dict(downloads).contains_exactly({