diff --git a/graphsage/model.py b/graphsage/model.py index aeca282..1825b55 100644 --- a/graphsage/model.py +++ b/graphsage/model.py @@ -76,8 +76,8 @@ def run_cora(): agg2 = MeanAggregator(lambda nodes : enc1(nodes).t(), cuda=False) enc2 = Encoder(lambda nodes : enc1(nodes).t(), enc1.embed_dim, 128, adj_lists, agg2, base_model=enc1, gcn=True, cuda=False) - enc1.num_samples = 5 - enc2.num_samples = 5 + enc1.num_sample = 5 + enc2.num_sample = 5 graphsage = SupervisedGraphSage(7, enc2) # graphsage.cuda() @@ -148,8 +148,8 @@ def run_pubmed(): agg2 = MeanAggregator(lambda nodes : enc1(nodes).t(), cuda=False) enc2 = Encoder(lambda nodes : enc1(nodes).t(), enc1.embed_dim, 128, adj_lists, agg2, base_model=enc1, gcn=True, cuda=False) - enc1.num_samples = 10 - enc2.num_samples = 25 + enc1.num_sample = 10 + enc2.num_sample = 25 graphsage = SupervisedGraphSage(3, enc2) # graphsage.cuda()