From 88c13246a6f60ec4bc844dffd21c8159b9b983e7 Mon Sep 17 00:00:00 2001
From: Alberto Cattaneo <albertoc@graphcore.ai>
Date: Mon, 22 Jan 2024 16:32:45 +0000
Subject: [PATCH] change RotatE relation initialization

---
 besskge/embedding.py | 16 ++++++++++++++++
 besskge/scoring.py   |  3 ++-
 2 files changed, 18 insertions(+), 1 deletion(-)

diff --git a/besskge/embedding.py b/besskge/embedding.py
index 86244da..91ed36e 100644
--- a/besskge/embedding.py
+++ b/besskge/embedding.py
@@ -44,6 +44,22 @@ def init_xavier_norm(embedding_table: torch.Tensor, gain: float = 1.0) -> torch.
     )
 
 
+def init_uniform_rotation(embedding_table: torch.Tensor) -> torch.Tensor:
+    r"""
+    Initialize tensor with each entry being a uniformly distributed
+    phase between 0 and :math:`2 \pi`.
+    To be used for initialization of relation embedding tables
+    in RotatE scoring function.
+
+    :param embedding_table:
+        Tensor of embedding parameters to initialize.
+
+    :return:
+        Initialized tensor.
+    """
+    return torch.rand_like(embedding_table) * 2 * np.pi
+
+
 def init_KGE_uniform(
     embedding_table: torch.Tensor, b: float = 1.0, divide_by_embedding_size: bool = True
 ) -> torch.Tensor:
diff --git a/besskge/scoring.py b/besskge/scoring.py
index 7c0b985..063d4a7 100644
--- a/besskge/scoring.py
+++ b/besskge/scoring.py
@@ -15,6 +15,7 @@
     init_KGE_normal,
     init_KGE_uniform,
     init_uniform_norm,
+    init_uniform_rotation,
     init_xavier_norm,
     initialize_entity_embedding,
     initialize_relation_embedding,
@@ -369,7 +370,7 @@ def __init__(
             init_KGE_uniform
         ],
         relation_initializer: Union[torch.Tensor, List[Callable[..., torch.Tensor]]] = [
-            init_KGE_uniform
+            init_uniform_rotation
         ],
         inverse_relations: bool = False,
     ) -> None: