Skip to content

High Memory Consumption (~80GB) with Batch Size=1 on SOC Dataset: Is this expected? #76

@WendyYW

Description

@WendyYW

Hello,

I am currently using this project to train on the SOC (Spin-Orbit Coupling) dataset. I have encountered a significant memory consumption issue where the VRAM usage is unexpectedly high.

Current Behavior I am running the training on a node equipped with 8x NVIDIA H100 (80GB) GPUs. Even with batch_size=1, the GPU memory usage almost hits the limit, reaching approximately 80,305 MiB per GPU.

Environment:

Hardware: NVIDIA H100 80GB HBM3 (x8)

CUDA Version: 12.4

Dataset: SOC Data

Screenshots

Image

Questions

Is this memory usage expected? Given the complexity of the SOC data, is it normal for a single sample to consume ~80GB of VRAM?

Config Check: Could you please check my attached config? Are there any specific parameters (e.g., related to layer size or precision) that I should adjust to lower the memory footprint?

Optimization: Any suggestions on how to optimize this to allow for a larger batch size?

Thank you for your help!

Configuration Below is the configuration file I am using. I suspect there might be some parameters here contributing to the high memory usage.

dataset_params:
  batch_size: 1
  split_file: null
  test_ratio: 0.1
  train_ratio: 0.8
  val_ratio: 0.1
  graph_data_path: ./test_dir
  num_workers: 16
losses_metrics:
  losses:
  - loss_weight: 27.211
    metric: mae
    prediction: hamiltonian
    target: hamiltonian
  metrics:
  - metric: mae
    prediction: hamiltonian
    target: hamiltonian
optim_params:
  lr: 0.01
  lr_decay: 0.5
  lr_patience: 5
  gradient_clip_val: 0.0
  max_epochs: 6000
  min_epochs: 100
  stop_patience: 30

output_nets:
  output_module: HamGNN_out
  HamGNN_out:
    ham_only: true # true: Only the Hamiltonian H is computed; 'false': Fit both H and S
    ham_type: openmx # openmx: fit openmx Hamiltonian; abacus: fit abacus Hamiltonian
    nao_max: 26 # The maximum number of atomic orbitals in the data set, which can be 14, 19 or 27
    add_H0: True # Generally true, the complete Hamiltonian is predicted as the sum of H_scf plus H_nonscf (H0)
    symmetrize: True # if set to true, the Hermitian symmetry constraint is imposed on the Hamiltonian
    calculate_band_energy: False # Whether to calculate the energy bands to train the model
    num_k: 5 # When calculating the energy bands, the number of K points to use
    band_num_control: 8 # `dict`: controls how many orbitals are considered for each atom in energy bands; `int`: [vbm-num, vbm+num]; `null`: all bands
    k_path: null # `auto`: Automatically determine the k-point path; `null`: random k-point path; `list`: list of k-point paths provided by the user
    soc_switch: True # if true, fit the SOC Hamiltonian
    nonlinearity_type: gate # norm or gate
    # spin constrained
    spin_constrained: False
    collinear_spin: False
    minMagneticMoment: 0.5

profiler_params:
  progress_bar_refresh_rat: 1
  train_dir: ./train_output_1w #The folder for saving training information and prediction results. This directory can be read by tensorboard to monitor the training process.

representation_nets:
  # Network parameters usually do not need to be changed.
  HamGNN_pre:
    cutoff: 26.0
    cutoff_func: cos
    edge_sh_normalization: component
    edge_sh_normalize: true
    ######## Irreps set 1 (crystal): ################
    irreps_edge_sh: 0e + 1o + 2e + 3o + 4e + 5o
    irreps_node_features: 64x0e+64x0o+32x1o+16x1e+12x2o+25x2e+18x3o+9x3e+4x4o+9x4e+4x5o+4x5e+2x6e
    num_layers: 3
    num_radial: 64
    num_types: 96
    rbf_func: bessel
    set_features: true
    radial_MLP: [64, 64]
    use_corr_prod: False
    correlation: 2
    num_hidden_features: 16
    use_kan: False
    radius_scale: 1.01
    build_internal_graph: False

setup:
  GNN_Net: HamGNNpre
  accelerator: null
  ignore_warnings: true
  checkpoint_path: null # Path to the model weights file
  load_from_checkpoint: False
  resume: False
  num_gpus: 8 # null: use cpu; [i]: use the ith GPU device
  precision: 32
  property: hamiltonian
  stage: fit # fit: training; test: inference

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions