diff --git a/jvector-examples/src/main/java/io/github/jbellis/jvector/example/AutoBenchYAML.java b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/AutoBenchYAML.java index f8aa81575..815d1e555 100644 --- a/jvector-examples/src/main/java/io/github/jbellis/jvector/example/AutoBenchYAML.java +++ b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/AutoBenchYAML.java @@ -193,7 +193,7 @@ public static void main(String[] args) throws IOException { // Write CSV data try (FileWriter writer = new FileWriter(outputFile)) { // Write CSV header - writer.write("dataset,QPS,QPS StdDev,Mean Latency,Recall@10,Index Construction Time\n"); + writer.write("dataset,QPS,QPS StdDev,Mean Latency,Recall@10,Index Construction Time,Avg Nodes Visited\n"); // Write one row per dataset with average metrics for (Map.Entry entry : statsByDataset.entrySet()) { @@ -205,7 +205,8 @@ public static void main(String[] args) throws IOException { writer.write(datasetStats.getQpsStdDev() + ","); writer.write(datasetStats.getAvgLatency() + ","); writer.write(datasetStats.getAvgRecall() + ","); - writer.write(datasetStats.getIndexConstruction() + "\n"); + writer.write(datasetStats.getIndexConstruction() + ","); + writer.write(datasetStats.getAvgNodesVisited() + "\n"); } } diff --git a/jvector-examples/src/main/java/io/github/jbellis/jvector/example/Grid.java b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/Grid.java index a4d62645f..c74a482eb 100644 --- a/jvector-examples/src/main/java/io/github/jbellis/jvector/example/Grid.java +++ b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/Grid.java @@ -568,7 +568,8 @@ public static List runAllAndCollectResults( ThroughputBenchmark.createDefault()), LatencyBenchmark.createDefault(), CountBenchmark.createDefault(), - AccuracyBenchmark.createDefault() + AccuracyBenchmark.createDefault(), + CountBenchmark.createDefault() ); QueryTester tester = new QueryTester(benchmarks); for (int topK : topKGrid.keySet()) { diff --git a/jvector-examples/src/main/java/io/github/jbellis/jvector/example/util/BenchmarkSummarizer.java b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/util/BenchmarkSummarizer.java index dba6064ab..60b9a80f1 100644 --- a/jvector-examples/src/main/java/io/github/jbellis/jvector/example/util/BenchmarkSummarizer.java +++ b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/util/BenchmarkSummarizer.java @@ -36,18 +36,16 @@ public static class SummaryStats { private final double indexConstruction; private final int totalConfigurations; private final double qpsStdDev; + private final double avgNodesVisited; - public SummaryStats(double avgRecall, double avgQps, double avgLatency, double indexConstruction, int totalConfigurations) { - this(avgRecall, avgQps, avgLatency, indexConstruction, totalConfigurations, 0.0); - } - - public SummaryStats(double avgRecall, double avgQps, double avgLatency, double indexConstruction, int totalConfigurations, double qpsStdDev) { + public SummaryStats(double avgRecall, double avgQps, double avgLatency, double indexConstruction, int totalConfigurations, double qpsStdDev, double avgNodesVisited) { this.avgRecall = avgRecall; this.avgQps = avgQps; this.avgLatency = avgLatency; this.indexConstruction = indexConstruction; this.totalConfigurations = totalConfigurations; this.qpsStdDev = qpsStdDev; + this.avgNodesVisited = avgNodesVisited; } public double getAvgRecall() { @@ -70,6 +68,8 @@ public int getTotalConfigurations() { public double getQpsStdDev() { return qpsStdDev; } + public double getAvgNodesVisited() { return avgNodesVisited; } + @Override public String toString() { return String.format( @@ -77,8 +77,9 @@ public String toString() { " Average Recall@k: %.4f%n" + " Average QPS: %.2f (± %.2f)%n" + " Average Latency: %.2f ms%n" + - " Index Construction Time: %.2f", - totalConfigurations, avgRecall, avgQps, qpsStdDev, avgLatency, indexConstruction); + " Index Construction Time: %.2f%n" + + " Average Nodes Visited: %.2f", + totalConfigurations, avgRecall, avgQps, qpsStdDev, avgLatency, indexConstruction, avgNodesVisited); } } @@ -89,7 +90,7 @@ public String toString() { */ public static SummaryStats summarize(List results) { if (results == null || results.isEmpty()) { - return new SummaryStats(0, 0, 0, 0, 0, 0); + return new SummaryStats(0, 0, 0, 0, 0, 0, 0); } double totalRecall = 0; @@ -97,11 +98,13 @@ public static SummaryStats summarize(List results) { double totalLatency = 0; double indexConstruction = 0; double totalQpsStdDev = 0; + double totalNodesVisited = 0; int recallCount = 0; int qpsCount = 0; int latencyCount = 0; int qpsStdDevCount = 0; + int nodesVisitedCount = 0; for (BenchResult result : results) { if (result.metrics == null) continue; @@ -135,6 +138,13 @@ public static SummaryStats summarize(List results) { } indexConstruction = extractIndexConstructionMetric(result.metrics); + + // Extract nodes visited metric (format is "Avg Visited") + Double nodesVisited = extractNodesVisitedMetric(result.metrics); + if (nodesVisited != null) { + totalNodesVisited += nodesVisited; + nodesVisitedCount++; + } } // Calculate averages, handling cases where some metrics might not be present @@ -142,11 +152,12 @@ public static SummaryStats summarize(List results) { double avgQps = qpsCount > 0 ? totalQps / qpsCount : 0; double avgLatency = latencyCount > 0 ? totalLatency / latencyCount : 0; double avgQpsStdDev = qpsStdDevCount > 0 ? totalQpsStdDev / qpsStdDevCount : 0; + double avgNodesVisited = nodesVisitedCount > 0 ? totalNodesVisited / nodesVisitedCount : 0; // Count total valid configurations as the maximum count of any metric int totalConfigurations = Math.max(Math.max(recallCount, qpsCount), latencyCount); - return new SummaryStats(avgRecall, avgQps, avgLatency, indexConstruction, totalConfigurations, avgQpsStdDev); + return new SummaryStats(avgRecall, avgQps, avgLatency, indexConstruction, totalConfigurations, avgQpsStdDev, avgNodesVisited); } private static Double extractIndexConstructionMetric(Map metrics) { @@ -235,7 +246,28 @@ private static Double extractQpsStdDevMetric(Map metrics) { } return null; } - + + /** + * Extract an average nodes visited metric from the metrics map + * @param metrics Map of metrics + * @return The average nodes visited value as a Double, or null if not found + */ + private static Double extractNodesVisitedMetric(Map metrics) { + // Try exact match first + Double value = extractMetric(metrics, "Avg Visited"); + if (value != null) return value; + + // Look for any key containing "Avg Visited" case insensitive + for (Map.Entry entry : metrics.entrySet()) { + if (entry.getKey().contains("Avg Visited")) { + return convertToDouble(entry.getValue()); + } + } + + return null; + } + + /** * Extract a specific metric from the metrics map * @param metrics Map of metrics diff --git a/jvector-examples/src/test/java/io/github/jbellis/jvector/example/util/BenchmarkSummarizerTest.java b/jvector-examples/src/test/java/io/github/jbellis/jvector/example/util/BenchmarkSummarizerTest.java index 6168d5dca..b36088c8a 100644 --- a/jvector-examples/src/test/java/io/github/jbellis/jvector/example/util/BenchmarkSummarizerTest.java +++ b/jvector-examples/src/test/java/io/github/jbellis/jvector/example/util/BenchmarkSummarizerTest.java @@ -115,8 +115,8 @@ public void testSummarizeWithNullList() { @Test public void testSummaryStatsToString() { // Create a SummaryStats instance - SummaryStats stats = new SummaryStats(0.85, 1200.0, 5.2, 1000000, 4); - + SummaryStats stats = new SummaryStats(0.85, 1200.0, 5.2, 1000000, 4, 0.2, 100) +; // Verify toString output String expected = String.format( "Benchmark Summary (across %d configurations):%n" + diff --git a/jvector-examples/src/test/java/io/github/jbellis/jvector/example/util/SummarizerTest.java b/jvector-examples/src/test/java/io/github/jbellis/jvector/example/util/SummarizerTest.java index 3dbf7f403..c3d698bab 100644 --- a/jvector-examples/src/test/java/io/github/jbellis/jvector/example/util/SummarizerTest.java +++ b/jvector-examples/src/test/java/io/github/jbellis/jvector/example/util/SummarizerTest.java @@ -121,7 +121,7 @@ private static void testSummaryStatsToString() { System.out.println("\nTest: SummaryStats toString method"); // Create a SummaryStats instance - SummaryStats stats = new SummaryStats(0.85, 1200.0, 5.2, 1000000, 4); + SummaryStats stats = new SummaryStats(0.85, 1200.0, 5.2, 1000000, 4, 0.2, 100); // Verify toString output String expected = String.format( diff --git a/visualize_benchmarks.py b/visualize_benchmarks.py index c8709d5a7..903c3448a 100644 --- a/visualize_benchmarks.py +++ b/visualize_benchmarks.py @@ -30,7 +30,7 @@ # Define metrics where higher values are better and lower values are better HIGHER_IS_BETTER = ["QPS", "Recall@10"] -LOWER_IS_BETTER = ["Mean Latency", "Index Build Time"] +LOWER_IS_BETTER = ["Mean Latency", "Index Build Time", "Average Nodes Visited"] class BenchmarkData: