Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
50 commits
Select commit Hold shift + click to select a range
5635129
PUBLIC: Add `predecessor_pointers_to_permutation_matrix` and `permuta…
May 30, 2024
b7eb7a6
PUBLIC: Add an additional field to the `quicksort` probing, update it…
May 30, 2024
8ae153c
PUBLIC: Add `clrs_utils.py` to the CLRS dataset. This module contains…
May 31, 2024
8d561d7
PUBLIC: Add CLRS Text accuracy_graphs.ipynb colab and accuracy_data.c…
avlife May 31, 2024
21cf060
Adding huggingface generators for clrs text
mcleish7 Jun 6, 2024
e9d6270
PUBLIC: Update accuracy_data.csv. Update CLRS Text graphs (no undersc…
avlife Jun 6, 2024
f5a35f6
[JAX] Update users of jax.tree.map() to be more careful about how the…
hawkinsp Jun 9, 2024
a58f44d
Internal
RerRayne Jun 14, 2024
efa9af2
Improving quality of huggingface generators
mcleish7 Jun 21, 2024
d892c7d
Apply suggestions from code review
mcleish7 Jun 26, 2024
982185b
Add `ml_collections` to `requirements.txt`.
RerRayne Jun 27, 2024
5ef0ad8
Made 'text' a permanent field and switched to uniform random sampling…
mcleish7 Jun 28, 2024
4eab3ea
Combining the infinite and finite generators and general comment form…
mcleish7 Jun 30, 2024
addaa37
Fix minor linter and codestyle issues
grenlayk Jul 1, 2024
2c26d74
Modify tests and add copyright
grenlayk Jul 4, 2024
62c3852
Fix typo in tests params
grenlayk Jul 4, 2024
1ff4f09
Added script to generate json for all algorithms.
RerRayne Jul 4, 2024
90e496a
Merge pull request #135 from mcleish7/clrs-text-hf
RerRayne Jul 5, 2024
a5314f3
Add support for generating train and val datasets using this script.
RerRayne Jul 5, 2024
9bf6807
Add clrs_text to __init__.py
RerRayne Jul 5, 2024
5631f12
Merge pull request #135 from mcleish7:clrs-text-hf
copybara-github Jul 8, 2024
ea337b4
Roll forward PR #104
RerRayne Jul 10, 2024
956cff7
Merge pull request #146 from google-deepmind/test-branch
PetarV- Jul 10, 2024
b8e671d
Add CLRS-Text details in the README files.
PetarV- Jul 11, 2024
b7b71f4
Use pythonic swap for variables in sorting.py
gurux13 Jul 11, 2024
cdd2889
Merge pull request #147 from gurux13:master
copybara-github Jul 11, 2024
d12851a
Bump CLRS version to 2.0.0.
RerRayne Jul 15, 2024
3c1e364
Update pypi-publish.yml
RerRayne Jul 17, 2024
42c2968
Merge pull request #150 from google-deepmind:RerRayne-update-pypi-wor…
copybara-github Jul 17, 2024
8665a6f
Update pypi-publish.yml
RerRayne Jul 17, 2024
a887b07
Merge pull request #151 from google-deepmind:RerRayne-fix-pypi-yaml
copybara-github Jul 17, 2024
726742d
Update pypi-publish.yml
RerRayne Jul 17, 2024
25127ee
Update pypi-publish.yml
RerRayne Jul 17, 2024
9489aef
Merge pull request #152 from google-deepmind/RerRayne-patch-pypi-fix
RerRayne Jul 18, 2024
a9f524b
Add `huggingface_generators` to `__init__.py`.
RerRayne Jul 18, 2024
832ac32
Update CLRS GitHub workflow to use `actions/checkout@v4` and `actions…
RerRayne Jul 18, 2024
d1c2ad7
Bump CLRS version to 2.0.1
RerRayne Jul 18, 2024
9a5c2c0
Update version numbers from requirements.txt
RerRayne Sep 6, 2024
08ad3d1
Log the warning only for the first on-the-fly sampler.
sinopalnikov Sep 22, 2024
a891bc2
Update run.py.
grenlayk Nov 13, 2024
c9cf120
Adds a "debug" mode to Baselines, in order to support decoding state.
PetarV- Nov 27, 2024
d76e598
Resolve unsoundness caught by pytype --strict-none-binding.
Dec 19, 2024
97fd88d
[CLRS]
RerRayne Jan 20, 2025
dcf8643
[CLRS] Add num_decimals_in_float parameter to the data generators.
RerRayne Jan 20, 2025
7712b7b
[CLRS] Incorporate a collinearity check into the ConvexHull sampler t…
RerRayne Jan 20, 2025
fbaa762
Increment minor version number for 2.0.3 release.
RerRayne Jan 25, 2025
411692b
Internal.
RerRayne Jan 28, 2025
7172d73
Replace numpy cross product with hand-written version
RerRayne Mar 20, 2025
9850007
Internal change
h-joo Apr 30, 2025
9ca575c
Fix bug where train lengths were modified across algorithms.
RerRayne Oct 15, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions .github/workflows/pypi-publish.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,21 @@ name: pypi

on:
release:
types: [created]

types: [created, published]
branches: [main, master]
jobs:
deploy:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v1
uses: actions/setup-python@v4
with:
python-version: '3.x'
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install setuptools wheel twine
- name: Check consistency between the package version and release tag
run: |
RELEASE_VER=${GITHUB_REF#refs/*/}
Expand All @@ -21,10 +25,6 @@ jobs:
then
echo "package ver. ($PACKAGE_VER) != release ver. ($RELEASE_VER)"; exit 1
fi
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install setuptools wheel twine
- name: Build and publish
env:
TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }}
Expand Down
25 changes: 23 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ pip install dm-clrs
or directly from GitHub (updated more frequently):

```shell
pip install git+git://github.com/deepmind/clrs.git
pip install git+https://github.com/google-deepmind/clrs.git
```

You may prefer to install it in a virtual environment if any requirements
Expand All @@ -29,7 +29,7 @@ clash with your Python installation:
```shell
python3 -m venv clrs_env
source clrs_env/bin/activate
pip install git+git://github.com/deepmind/clrs.git
pip install git+https://github.com/google-deepmind/clrs.git
```

Once installed you can run our example baseline model:
Expand Down Expand Up @@ -225,6 +225,14 @@ for feedback in _iterate_sampler(batch_size=32):

```

Most recently, we are offering [**CLRS-Text**](https://github.com/google-deepmind/clrs/tree/master/clrs/_src/clrs_text),
a text-based variant of the benchmark suitable for training and evaluating the algorithmic reasoning
capabilities of language models. Please see the relevant subfolder for a
dedicated README file.

You may also see the [companion paper](https://arxiv.org/abs/2406.04229) on
CLRS-Text.

## Adding new algorithms

Adding a new algorithm to the task suite requires the following steps:
Expand Down Expand Up @@ -259,3 +267,16 @@ To cite the CLRS Algorithmic Reasoning Benchmark:
year={2022}
}
```

To cite the CLRS-Text Algorithmic Reasoning Language Benchmark:

```latex
@article{deepmind2024clrstext,
title={The CLRS-Text Algorithmic Reasoning Language Benchmark},
author={Larisa Markeeva and Sean McLeish and Borja Ibarz and Wilfried Bounsi
and Olga Kozlova and Alex Vitvitskyi and Charles Blundell and
Tom Goldstein and Avi Schwarzschild and Petar Veli\v{c}kovi\'{c}},
journal={arXiv preprint arXiv:2406.04229},
year={2024}
}
```
14 changes: 13 additions & 1 deletion clrs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,21 +16,30 @@
"""The CLRS Algorithmic Reasoning Benchmark."""

from clrs import models

from clrs._src import algorithms
from clrs._src import clrs_text
from clrs._src import decoders
from clrs._src import processors
from clrs._src import specs

from clrs._src.dataset import chunkify
from clrs._src.dataset import CLRSDataset
from clrs._src.dataset import create_chunked_dataset
from clrs._src.dataset import create_dataset
from clrs._src.dataset import get_clrs_folder
from clrs._src.dataset import get_dataset_gcp_url

from clrs._src.evaluation import evaluate
from clrs._src.evaluation import evaluate_hints

from clrs._src.model import Model

from clrs._src.probing import DataPoint
from clrs._src.probing import predecessor_to_cyclic_predecessor_and_first

from clrs._src.processors import get_processor_factory

from clrs._src.samplers import build_sampler
from clrs._src.samplers import CLRS30
from clrs._src.samplers import Features
Expand All @@ -40,6 +49,7 @@
from clrs._src.samplers import process_random_pos
from clrs._src.samplers import Sampler
from clrs._src.samplers import Trajectory

from clrs._src.specs import ALGO_IDX_INPUT_NAME
from clrs._src.specs import CLRS_30_ALGS_SETTINGS
from clrs._src.specs import Location
Expand All @@ -49,7 +59,7 @@
from clrs._src.specs import Stage
from clrs._src.specs import Type

__version__ = "1.0.0"
__version__ = "2.0.3"

__all__ = (
"ALGO_IDX_INPUT_NAME",
Expand All @@ -59,6 +69,7 @@
"CLRS_30_ALGS_SETTINGS",
"create_chunked_dataset",
"create_dataset",
"clrs_text",
"get_clrs_folder",
"get_dataset_gcp_url",
"get_processor_factory",
Expand All @@ -67,6 +78,7 @@
"process_permutations",
"process_pred_as_input",
"process_random_pos",
"specs",
"evaluate",
"evaluate_hints",
"Features",
Expand Down
12 changes: 8 additions & 4 deletions clrs/_src/algorithms/searching.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,8 +179,10 @@ def partition(A, A_pos, p, r, target, probes):
'i': probing.mask_one(A_pos[i + 1], A.shape[0]),
'j': probing.mask_one(A_pos[j], A.shape[0]),
'i_rank': (i + 1) * 1.0 / A.shape[0],
'target': target * 1.0 / A.shape[0]
})
'target': target * 1.0 / A.shape[0],
'pivot': probing.mask_one(A_pos[r], A.shape[0]),
},
)

tmp = A[i + 1]
A[i + 1] = A[r]
Expand All @@ -199,8 +201,10 @@ def partition(A, A_pos, p, r, target, probes):
'i': probing.mask_one(A_pos[i + 1], A.shape[0]),
'j': probing.mask_one(A_pos[r], A.shape[0]),
'i_rank': (i + 1 - p) * 1.0 / A.shape[0],
'target': target * 1.0 / A.shape[0]
})
'target': target * 1.0 / A.shape[0],
'pivot': probing.mask_one(A_pos[i + 1], A.shape[0]),
},
)

return i + 1

Expand Down
43 changes: 10 additions & 33 deletions clrs/_src/algorithms/sorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,13 +123,8 @@ def bubble_sort(A: _Array) -> _Out:
for i in range(A.shape[0] - 1):
for j in reversed(range(i + 1, A.shape[0])):
if A[j] < A[j - 1]:
tmp = A[j]
A[j] = A[j - 1]
A[j - 1] = tmp

tmp = A_pos[j]
A_pos[j] = A_pos[j - 1]
A_pos[j - 1] = tmp
A[j], A[j - 1] = A[j - 1], A[j]
A_pos[j], A_pos[j - 1] = A_pos[j - 1], A_pos[j]

probing.push(
probes,
Expand Down Expand Up @@ -190,13 +185,8 @@ def max_heapify(A, i, heap_size, ind, phase):
if r < heap_size and A[r] > A[largest]:
largest = r
if largest != i:
tmp = A[i]
A[i] = A[largest]
A[largest] = tmp

tmp = A_pos[i]
A_pos[i] = A_pos[largest]
A_pos[largest] = tmp
A[i], A[largest] = A[largest], A[i]
A_pos[i], A_pos[largest] = A_pos[largest], A_pos[i]

probing.push(
probes,
Expand All @@ -221,13 +211,8 @@ def build_max_heap(A):
build_max_heap(A)
heap_size = A.shape[0]
for i in reversed(range(1, A.shape[0])):
tmp = A[0]
A[0] = A[i]
A[i] = tmp

tmp = A_pos[0]
A_pos[0] = A_pos[i]
A_pos[i] = tmp
A[0], A[i] = A[i], A[0]
A_pos[0], A_pos[i] = A_pos[i], A_pos[0]

heap_size -= 1

Expand Down Expand Up @@ -268,12 +253,8 @@ def partition(A, A_pos, p, r, probes):
for j in range(p, r):
if A[j] <= x:
i += 1
tmp = A[i]
A[i] = A[j]
A[j] = tmp
tmp = A_pos[i]
A_pos[i] = A_pos[j]
A_pos[j] = tmp
A[i], A[j] = A[j], A[i]
A_pos[i], A_pos[j] = A_pos[j], A_pos[i]

probing.push(
probes,
Expand All @@ -286,12 +267,8 @@ def partition(A, A_pos, p, r, probes):
'j': probing.mask_one(A_pos[j], A.shape[0])
})

tmp = A[i + 1]
A[i + 1] = A[r]
A[r] = tmp
tmp = A_pos[i + 1]
A_pos[i + 1] = A_pos[r]
A_pos[r] = tmp
A[i + 1], A[r] = A[r], A[i + 1]
A_pos[i + 1], A_pos[r] = A_pos[r], A_pos[i + 1]

probing.push(
probes,
Expand Down
27 changes: 22 additions & 5 deletions clrs/_src/baselines.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ def __init__(
hint_repred_mode: str = 'soft',
name: str = 'base_model',
nb_msg_passing_steps: int = 1,
debug: bool = False,
):
"""Constructor for BaselineModel.

Expand Down Expand Up @@ -199,6 +200,7 @@ def __init__(
- 'hard_on_eval', which is soft for training and hard for evaluation.
name: Model name.
nb_msg_passing_steps: Number of message passing steps per hint.
debug: If True, the model run in debug mode, outputting all hidden state.

Raises:
ValueError: if `encode_hints=True` and `decode_hints=False`.
Expand All @@ -223,6 +225,7 @@ def __init__(
self.opt = optax.adam(learning_rate)

self.nb_msg_passing_steps = nb_msg_passing_steps
self.debug = debug

self.nb_dims = []
if isinstance(dummy_trajectory, _Feedback):
Expand Down Expand Up @@ -253,7 +256,8 @@ def _use_net(*args, **kwargs):
processor_factory, use_lstm, encoder_init,
dropout_prob, hint_teacher_forcing,
hint_repred_mode,
self.nb_dims, self.nb_msg_passing_steps)(*args, **kwargs)
self.nb_dims, self.nb_msg_passing_steps,
self.debug)(*args, **kwargs)

self.net_fn = hk.transform(_use_net)
pmap_args = dict(axis_name='batch', devices=jax.local_devices())
Expand Down Expand Up @@ -324,18 +328,25 @@ def _feedback(self, params, rng_key, feedback, opt_state, algorithm_index):
def _predict(self, params, rng_key: hk.PRNGSequence, features: _Features,
algorithm_index: int, return_hints: bool,
return_all_outputs: bool):
outs, hint_preds = self.net_fn.apply(
net_outputs = self.net_fn.apply(
params, rng_key, [features],
repred=True, algorithm_index=algorithm_index,
return_hints=return_hints,
return_all_outputs=return_all_outputs)
if self.debug:
outs, hint_preds, hidden_states = net_outputs
else:
outs, hint_preds = net_outputs
outs = decoders.postprocess(self._spec[algorithm_index],
outs,
sinkhorn_temperature=0.1,
sinkhorn_steps=50,
hard=True,
)
return outs, hint_preds
if self.debug:
return outs, hint_preds, hidden_states
else:
return outs, hint_preds

def compute_grad(
self,
Expand Down Expand Up @@ -394,12 +405,16 @@ def predict(self, rng_key: hk.PRNGSequence, features: _Features,

def _loss(self, params, rng_key, feedback, algorithm_index):
"""Calculates model loss f(feedback; params)."""
output_preds, hint_preds = self.net_fn.apply(
outputs = self.net_fn.apply(
params, rng_key, [feedback.features],
repred=False,
algorithm_index=algorithm_index,
return_hints=True,
return_all_outputs=False)
if self.debug:
output_preds, hint_preds, _ = outputs
else:
output_preds, hint_preds = outputs

nb_nodes = _nb_nodes(feedback, is_chunked=False)
lengths = feedback.features.lengths
Expand Down Expand Up @@ -766,7 +781,9 @@ def _keep_in_algo(k, v):
masked_grads = grads
else:
masked_grads = {k: _keep_in_algo(k, v) for k, v in grads.items()}
flat_grads, treedef = jax.tree_util.tree_flatten(masked_grads)
flat_grads, treedef = jax.tree_util.tree_flatten(
masked_grads, is_leaf=lambda x: x is None
)
flat_opt_state = jax.tree_util.tree_map(
lambda _, x: x # pylint:disable=g-long-lambda
if isinstance(x, (np.ndarray, jax.Array))
Expand Down
Loading