Skip to content

Commit 3ab63a7

Browse files
committed
cache atom types & bond types once computed
1 parent e9c02e1 commit 3ab63a7

File tree

1 file changed

+25
-25
lines changed

1 file changed

+25
-25
lines changed

torchdrug/data/dataset.py

Lines changed: 25 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -414,6 +414,30 @@ class KnowledgeGraphDataset(torch_data.Dataset, core.Configurable):
414414
The whole dataset contains one knowledge graph.
415415
"""
416416

417+
def load_triplet(self, triplets, entity_vocab=None, relation_vocab=None, inv_entity_vocab=None,
418+
inv_relation_vocab=None):
419+
"""
420+
Load the dataset from triplets.
421+
The mapping between indexes and tokens is specified through either vocabularies or inverse vocabularies.
422+
423+
Parameters:
424+
triplets (array_like): triplets of shape :math:`(n, 3)`
425+
entity_vocab (dict of str, optional): maps entity indexes to tokens
426+
relation_vocab (dict of str, optional): maps relation indexes to tokens
427+
inv_entity_vocab (dict of str, optional): maps tokens to entity indexes
428+
inv_relation_vocab (dict of str, optional): maps tokens to relation indexes
429+
"""
430+
entity_vocab, inv_entity_vocab = self._standarize_vocab(entity_vocab, inv_entity_vocab)
431+
relation_vocab, inv_relation_vocab = self._standarize_vocab(relation_vocab, inv_relation_vocab)
432+
433+
num_node = len(entity_vocab) if entity_vocab else None
434+
num_relation = len(relation_vocab) if relation_vocab else None
435+
self.graph = data.Graph(triplets, num_node=num_node, num_relation=num_relation)
436+
self.entity_vocab = entity_vocab
437+
self.relation_vocab = relation_vocab
438+
self.inv_entity_vocab = inv_entity_vocab
439+
self.inv_relation_vocab = inv_relation_vocab
440+
417441
def load_tsv(self, tsv_file, verbose=0):
418442
"""
419443
Load the dataset from a tsv file.
@@ -483,30 +507,6 @@ def load_tsvs(self, tsv_files, verbose=0):
483507
self.load_triplet(triplets, inv_entity_vocab=inv_entity_vocab, inv_relation_vocab=inv_relation_vocab)
484508
self.num_samples = num_samples
485509

486-
def load_triplet(self, triplets, entity_vocab=None, relation_vocab=None, inv_entity_vocab=None,
487-
inv_relation_vocab=None):
488-
"""
489-
Load the dataset from triplets.
490-
The mapping between indexes and tokens is specified through either vocabularies or inverse vocabularies.
491-
492-
Parameters:
493-
triplets (array_like): triplets of shape :math:`(n, 3)`
494-
entity_vocab (dict of str, optional): maps entity indexes to tokens
495-
relation_vocab (dict of str, optional): maps relation indexes to tokens
496-
inv_entity_vocab (dict of str, optional): maps tokens to entity indexes
497-
inv_relation_vocab (dict of str, optional): maps tokens to relation indexes
498-
"""
499-
entity_vocab, inv_entity_vocab = self._standarize_vocab(entity_vocab, inv_entity_vocab)
500-
relation_vocab, inv_relation_vocab = self._standarize_vocab(relation_vocab, inv_relation_vocab)
501-
502-
num_node = len(entity_vocab) if entity_vocab else None
503-
num_relation = len(relation_vocab) if relation_vocab else None
504-
self.graph = data.Graph(triplets, num_node=num_node, num_relation=num_relation)
505-
self.entity_vocab = entity_vocab
506-
self.relation_vocab = relation_vocab
507-
self.inv_entity_vocab = inv_entity_vocab
508-
self.inv_relation_vocab = inv_relation_vocab
509-
510510
def _standarize_vocab(self, vocab, inverse_vocab):
511511
if vocab is not None:
512512
if isinstance(vocab, dict):
@@ -609,7 +609,7 @@ def round_to_boundary(i):
609609

610610
if key_lengths is not None:
611611
assert lengths is None
612-
key2count = torch.bincount(keys)
612+
key2count = keys.bincount()
613613
key_offset = 0
614614
lengths = []
615615
for key_length in key_lengths:

0 commit comments

Comments
 (0)