Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Small equivariant example #5

Open
DeNeutoy opened this issue Sep 12, 2023 · 11 comments
Open

Small equivariant example #5

DeNeutoy opened this issue Sep 12, 2023 · 11 comments

Comments

@DeNeutoy
Copy link

Hi @yilunliao,

Thanks for the nice codebase - I am adapting it for another purpose, and I was running into some issues when checking the outputs are actually equivariant. Are there any init flags that must be set in a certain way to guarantee equivariance?

I have a snippet equivalent to this:

import torch_geometric
import torch
from e3nn import o3
from torch_geometric.data import Data
from nets.equiformer_v2.equiformer_v2_oc20 import EquiformerV2_OC20

edge_index = torch.tensor([[0, 1, 1, 2],
                           [1, 0, 2, 1]], dtype=torch.long)
pos = torch.randn(10, 3)
data = Data(pos=pos, edge_index=edge_index)

R = torch.tensor(o3.rand_matrix())

model = EquiformerV2_OC20(
        num_layers=2,
        attn_hidden_channels=16,
        ffn_hidden_channels=16,
        sphere_channels=16,
        edge_channels=16,
        alpha_drop=0.0, # Turn off dropout for eq
        drop_path_rate=0.0, # Turn off drop path for eq
    )

energy1, forces1 = model(data)
rotated_pos = torch.matmul(pos, R)
data.pos = rotated_pos
energy2, forces2 = model(data)

assert energy1 == energy2
assert torch.allclose(forces1, torch.matmul(forces2, R), atol=1.0e-3)

and the energies are equal, but the forces do not obey equality under rotation. I've turned off all dropout and set the model to eval - just wondering if there are any other tricks to retain the genuine eq behaviour. Thanks!

@yilunliao
Copy link
Member

Hi @DeNeutoy

Can you let me know how large the difference is?
After scanning over the code, can you also make sure you should compare this:

 assert torch.allclose(torch.matmul(forces1, R), forces2, atol=1.0e-3)

not this:

 assert torch.allclose(forces1, torch.matmul(forces2, R), atol=1.0e-3)

I think it is because force2 has input positions rotated, and therefore you have to rotate force2 back or rotate the output force1 before comparing the two force outputs.

Best

@DeNeutoy
Copy link
Author

DeNeutoy commented Sep 12, 2023

Thanks for the response @yilunliao - sorry, I was rotating the forces1 and comparing as you suggested - this was a bug in my snippet, but not my actual code. I dug into this a little and when I look at the nodes which are not correct, I see this:

print((forces1 == forces2).all(-1).all(-1))
tensor([False, False, False,  True,  True,  True,  True,  True,  True,  True])

indicating that it is only nodes which have edges that are affected. I then confirmed this by modifying the input graph, and only the receivers changes this, e.g:

senders = torch.tensor([0, 1, 2, 1, 2, 0])
receivers = torch.tensor([1, 0, 1, 2, 0, 4])
# Rerun, get:
tensor([False, False, False,  True, **False**,  True,  True,  True,  True,  True])

but changing the senders doesn't:

senders = torch.tensor([0, 1, 2, 1, 2, 4])
receivers = torch.tensor([1, 0, 1, 2, 0, 0])
# Rerun, get:
tensor([False, False, False,  True, True,  True,  True,  True,  True,  True])

I was wondering if this might alert you to something? I then started stepping through the code, and the embeddings are equal up until the edge degree embedding, but if I remove this, they are then unequal again after the TransBlockV2 stack.

For the node indices which don't match, the absolute difference is large:

print(torch.abs(forces1, forces2)[:3, :, :].mean())
1.0051

This is with a completely untrained model, although I wouldn't expect that to make a difference.

Any help is much appreciated!

@yilunliao
Copy link
Member

@DeNeutoy

I see.

Can you make sure the results of edge-degree embeddings satisfy equivariance constraint?
Since this embedding only uses rotation and applies linear layers to m = 0 components, this is strictly equivariant and can be easily tested.

I have limited bandwidth until the weekend or next week but will look at this and provide another example. (Ping me if you have not heard from me)

@yilunliao
Copy link
Member

@DeNeutoy

Sorry for the late reply.

Have you figured out the issue?
If no, can you please update me on what you think is the problem?

I have an incoming deadline, so I would be late to response, but I will make sure we can find the reason.

Best

@DeNeutoy
Copy link
Author

Hi @yilunliao ,

I haven't, unfortunately. I tried looking into the edge degree embeddings, but it's not as simple as looking at a rotation of the input vectors - the edge_embedding outputs a SO3_embedding object, which internally has a _rotate method which is defined by the SO3_rotations + wigner matrices defined by the model's forward pass. So it was kind of unclear to me how to "unrotate" the embeddings.

If you had a small example, that would be helpful - but I understand if this is difficult to produce. These things are quite complex!

@yilunliao
Copy link
Member

Hi @DeNeutoy .

Here is how we rotate the embedding back to the original coordinate after SO(2) linear layers:
https://github.com/atomicarchitects/equiformer_v2/blob/main/nets/equiformer_v2/so3.py#L452

Sure. I can provide a simple example to test that, but I will do that next weekend due to an incoming deadline.

Best

@BurgerAndreas
Copy link

Hi @yilunliao,

A small example would be extremely useful to build on your codebase!
Could you help us out?

@pavlo-melnyk
Copy link

Hello @yilunliao,

Thank you for your great work on Equiformers and the codebases.

When testing the equivariance and invariance under O(3) actions for forces and energy predictions, respectively, I observed inconsistencies with EquiformerV2_OC20.
Namely, the outputs of the network are not equivariant/invariant.

What is odd is that the equi-/invariance error seems to grow after (successful) training (on my custom dataset):

Before training

Invariance error:
(energy1 - energy2).abs() = 0.0019094645977020264

Equivariance error:
(forces1 - forces2 @ R.T).abs().max() =        0.040991928428411484

just in case:
(forces1 - forces2 @ R).abs().max() =        0.03977788984775543

After training (the errors vary depending on the seed):

Invariance error:
(energy1 - energy2).abs() = 4.39306640625

Equivariance error:
(forces1 - forces2 @ R.T).abs().max() =        2.744657278060913

just in case:
(forces1 - forces2 @ R).abs().max() =        2.7617990970611572

Here's a snippet of the code I use to test the equivariance for a random input to a randomly initialized Equiformer:

import torch

from equiformer_v2.nets.equiformer_v2.equiformer_v2_oc20 import EquiformerV2_OC20
from e3nn import o3
from types import SimpleNamespace
from pytorch_lightning import seed_everything


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

seed_everything(42, workers=True)

model_config = {
    "use_pbc": True, 
    "regress_forces": True,
    "otf_graph": True,  
    "max_neighbors": 20,  
    "max_radius": 12.0,  
    "max_num_elements": 23,  # this is what I use for the custom dataset with the highest atomic number being 22; I tried different values when inputting a random sample
    "num_layers": 8,
    "sphere_channels": 128,
    "attn_hidden_channels": 64,
    "num_heads": 8,
    "attn_alpha_channels": 64,
    "attn_value_channels": 16,
    "ffn_hidden_channels": 128,
    "norm_type": "layer_norm_sh",
    "lmax_list": [4],
    "mmax_list": [2],
    "grid_resolution": 18,
    "num_sphere_samples": 128,
    "edge_channels": 128,
    "use_atom_edge_embedding": True,
    "share_atom_edge_embedding": False,
    "distance_function": "gaussian",
    "num_distance_basis": 512,
    "attn_activation": "silu",
    "use_s2_act_attn": False,
    "use_attn_renorm": True,
    "ffn_activation": "silu",
    "use_gate_act": False,
    "use_grid_mlp": True,
    "use_sep_s2_act": True,
    "alpha_drop": 0.1,
    "drop_path_rate": 0.1,
    "proj_drop": 0.0,
    "weight_init": "uniform",
}
model = EquiformerV2_OC20(None, None, None, **model_config).to(device) # also tried with the standard params in the constructor, i.e., without model_config

N = 90 # num atoms

data = {
    "pos": torch.randn(N, 3),
    "pbc": torch.tensor([[True, True, True]]), 
    "atomic_numbers": torch.ones(N).long(), 
    "cell": torch.randn(1, 3, 3), 
    "natoms": torch.tensor([N]),
    "batch": torch.zeros(N).long(),
}

data = {k: v.to(device) for k, v in data.items()}

# convert the input to the right format:
data = SimpleNamespace(**data)

# equivariance test:
pos = data.pos
R = torch.tensor(o3.rand_matrix()).to(device)
model.eval()
with torch.no_grad():
    energy1, forces1 = model(data)
    
    rotated_pos = torch.matmul(pos, R)
    data.pos = rotated_pos
    energy2, forces2 = model(data)

print(energy1, energy2)

print(f"\nInvariance error:\n(energy1 - energy2).abs() = {(energy1 - energy2).abs()}")

print(f"\nEquivariance error:\n(forces1 - forces2 @ R.T).abs().max() = \
       {(forces1 - torch.matmul(forces2, R.transpose(-1, -2)).detach() ).abs().max()}")

print(f"\njust in case:\n(forces1 - forces2 @ R).abs().max() = \
       {(forces1 - torch.matmul(forces2, R).detach() ).abs().max()}")

resulting in

Invariance error:
(energy1 - energy2).abs() = 0.0036049485206604004

Equivariance error:
(forces1 - forces2 @ R.T).abs().max() =        0.0455518402159214

just in case:
(forces1 - forces2 @ R).abs().max() =        0.03853069245815277

I understand that some numerical errors in equi-/invariance are expected, but the above examples, especially before and after training, seem significant.

I have tried different hyperparameters for the model but couldn't pinpoint the cause.

 Could you clarify if this behaviour is expected or if there might be an issue in my setup? 

Any guidance would be greatly appreciated

Best,
Pavló

@yilunliao
Copy link
Member

Hi @pavlo-melnyk

Thanks for your question.

At the first glance, I am not 100% sure what could be the problem.

I think you also need to rotate cell in data in addition to pos.
For simplicity, I would suggest that you turn off periodic boundary conditions by using pbc = torch.Tensor([False, False, False]).

Let me take a deeper look into this and get back to you later this week.

@pavlo-melnyk
Copy link

pavlo-melnyk commented Feb 12, 2025

Hi @yilunliao,

Thank you for pointing this out. You're absolutely correct.
My bad: I totally forgot about the cells also being rotation-dependent.

I've now verified it:

# SO(3)-equivariance test:
pos = data.pos
cell = data.cell
R = torch.tensor(o3.rand_matrix()).to(device) # det(R) = +1
model.eval()
with torch.no_grad():
    energy1, forces1 = model(data)
    
    rotated_pos = torch.matmul(pos, R)
    rotated_cell = torch.matmul(cell, R)
    data.pos = rotated_pos
    data.cell = rotated_cell

    energy2, forces2 = model(data)

For a random sample and a randomly initialized Equiformer, the output is

Invariance error:
(energy1 - energy2).abs() = 4.842877388000488e-08

Equivariance error:
(forces1 - forces2 @ R.T).abs().max() = 7.970724254846573e-06

and for a sample from my dataset and a trained Equiformer:

Invariance error:
(energy1 - energy2).abs() = 0.000152587890625

Equivariance error:
(forces1 - forces2 @ R.T).abs().max() =  0.0003859400749206543

which is reasonable.

Thanks a lot!

@yilunliao
Copy link
Member

Hi @pavlo-melnyk

Great.

If you compare relative errors (e.g., torch.max(torch.abs((forces1 - firces2 @ R.T) / (forces1 + 1e-6)))), I guess the error should be negligible for the trained model.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants