@@ -414,6 +414,30 @@ class KnowledgeGraphDataset(torch_data.Dataset, core.Configurable):
414414 The whole dataset contains one knowledge graph.
415415 """
416416
417+ def load_triplet (self , triplets , entity_vocab = None , relation_vocab = None , inv_entity_vocab = None ,
418+ inv_relation_vocab = None ):
419+ """
420+ Load the dataset from triplets.
421+ The mapping between indexes and tokens is specified through either vocabularies or inverse vocabularies.
422+
423+ Parameters:
424+ triplets (array_like): triplets of shape :math:`(n, 3)`
425+ entity_vocab (dict of str, optional): maps entity indexes to tokens
426+ relation_vocab (dict of str, optional): maps relation indexes to tokens
427+ inv_entity_vocab (dict of str, optional): maps tokens to entity indexes
428+ inv_relation_vocab (dict of str, optional): maps tokens to relation indexes
429+ """
430+ entity_vocab , inv_entity_vocab = self ._standarize_vocab (entity_vocab , inv_entity_vocab )
431+ relation_vocab , inv_relation_vocab = self ._standarize_vocab (relation_vocab , inv_relation_vocab )
432+
433+ num_node = len (entity_vocab ) if entity_vocab else None
434+ num_relation = len (relation_vocab ) if relation_vocab else None
435+ self .graph = data .Graph (triplets , num_node = num_node , num_relation = num_relation )
436+ self .entity_vocab = entity_vocab
437+ self .relation_vocab = relation_vocab
438+ self .inv_entity_vocab = inv_entity_vocab
439+ self .inv_relation_vocab = inv_relation_vocab
440+
417441 def load_tsv (self , tsv_file , verbose = 0 ):
418442 """
419443 Load the dataset from a tsv file.
@@ -483,30 +507,6 @@ def load_tsvs(self, tsv_files, verbose=0):
483507 self .load_triplet (triplets , inv_entity_vocab = inv_entity_vocab , inv_relation_vocab = inv_relation_vocab )
484508 self .num_samples = num_samples
485509
486- def load_triplet (self , triplets , entity_vocab = None , relation_vocab = None , inv_entity_vocab = None ,
487- inv_relation_vocab = None ):
488- """
489- Load the dataset from triplets.
490- The mapping between indexes and tokens is specified through either vocabularies or inverse vocabularies.
491-
492- Parameters:
493- triplets (array_like): triplets of shape :math:`(n, 3)`
494- entity_vocab (dict of str, optional): maps entity indexes to tokens
495- relation_vocab (dict of str, optional): maps relation indexes to tokens
496- inv_entity_vocab (dict of str, optional): maps tokens to entity indexes
497- inv_relation_vocab (dict of str, optional): maps tokens to relation indexes
498- """
499- entity_vocab , inv_entity_vocab = self ._standarize_vocab (entity_vocab , inv_entity_vocab )
500- relation_vocab , inv_relation_vocab = self ._standarize_vocab (relation_vocab , inv_relation_vocab )
501-
502- num_node = len (entity_vocab ) if entity_vocab else None
503- num_relation = len (relation_vocab ) if relation_vocab else None
504- self .graph = data .Graph (triplets , num_node = num_node , num_relation = num_relation )
505- self .entity_vocab = entity_vocab
506- self .relation_vocab = relation_vocab
507- self .inv_entity_vocab = inv_entity_vocab
508- self .inv_relation_vocab = inv_relation_vocab
509-
510510 def _standarize_vocab (self , vocab , inverse_vocab ):
511511 if vocab is not None :
512512 if isinstance (vocab , dict ):
@@ -609,7 +609,7 @@ def round_to_boundary(i):
609609
610610 if key_lengths is not None :
611611 assert lengths is None
612- key2count = torch .bincount (keys )
612+ key2count = keys .bincount ()
613613 key_offset = 0
614614 lengths = []
615615 for key_length in key_lengths :
0 commit comments