COMPLETE AND EFFICIENT GRAPH TRANSFORMERS FOR CRYSTAL MATERIAL PROPERTY PREDICTION
Crystal structures are characterized by atomic bases within a primitive unit cell that repeats along a regular lattice throughout 3D space. The periodic and infinite nature of crystals poses unique challenges for geometric graph representation learning. Specifically, constructing graphs that effectively capture the complete geometric information of crystals and handle chiral crystals remains an unsolved and challenging problem. In this paper, we introduce a novel approach that utilizes the periodic patterns of unit cells to establish the lattice-based representation for each atom, enabling efficient and expressive graph representations of crystals. Furthermore, we propose ComFormer, a SE(3) transformer designed specifically for crystalline materials. ComFormer includes two variants; namely, iComFormer that employs invariant geometric descriptors of Euclidean distances and angles, and eComFormer that utilizes equivariant vector representations. Experimental results demonstrate the state-of-the-art predictive accuracy of ComFormer variants on various tasks across three widely used crystal benchmarks.
A periodic crystal unit cell is represented as:
- atom types (composition):
$A = (a_1,\ldots,a_N)$ - fractional coordinates:
$F = (f_1,\ldots,f_N),; f_i \in [0,1)^3$ (stacked as$F \in [0,1)^{3 \times N}$ ) - lattice matrix:
$L \in \mathbb{R}^{3 \times 3}$
DiffCSP designs separate forward corruption processes for
Forward diffusion:
Reparameterized sampling:
Reverse (ancestral) step:
With mean:
Lattice denoising loss:
Because
This implies the wrapped Normal transition density:
As
Score-matching objective:
Sampling typically uses a predictor-corrector scheme: an ancestral predictor combined with a Langevin corrector driven by
DiffCSP builds
Message passing at layer
Periodic Fourier features for relative fractional coordinates
which is periodic-translation invariant under wrapping.
Outputs (noise/score predictions):
- Perov-5: 18,928 perovskite structures; each unit cell contains 5 atoms (ABX$_3$-like), forming a structured CSP benchmark.
- Carbon-24: 10,153 carbon structures; unit cells contain 6-24 atoms with all-carbon composition, useful for assessing one-to-many structure diversity.
- MP-20: 45,231 inorganic crystals (Materials Project subset) with at most 20 atoms per unit cell; widely used for crystal generation and CSP.
- MPTS-52 (Materials Project Time Split): 40,476 crystals with up to 52 atoms per cell; chronological split for temporal generalization. A common split is 27,380 / 5,000 / 8,096 for train/val/test.
Recommended data fields for each sample:
-
atom_types: length-$N$ atomic numbers or element indices -
frac_coords:$N \times 3$ fractional coordinates in$[0,1)$ -
lattice:$3 \times 3$ lattice matrix
Optional fields include material_id, spacegroup, energy, and dataset split tags.
| Dataset | Train | Val | Test |
|---|---|---|---|
| MP-20 | 27136 | 9047 | 9046 |
| Model | Dataset | Match Rate (%) | RMS Dist | GPUs | Training Time | Config | Checkpoint / Log |
|---|---|---|---|---|---|---|---|
| diffcsp_mp20 | mp20 | 51.72 | 0.0591 | 1 | ~13.5 hours | diffcsp_mp20.yaml | checkpoint / log |
# multi-gpu training (example with 4 GPUs)
python -m paddle.distributed.launch --gpus="0,1,2,3" structure_generation/train.py -c structure_generation/configs/diffcsp/diffcsp_mp20.yaml
# single-gpu training
python structure_generation/train.py -c structure_generation/configs/diffcsp/diffcsp_mp20.yaml# Adjust program behavior on the fly using command-line parameters without modifying the configuration file directly.
# Example: --Global.do_eval=True
python structure_generation/train.py -c structure_generation/configs/diffcsp/diffcsp_mp20.yaml Global.do_eval=True Global.do_train=False Global.do_test=False Trainer.pretrained_model_path='path/to/model.pdparams'# Evaluate the model on the test dataset.
python structure_generation/train.py -c structure_generation/configs/diffcsp/diffcsp_mp20.yaml Global.do_eval=False Global.do_train=False Global.do_test=True Trainer.pretrained_model_path='path/to/model.pdparams'# Predict crystal structures using a trained model.
# Mode 1: Use a pre-trained model (downloads automatically).
# Mode 2: Use a custom configuration file and checkpoint.
# Results are saved to the folder specified by --save_path (default: result).
# Mode 1: pre-trained model
python structure_generation/sample.py --model_name='diffcsp_mp20' --weights_name='latest.pdparams' --save_path='result_diffcsp_mp20/' --chemical_formula='LiMnO2'
# Mode 2: custom config + checkpoint
python structure_generation/sample.py --config_path='structure_generation/configs/diffcsp/diffcsp_mp20.yaml' --checkpoint_path='./output/diffcsp_mp20/checkpoints/latest.pdparams' --save_path='result_diffcsp_mp20/' --chemical_formula='LiMnO2'@article{jiao2023crystal,
title={Crystal structure prediction by joint equivariant diffusion},
author={Jiao, Rui and Huang, Wenbing and Lin, Peijia and Han, Jiaqi and Chen, Pin and Lu, Yutong and Liu, Yang},
journal={arXiv preprint arXiv:2309.04475},
year={2023}
}
