Skip to content

Commit

Permalink
Merge branch 'main' into include_additional_CIFs
Browse files Browse the repository at this point in the history
  • Loading branch information
ClaudioZeniMRSC committed Feb 24, 2025
2 parents 3cd34c5 + e1c3cca commit ab23939
Show file tree
Hide file tree
Showing 8 changed files with 115 additions and 89 deletions.
31 changes: 31 additions & 0 deletions .github/ISSUE_TEMPLATE/bug_report.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
---
name: Bug report
about: Create a report to help us improve
title: ''
labels: ''
assignees: ''

---

**Describe the bug**
A clear and concise description of what the bug is.

**To Reproduce**
Steps to reproduce the behavior:
1. Go to '...'
2. Click on '....'
3. Scroll down to '....'
4. See error

**Expected behavior**
A clear and concise description of what you expected to happen.

**Screenshots**
If applicable, add screenshots to help explain your problem.

**Desktop (please complete the following information):**
- OS: [e.g. iOS]
- Version [e.g. 22]

**Additional context**
Add any other context about the problem here.
20 changes: 20 additions & 0 deletions .github/ISSUE_TEMPLATE/feature_request.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
---
name: Feature request
about: Suggest an idea for this project
title: ''
labels: ''
assignees: ''

---

**Is your feature request related to a problem? Please describe.**
A clear and concise description of what the problem is. Ex. I'm always frustrated when [...]

**Describe the solution you'd like**
A clear and concise description of what you want to happen.

**Describe alternatives you've considered**
A clear and concise description of any alternative solutions or features you've considered.

**Additional context**
Add any other context or screenshots about the feature request here.
57 changes: 48 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,26 +40,25 @@ source .venv/bin/activate
uv pip install -e .
```

Note that our datasets and model checkpoints are provided inside this repo via [Git Large File Storage (LFS)](https://git-lfs.com/). To find out whether LFS is installed on your machine, run
Note that our datasets and model checkpoints are provided inside this repo via [Git Large File Storage (LFS)](https://git-lfs.com/).
To find out whether LFS is installed on your machine, run
```bash
git lfs --version
```
If this prints some version like `git-lfs/3.0.2 (GitHub; linux amd64; go 1.18.1)`, you can skip the following step.

### Apple Silicon
> [!WARNING]
> Running MatterGen on Apple Silicon is **experimental**. Use at your own risk.
> Further, you need to run `export PYTORCH_ENABLE_MPS_FALLBACK=1` before any training or generation run.


### Install Git LFS
If Git LFS was not installed before you cloned this repo, you can install it via:
```bash
sudo apt install git-lfs
git lfs install
```

### Apple Silicon
> [!WARNING]
> Running MatterGen on Apple Silicon is **experimental**. Use at your own risk.
> Further, you need to run `export PYTORCH_ENABLE_MPS_FALLBACK=1` before any training or generation run.
## Get started with a pre-trained model
We provide checkpoints of an unconditional base version of MatterGen as well as fine-tuned models for these properties:
* `mattergen_base`: unconditional base model
Expand All @@ -71,7 +70,10 @@ We provide checkpoints of an unconditional base version of MatterGen as well as
* `dft_mag_density_hhi_score`: fine-tuned model jointly conditioned on magnetic density from DFT and HHI score
* `chemical_system_energy_above_hull`: fine-tuned model jointly conditioned on chemical system and energy above hull from DFT

The checkpoints are located at `checkpoints/<model_name>` and are also available on [Hugging Face](https://huggingface.co/microsoft/mattergen).
The checkpoints are located at `checkpoints/<model_name>` and are also available on [Hugging Face](https://huggingface.co/microsoft/mattergen). By default, they are downloaded from Huggingface when requested. You can also manually download them from Git LFS via
```bash
git lfs pull -I checkpoints/<model_name> --exclude=""
```

> [!NOTE]
> The checkpoints provided were re-trained using this repository, i.e., are not identical to the ones used in the paper. Hence, results may slightly deviate from those in the publication.
Expand Down Expand Up @@ -142,6 +144,43 @@ This script will try to read structures from disk in the following precedence or
* If `$RESULTS_PATH` points to a directory, it will read all `.cif`, `.xyz`, or `.extxyz` files in the order they occur in `os.listdir`.

Here, we expect `energies.npy` to be a numpy array with the entries being `float` energies in the same order as the structures read from `$RESULTS_PATH`.

### Evaluate using your own reference dataset

> [!IMPORTANT]
> If you are planning to use MatterSim to evaluate the stability of the generated structures, then the reference dataset you provide must contain energies
> that are compatible with MatterSim, meaning they should be either DFT-computed energies calculated according to the Materials Project Compatbility scheme,
> or energies directly computed with MatterSim.
If you want to use your own custom dataset for evaluation, you first need to serialize and save it by doing so:

``` python
from mattergen.evaluation.reference.reference_dataset import ReferenceDataset
from mattergen.evaluation.reference.reference_dataset_serializer import LMDBGZSerializer


reference_dataset = ReferenceDataset.from_entries(name="my_reference_dataset", entries=entries)
LMDBGZSerializer().serialize(reference_dataset, "path_to_file.gz")
```

where `entries` is a list of `pymatgen.entries.computed_entries.ComputedStructureEntry` objects containing structure-energy pairs for each structure.

By default, we apply the MaterialsProject2020Compatibility energy correction scheme to all input structures during evaluation, and assume that the reference dataset
has already been pre-processed using the same compatibility scheme. Therefore, unless you have already done this, you should obtain the `entries` object for
your custom reference dataset in the following way:

``` python
from mattergen.evaluation.utils.vasprunlike import VasprunLike
from pymatgen.entries.compatibility import MaterialsProject2020Compatibility

entries = []
for structure, energy in zip(structures, energies)
vasprun_like = VasprunLike(structure=structure, energy=energy)
entries.append(vasprun_like.get_computed_entry(
inc_structure=True, energy_correction_scheme=MaterialsProject2020Compatibility()
))
```

## Train MatterGen yourself
Before we can train MatterGen from scratch, we have to unpack and preprocess the dataset files.

Expand Down
4 changes: 2 additions & 2 deletions mattergen/common/gemnet/layers/basis_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ def real_sph_harm(
L_maxdegree: int, use_theta: bool, use_phi: bool = True, zero_m_only: bool = True
) -> List[List[Any]]:
"""
Computes formula strings of the the real part of the spherical harmonics up to degree L (excluded).
Computes formula strings of the real part of the spherical harmonics up to degree L (excluded).
Variables are either spherical coordinates phi and theta (or cartesian coordinates x,y,z) on the UNIT SPHERE.
Parameters
Expand All @@ -209,7 +209,7 @@ def real_sph_harm(
Returns
-------
Y_lm_real: list
Computes formula strings of the the real part of the spherical harmonics up
Computes formula strings of the real part of the spherical harmonics up
to degree L (where degree L is not excluded).
In total L^2 many sph harm exist up to degree L (excluded). However, if zero_m_only only is True then
the total count is reduced to be only L many.
Expand Down
2 changes: 1 addition & 1 deletion mattergen/common/utils/eval_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def save_structures(output_path: Path, structures: Sequence[Structure]) -> None:
with ZipFile(output_path / GENERATED_CRYSTALS_ZIP_FILE_NAME, "w") as zip_obj:
for ix, ase_atom in enumerate(ase_atoms):
ase.io.write(f"/tmp/gen_{ix}.cif", ase_atom, format="cif")
zip_obj.write(f"/tmp/gen_{ix}.cif")
zip_obj.write(f"/tmp/gen_{ix}.cif", arcname=f"gen_{ix}.cif")
except IOError as e:
print(f"Got error {e} writing the generated structures to disk.")

Expand Down
9 changes: 5 additions & 4 deletions mattergen/evaluation/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,11 @@ def evaluate(
structures: List of structures to evaluate.
relax: Whether to relax the structures before evaluation. Note that if this is run, `energies` will be ignored.
energies: Energies of the structures if already relaxed and computed externally (e.g., from DFT).
reference_dataset: Reference dataset.
ordered_structure_matcher: Matcher for ordered structures.
disordered_structure_matcher: Matcher for disordered structures.
n_jobs: Number of parallel jobs.
reference: Reference dataset. If this is None, the default reference dataset will be used.
structure_matcher: Structure matcher to use for matching the structures.
save_as: Save the metrics as a JSON file.
potential_load_path: Path to the Machine Learning potential to use for relaxation.
device: Device to use for relaxation.
Returns:
metrics: a dictionary of metrics and their values.
Expand Down
74 changes: 1 addition & 73 deletions mattergen/evaluation/reference/reference_dataset_serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from tqdm.autonotebook import tqdm

from mattergen.evaluation.reference.reference_dataset import ReferenceDataset, ReferenceDatasetImpl
from mattergen.evaluation.utils.lmdb_utils import lmdb_get, lmdb_open, lmdb_read_metadata
from mattergen.evaluation.utils.lmdb_utils import lmdb_get, lmdb_open, lmdb_put, lmdb_read_metadata


def gzip_compress(file_path: str | os.PathLike, output_dir: str | os.PathLike) -> Path:
Expand All @@ -44,78 +44,6 @@ class LmdbNotFoundError(Exception):
pass


def lmdb_open(db_path: str | os.PathLike, readonly: bool = False) -> lmdb.Environment:
if readonly:
return lmdb.open(
str(db_path),
subdir=False,
readonly=True,
lock=False,
readahead=False,
meminit=False,
max_readers=1,
)
else:
return lmdb.open(
str(db_path),
map_size=1099511627776 * 2,
subdir=False,
meminit=False,
map_async=True,
)


def lmdb_read_metadata(db_path: str | os.PathLike, key: str, default=None) -> Any:
with lmdb_open(db_path, readonly=True) as db:
with db.begin() as txn:
result = lmdb_get(txn, key, default=default)
return result


def lmdb_get(
txn: lmdb.Transaction, key: str, default: Any = None, raise_if_missing: bool = True
) -> Any:
"""
Fetches a record from a database.
Args:
txn: LMDB transaction (use env.begin())
key: key of the data to be fetched.
default: default value to be used if the record doesn't exist.
raise_if_missing: raise LmdbNotFoundError if the record doesn't exist
and no default value was given.
Returns:
the value of the retrieved data.
"""
value = txn.get(key.encode("ascii"))
if value is None:
if default is None and raise_if_missing:
raise LmdbNotFoundError(
f"Key {key} not found in database but default was not provided."
)
return default
return pickle.loads(value)


def lmdb_put(txn: lmdb.Transaction, key: str, value: Any) -> bool:
"""
Stores a record in a database.
Args:
txn: LMDB transaction (use env.begin())
key: key of the data to be stored.
value: value of the data to be stored (needs to be picklable).
Returns:
True if it was written.
"""
return txn.put(
key.encode("ascii"),
pickle.dumps(value, protocol=pickle.HIGHEST_PROTOCOL),
)


class LMDBGZSerializer():
def __init__(
self,
Expand Down
7 changes: 7 additions & 0 deletions mattergen/scripts/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from mattergen.common.utils.eval_utils import load_structures
from mattergen.common.utils.globals import get_device
from mattergen.evaluation.evaluate import evaluate
from mattergen.evaluation.reference.reference_dataset_serializer import LMDBGZSerializer
from mattergen.evaluation.utils.structure_matcher import (
DefaultDisorderedStructureMatcher,
DefaultOrderedStructureMatcher,
Expand All @@ -26,6 +27,7 @@ def main(
potential_load_path: (
Literal["MatterSim-v1.0.0-1M.pth", "MatterSim-v1.0.0-5M.pth"] | None
) = None,
reference_dataset_path: str | None = None,
device: str = str(get_device()),
):
structures = load_structures(Path(structures_path))
Expand All @@ -35,13 +37,18 @@ def main(
if structure_matcher == "disordered"
else DefaultOrderedStructureMatcher()
)
reference = None
if reference_dataset_path:
reference = LMDBGZSerializer().deserialize(reference_dataset_path)

metrics = evaluate(
structures=structures,
relax=relax,
energies=energies,
structure_matcher=structure_matcher,
save_as=save_as,
potential_load_path=potential_load_path,
reference=reference,
device=device,
)
print(json.dumps(metrics, indent=2))
Expand Down

0 comments on commit ab23939

Please sign in to comment.