diff --git a/node2vec_spark/pom.xml b/node2vec_spark/pom.xml
index b958576..92beef2 100644
--- a/node2vec_spark/pom.xml
+++ b/node2vec_spark/pom.xml
@@ -17,7 +17,7 @@
2.4.3
1.4.0
1.7
- 2.10
+ 2.11
@@ -124,6 +124,15 @@
guava
19.0
+
+
+ neo4j-contrib
+ neo4j-spark-connector
+ 2.4.5-M1
+ system
+ ${project.basedir}/src/main/resources/neo4j-spark-connector-full-2.4.5-M1.jar
+
+
diff --git a/node2vec_spark/run_command.sh b/node2vec_spark/run_command.sh
new file mode 100755
index 0000000..f29414f
--- /dev/null
+++ b/node2vec_spark/run_command.sh
@@ -0,0 +1,21 @@
+#!/bin/bash
+
+
+if [ "$#" -lt 2 ]; then
+ echo "Illegal number of parameters"
+ echo "Usage: ./run_command.sh "
+ exit 1
+fi
+
+neo_query="$1"
+out_name="$2"
+
+. /root/env/bin/activate
+
+export JAVA_HOME=/usr/lib/jvm/java-1.8.0-openjdk-1.8.0.252.b09-2.el7_8.x86_64/jre
+
+
+DIR=$(dirname $0)
+
+time spark-submit --num-executors=384 --driver-memory 12g --executor-memory 24g --jars /opt/sparkx/neo4j-spark-connector-full-2.4.5-M1.jar --conf spark.neo4j.user=neo4j --conf spark.neo4j.password=test --class com.navercorp.Main $DIR/target/node2vec-0.0.1-SNAPSHOT.jar --cmd neo2vec --dim 10 --p 100.0 --q 100.0 --walkLength 5 --output "$out_name" --input "kirokhayeh.bin" --neoQuery "$neo_query"
+
diff --git a/node2vec_spark/src/main/resources/neo4j-spark-connector-2.4.5-M1-sources.jar b/node2vec_spark/src/main/resources/neo4j-spark-connector-2.4.5-M1-sources.jar
new file mode 100644
index 0000000..213a0f8
Binary files /dev/null and b/node2vec_spark/src/main/resources/neo4j-spark-connector-2.4.5-M1-sources.jar differ
diff --git a/node2vec_spark/src/main/resources/neo4j-spark-connector-2.4.5-M1.jar b/node2vec_spark/src/main/resources/neo4j-spark-connector-2.4.5-M1.jar
new file mode 100644
index 0000000..5c8d650
Binary files /dev/null and b/node2vec_spark/src/main/resources/neo4j-spark-connector-2.4.5-M1.jar differ
diff --git a/node2vec_spark/src/main/resources/neo4j-spark-connector-full-2.4.5-M1.jar b/node2vec_spark/src/main/resources/neo4j-spark-connector-full-2.4.5-M1.jar
new file mode 100644
index 0000000..def51fa
Binary files /dev/null and b/node2vec_spark/src/main/resources/neo4j-spark-connector-full-2.4.5-M1.jar differ
diff --git a/node2vec_spark/src/main/scala/com/navercorp/Main.scala b/node2vec_spark/src/main/scala/com/navercorp/Main.scala
index f3494e5..20f8f40 100644
--- a/node2vec_spark/src/main/scala/com/navercorp/Main.scala
+++ b/node2vec_spark/src/main/scala/com/navercorp/Main.scala
@@ -8,7 +8,7 @@ import com.navercorp.lib.AbstractParams
object Main {
object Command extends Enumeration {
type Command = Value
- val node2vec, randomwalk, embedding = Value
+ val node2vec, randomwalk, embedding, neo2vec = Value
}
import Command._
@@ -26,6 +26,7 @@ object Main {
degree: Int = 30,
indexed: Boolean = true,
nodePath: String = null,
+ neoQuery: String = "MATCH (a)-[r]->(b) RETURN id(a) as source, id(b) as target, 1.0 as value",
input: String = null,
output: String = null,
cmd: Command = Command.node2vec) extends AbstractParams[Params] with Serializable
@@ -36,6 +37,9 @@ object Main {
opt[Int]("walkLength")
.text(s"walkLength: ${defaultParams.walkLength}")
.action((x, c) => c.copy(walkLength = x))
+ opt[Int]("dim")
+ .text(s"dim: ${defaultParams.dim}")
+ .action((x, c) => c.copy(dim = x))
opt[Int]("numWalks")
.text(s"numWalks: ${defaultParams.numWalks}")
.action((x, c) => c.copy(numWalks = x))
@@ -60,6 +64,9 @@ object Main {
opt[String]("nodePath")
.text("Input node2index file path: empty")
.action((x, c) => c.copy(nodePath = x))
+ opt[String]("neoQuery")
+ .text("Query for fetching graph from Neo4j")
+ .action((x, c) => c.copy(neoQuery = x))
opt[String]("input")
.required()
.text("Input edge file path: empty")
@@ -93,7 +100,7 @@ object Main {
parser.parse(args, defaultParams).map { param =>
val conf = new SparkConf().setAppName("Node2Vec")
val context: SparkContext = new SparkContext(conf)
-
+
Node2vec.setup(context, param)
param.cmd match {
@@ -105,7 +112,11 @@ object Main {
case Command.randomwalk => Node2vec.load()
.initTransitionProb()
.randomWalk()
- .saveRandomPath()
+ case Command.neo2vec => Node2vec.loadNeo()
+ .initTransitionProb()
+ .randomWalk()
+ .embedding()
+ .save()
case Command.embedding => {
val randomPaths = Word2vec.setup(context, param).read(param.input)
Word2vec.fit(randomPaths).save(param.output)
diff --git a/node2vec_spark/src/main/scala/com/navercorp/Node2vec.scala b/node2vec_spark/src/main/scala/com/navercorp/Node2vec.scala
index 07ec21a..f0c6208 100644
--- a/node2vec_spark/src/main/scala/com/navercorp/Node2vec.scala
+++ b/node2vec_spark/src/main/scala/com/navercorp/Node2vec.scala
@@ -10,6 +10,10 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.graphx.{EdgeTriplet, Graph, _}
import com.navercorp.graph.{GraphOps, EdgeAttr, NodeAttr}
+import org.neo4j.spark._
+import org.apache.spark.graphx._
+import org.apache.spark.graphx.lib._
+
object Node2vec extends Serializable {
lazy val logger: Logger = LoggerFactory.getLogger(getClass.getName);
@@ -28,6 +32,45 @@ object Node2vec extends Serializable {
this
}
+ def loadNeo(): this.type = {
+
+ val neo = Neo4j(context)
+
+ val graphQuery = config.neoQuery
+
+ println("Using graph query: %s".format(graphQuery))
+
+ val graph: Graph[Long, Double] = neo.rels(graphQuery).partitions(10).batch(200).loadGraph
+
+ val bcMaxDegree = context.broadcast(config.degree)
+ val bcEdgeCreator = config.directed match {
+ case true => context.broadcast(GraphOps.createDirectedEdge)
+ case false => context.broadcast(GraphOps.createUndirectedEdge)
+ }
+
+ val inputTriplets = graph.edges
+
+ // TODO optimize by removing the graph recomposition redundancy
+ indexedNodes = inputTriplets.flatMap { e =>
+ bcEdgeCreator.value.apply(e.srcId, e.dstId, e.attr)
+ }.reduceByKey(_++_).map { case (nodeId, neighbors: Array[(VertexId, Double)]) =>
+ var neighbors_ = neighbors
+ if (neighbors_.length > bcMaxDegree.value) {
+ neighbors_ = neighbors.sortWith{ case (left, right) => left._2 > right._2 }.slice(0, bcMaxDegree.value)
+ }
+
+ (nodeId, NodeAttr(neighbors = neighbors_.distinct))
+ }.repartition(200).cache
+
+ indexedEdges = indexedNodes.flatMap { case (srcId, clickNode) =>
+ clickNode.neighbors.map { case (dstId, weight) =>
+ Edge(srcId, dstId, EdgeAttr())
+ }
+ }.repartition(200).cache
+
+ this
+ }
+
def load(): this.type = {
val bcMaxDegree = context.broadcast(config.degree)
val bcEdgeCreator = config.directed match {