Skip to content

vpuri3/FLARE.py

Folders and files

NameName
Last commit message
Last commit date

Latest commit

Β 

History

22 Commits
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 

Repository files navigation

πŸŽ‡ FLARE: Fast Low-rank Attention Routing Engine

Abstract

The quadratic complexity of self-attention limits its applicability and scalability on large unstructured meshes. We introduce Fast Low-rank Attention Routing Engine (FLARE), a linear complexity self-attention mechanism that routes attention through fixed-length latent sequences. Each attention head performs global communication among $N$ tokens by projecting the input sequence onto a fixed length latent sequence of $M \ll N$ tokens using learnable query tokens. By routing attention through a bottleneck sequence, FLARE learns a low-rank form of attention that can be applied at $\mathcal{O}(NM)$ cost. FLARE not only scales to unprecedented problem sizes, but also delivers superior accuracy compared to state-of-the-art neural PDE surrogates across diverse benchmarks. We also release a new additive manufacturing dataset to spur further research. Our code is available at https://github.com/vpuri3/FLARE.py.

Highlights

  • Linear complexity token mixing. FLARE is an efficient self-attention mechanism designed to learn on long sequences such as point clouds. By replacing full self-attention with low-rank projections and reconstructions, FLARE achieves linear complexity in the number of points.
  • Superior accuracy. Across multiple PDE benchmarks, FLARE achieves superior predictive accuracy compared to leading neural surrogate models, despite operating with fewer parameters, and at much lower computational complexity.
  • Scalability. FLARE is built entirely from standard fused attention primitives, ensuring high GPU utilization and ease of integration into existing transformer architectures. As such, FLARE enables end-to-end training on unstructured meshes with one million points (see Figure below) without distributed computing (Luo et al., 2025) or memory offloading – the largest scale demonstrated for transformer-based PDE surrogates.

FLARE

FLARE is a simple yet powerful mechanism designed to break the scalability barrier in PDE surrogate learning. FLARE is built on the argument that projecting input sequences onto shorter latent sequences, and then unprojecting to the original sequence length, is equivalent to constructing a low-rank form of attention with rank at most equal to the number of latent tokens (see figure below).

Furthermore, we argue that multiple simultaneous low-rank projections could collectively capture a full attention pattern. Unlike Transolver which shares projection weights across heads, or LNO which applies only a single projection, our design allocates a distinct slice of the latent tokens to each head resulting in distinct projection matrices for each head. This allows each head to learn independent attention relationships, opening up a key direction of scaling and exploration, wherein each head may specialize in distinct routing patterns.

FLARE Architecture

FLARE exhibits excellent scaling and can tackle problems with millions of tokens on a single GPU. We present time and memory requirements of different attention schemes. On an input sequence of one million tokens, FLARE (red) is over $200\times$ faster than vanilla attention, while consuming marginally more memory. All models are implemented with flash attention (Dao et al., 2022), and the memory upper bound on a single H100 80GB GPU is depicted with a dashed line. Note that the curves for FLARE are somewhat overlapping.

FLARE scaling

The implementation of FLARE is straightforward and employs highly optimized fused self-attention kernels.

import torch.nn.functional as F
def flare_multihead_mixer(q, k, v):
    """
    Arguments:
    q: Query tensor [H, M, D]
    k: Key tensor [B, H, N, D]
    v: Value tensor [B, H, N, D]
    Returns:
    y: Output tensor [B, H, N, D]
    """

    z = F.scaled_dot_product_attention(q, k, v, scale=1.0)
    y = F.scaled_dot_product_attention(k, q, z, scale=1.0)

    return y

Benchmark dataset of additive manufacturing (AM) simulations.

We simulate the LPBF process on selected geometries from the Autodesk segementation dataset (Lambourne et al., 2021) to generate a benchmark dataset for AM calculations. Several geometries are presented in this gallery. The color indicates Z (vertical) displacement field.

FLARE Architecture

πŸ—οΈ Codebase Architecture

This codebase implements the FLARE architecture and is built upon the mlutils.py framework, which provides foundational ML training infrastructure with multi-GPU support, extendable trainer classes, and callback systems.

The project is organized into several key packages:

pdebench/ - Main PDE Benchmarking Framework

  • Models: Implementation of FLARE alongside state-of-the-art neural PDE surrogates
    • flare.py: Core FLARE architecture with linear complexity attention
    • transolver.py: Transolver baseline model
    • lno.py: Linear Neural Operator
    • transformer.py: Standard transformer architectures
    • gnot.py: Geometry-aware Neural Operator
    • perceiver.py: PerceiverIO architecture
  • Datasets: Comprehensive PDE dataset loading and preprocessing
    • utils.py: Dataset utilities and transformations
  • Callbacks: Training monitoring, evaluation, and visualization

am/ - Additive Manufacturing Specialization

  • Models: Specialized architectures for AM simulations
    • meshGNN.py: Graph neural networks for mesh data
  • Datasets: LPBF (Laser Powder Bed Fusion) data processing
    • sdf.py: Signed distance function utilities
    • extraction.py: Feature extraction from AM simulations
    • filtering.py: Data filtering and preprocessing
  • Visualization: 3D visualization tools for AM geometries

mlutils/ - Core ML Framework (from mlutils.py)

  • trainer.py: Distributed training with checkpointing and restart capabilities
  • callbacks.py: Extensible callback system for monitoring and analysis
  • utils.py: General ML utilities and helper functions

ablation/ - Performance Analysis Suite

  • Scaling experiments: scale_dml.py, time_memory_*.py
  • Architecture ablations: ablate_num_heads.py, ablate_num_layers.py, ablate_num_blocks.py
  • Memory and timing benchmarks with Flash Attention comparisons

πŸš€ Key Features

Scalable Training Infrastructure

  • Multi-GPU/multi-node training with torchrun
  • Automatic checkpointing and restart capabilities
  • Mixed precision training (FP16/FP32)
  • Comprehensive logging and monitoring

Flexible Model Zoo

  • FLARE and many of the state-of-the-art neural PDE surrogates
  • Modular architecture for easy experimentation

πŸ’» Installation

Clone the repository and run the installation script:

git clone https://github.com/vpuri3/FLARE.py.git
cd FLARE.py
chmod +x scripts/install.sh
./scripts/install.sh

The installer will:

  • Set up Python 3.11 virtual environment with uv
  • Install PyTorch with CUDA support
  • Install all required dependencies
  • Optionally install Flash Attention for optimal performance
  • Optionally install LaTeX for publication-quality plots

πŸ“Š Datasets

This codebase supports a variety of PDE datasets. You can download them using the built-in dataset utility:

git clone https://github.com/vpuri3/FLARE.py.git
cd FLARE.py
chmod +x scripts/download_data.sh
./scripts/download_data.sh

🎯 Usage

Training

Single GPU training:

uv run python -m pdebench --train true --dataset elasticity --exp_name flare_elas --model_type 2 --epochs 100 ...

Multi-GPU training:

uv run torchrun --nproc-per-node 2 -m pdebench --train true --dataset flare_darcy --exp_name flare_elasticity --model_type 2 --epochs 100 ...

Training hyperparameters can be modified with the following command-line arguments:

$ uv run python -m pdebench --help
usage: __main__.py [-h] [--config CONFIG] [--print_config[=flags]] [--train {true,false}]
                   [--evaluate {true,false}] [--restart {true,false}] [--exp_name EXP_NAME]
                   [--seed SEED] [--dataset DATASET] [--num_workers NUM_WORKERS] [--epochs EPOCHS]
                   [--batch_size BATCH_SIZE] [--weight_decay WEIGHT_DECAY]
                   [--learning_rate LEARNING_RATE] [--schedule SCHEDULE]
                   [--one_cycle_pct_start ONE_CYCLE_PCT_START]
                   [--one_cycle_div_factor ONE_CYCLE_DIV_FACTOR]
                   [--one_cycle_final_div_factor ONE_CYCLE_FINAL_DIV_FACTOR]
                   [--one_cycle_three_phase {true,false}] [--opt_beta1 OPT_BETA1]
                   [--opt_beta2 OPT_BETA2] [--opt_eps OPT_EPS] [--clip_grad_norm CLIP_GRAD_NORM]
                   [--optimizer OPTIMIZER] [--mixed_precision {true,false}]
                   [--attn_backend ATTN_BACKEND] [--timing_only {true,false}] [--model_type MODEL_TYPE]
                   [--conv2d {true,false}] [--unified_pos {true,false}] [--act ACT]
                   [--channel_dim CHANNEL_DIM] [--num_blocks NUM_BLOCKS] [--num_heads NUM_HEADS]
                   [--num_latents NUM_LATENTS] [--num_layers_kv_proj NUM_LAYERS_KV_PROJ]
                   [--num_layers_mlp NUM_LAYERS_MLP] [--num_layers_in_out_proj NUM_LAYERS_IN_OUT_PROJ]
                   [--mlp_ratio MLP_RATIO] [--kv_proj_ratio KV_PROJ_RATIO]
                   [--in_out_proj_ratio IN_OUT_PROJ_RATIO] [--out_proj_ln {true,false}]

Each training run will create a directory in out/pdebench where it would store checkpoints.

$ tree out/pdebench/ -L 2
out/pdebench
β”œβ”€β”€ flare_elas
β”‚   β”œβ”€β”€ ckpt01
β”‚   β”œβ”€β”€ ...
β”‚   β”œβ”€β”€ ckpt10
β”‚   β”œβ”€β”€ config.yaml
β”‚   β”œβ”€β”€ grad_norm.png
β”‚   β”œβ”€β”€ learning_rate.png
β”‚   β”œβ”€β”€ losses.png
β”‚   β”œβ”€β”€ rel_error.json
β”‚   └── model_stats.json
└── flare_darcy
    β”œβ”€β”€ ckpt01
    β”œβ”€β”€ ...
    β”œβ”€β”€ ckpt10
    β”œβ”€β”€ config.yaml
    β”œβ”€β”€ grad_norm.png
    β”œβ”€β”€ learning_rate.png
    β”œβ”€β”€ losses.png
    β”œβ”€β”€ rel_error.json
    └── model_stats.json

Evaluation

Load and evaluate a trained model:

python -m pdebench --eval true --exp_name flare_elasticity

Configuration

All experiments are managed through YAML configuration files with comprehensive command-line override support. Results are automatically organized in the out/ directory with:

  • Model checkpoints
  • Training logs and metrics
  • Evaluation results and visualizations
  • Configuration snapshots

πŸ“Š Datasets

PDE Benchmarks

  • Supports multiple standard PDE benchmark datasets
  • Scalable data loading for large mesh datasets
  • Flexible preprocessing and augmentation pipelines

Additive Manufacturing Dataset

  • New benchmark dataset with LPBF simulations
  • Generated on Autodesk segmentation geometries
  • Includes displacement fields and thermal histories

πŸ”„ Reproducibility

We are committed to ensuring the reproducibility of our research results. Our main results can be reproduced by running the script:

chmod +x ./out/pdebench/run_comp.sh
./out/pdebench/run_comp.sh

πŸ”¬ Research Applications

  • Neural PDE Surrogates: Fast approximation of expensive PDE solvers
  • Point Cloud Processing: Large-scale geometric deep learning
  • Scientific Computing: Scalable transformer architectures for irregular data

Bibtex

@misc{puri2025flare,
      title={{FLARE}: {F}ast {L}ow-rank {A}ttention {R}outing {E}ngine}, 
      author={Vedant Puri and Aditya Joglekar and Kevin Ferguson and Yu-hsuan Chen and Yongjie Jessica Zhang and Levent Burak Kara},
      year={2025},
      eprint={2508.12594},
      archivePrefix={arXiv},
      primaryClass={cs.LG},
      url={https://arxiv.org/abs/2508.12594}, 
}

About

Fast Low-rank Attention Routing Engine

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published