diff --git a/src/main/java/evaluation/listeners/MetricsGameListener.java b/src/main/java/evaluation/listeners/MetricsGameListener.java index 964ba24b3..46f0593af 100644 --- a/src/main/java/evaluation/listeners/MetricsGameListener.java +++ b/src/main/java/evaluation/listeners/MetricsGameListener.java @@ -45,6 +45,7 @@ public class MetricsGameListener implements IGameListener { // Destination directory for the reports String destDir = "metrics/out/"; //by default + boolean firstReport; public MetricsGameListener() { } @@ -61,6 +62,7 @@ public MetricsGameListener(IDataLogger.ReportDestination logTo, IDataLogger.Repo reportDestinations = Collections.singletonList(logTo); this.reportTypes = Arrays.asList(dataTypes); this.metrics = new LinkedHashMap<>(); + this.firstReport = true; for (AbstractMetric m : metrics) { m.setDataLogger(new DataTableSaw(m)); //todo this logger needs to be read from JSON this.metrics.put(m.getName(), m); @@ -128,7 +130,7 @@ public void report() { // of redundant directories if (!(reportTypes.size() == 1 && reportTypes.contains(RawDataPerEvent))) for (AbstractMetric metric : metrics.values()) { - metric.report(destDir, reportTypes, reportDestinations); + metric.report(destDir, reportTypes, reportDestinations, !firstReport); } // We also create raw data files for groups of metrics responding to the same event @@ -142,9 +144,15 @@ public void report() { } if (!eventMetrics.isEmpty()) { IDataLogger dataLogger = new DataTableSaw(eventMetrics, event, eventToIndexingColumn(event)); - dataLogger.getDefaultProcessor().processRawDataToFile(dataLogger, destDir); + dataLogger.getDefaultProcessor().processRawDataToFile(dataLogger, destDir, !firstReport); } } + //Clean the data. We don't want to keep this in memory; instead we append after every reporting. + for (AbstractMetric metric : metrics.values()) { + IDataLogger dataLogger = metric.getDataLogger(); + dataLogger.flush(); + } + firstReport = false; } } diff --git a/src/main/java/evaluation/metrics/AbstractMetric.java b/src/main/java/evaluation/metrics/AbstractMetric.java index c4b0862f0..91cd14542 100644 --- a/src/main/java/evaluation/metrics/AbstractMetric.java +++ b/src/main/java/evaluation/metrics/AbstractMetric.java @@ -200,7 +200,10 @@ public boolean filterByEventTypeWhenReporting() { * @param reportTypes - list of report types to produce * @param reportDestinations - list of report destinations to produce */ - public void report(String folderName, List reportTypes, List reportDestinations) + public void report(String folderName, + List reportTypes, + List reportDestinations, + boolean append) { //DataProcessor with compatibility assertion: IDataProcessor dataProcessor = getDataProcessor(); @@ -217,7 +220,7 @@ public void report(String folderName, List reportTypes, if (reportType == IDataLogger.ReportType.RawData) { if (reportDestination == IDataLogger.ReportDestination.ToFile || reportDestination == IDataLogger.ReportDestination.ToBoth) { - dataProcessor.processRawDataToFile(dataLogger, folderName); + dataProcessor.processRawDataToFile(dataLogger, folderName, append); } if (reportDestination == IDataLogger.ReportDestination.ToConsole || reportDestination == IDataLogger.ReportDestination.ToBoth) { dataProcessor.processRawDataToConsole(dataLogger); diff --git a/src/main/java/evaluation/metrics/IDataLogger.java b/src/main/java/evaluation/metrics/IDataLogger.java index 147e146fc..3a65dedf7 100644 --- a/src/main/java/evaluation/metrics/IDataLogger.java +++ b/src/main/java/evaluation/metrics/IDataLogger.java @@ -45,6 +45,8 @@ default void reset() {} */ IDataProcessor getDefaultProcessor(); + void flush(); + IDataLogger copy(); IDataLogger emptyCopy(); IDataLogger create(); diff --git a/src/main/java/evaluation/metrics/IDataProcessor.java b/src/main/java/evaluation/metrics/IDataProcessor.java index c4330930c..600d55358 100644 --- a/src/main/java/evaluation/metrics/IDataProcessor.java +++ b/src/main/java/evaluation/metrics/IDataProcessor.java @@ -26,7 +26,7 @@ public interface IDataProcessor * @param logger - logger that contains the raw data * @param folderName - name of the folder to save the file to */ - void processRawDataToFile(IDataLogger logger, String folderName); + void processRawDataToFile(IDataLogger logger, String folderName, boolean append); /** diff --git a/src/main/java/evaluation/metrics/TournamentMetric.java b/src/main/java/evaluation/metrics/TournamentMetric.java index fd9c148ed..c31840506 100644 --- a/src/main/java/evaluation/metrics/TournamentMetric.java +++ b/src/main/java/evaluation/metrics/TournamentMetric.java @@ -20,9 +20,12 @@ public class TournamentMetric extends AbstractMetric { AbstractMetric wrappedMetric; + private boolean firstReport; + public TournamentMetric(AbstractMetric metric) { super(metric.getEventTypes()); this.wrappedMetric = metric; + this.firstReport = true; } /** @@ -118,7 +121,7 @@ public void report(String folderName, List reportTypes, if (reportType == IDataLogger.ReportType.RawData) { if (reportDestination == IDataLogger.ReportDestination.ToFile || reportDestination == IDataLogger.ReportDestination.ToBoth) { - dataProcessor.processRawDataToFile(logger, folder); + dataProcessor.processRawDataToFile(logger, folder, !firstReport); } if (reportDestination == IDataLogger.ReportDestination.ToConsole || reportDestination == IDataLogger.ReportDestination.ToBoth) { dataProcessor.processRawDataToConsole(logger); @@ -140,5 +143,6 @@ public void report(String folderName, List reportTypes, } } } + firstReport = false; } } diff --git a/src/main/java/evaluation/metrics/tablessaw/DataTableSaw.java b/src/main/java/evaluation/metrics/tablessaw/DataTableSaw.java index 40a0e9e9c..bc98037ad 100644 --- a/src/main/java/evaluation/metrics/tablessaw/DataTableSaw.java +++ b/src/main/java/evaluation/metrics/tablessaw/DataTableSaw.java @@ -120,6 +120,11 @@ public IDataProcessor getDefaultProcessor() { return new TableSawDataProcessor(); } + @Override + public void flush() { + this.data = data.emptyCopy(); + } + @Override public IDataLogger copy() { return new DataTableSaw(metric, data.copy()); @@ -297,4 +302,5 @@ record = true; } } } + } \ No newline at end of file diff --git a/src/main/java/evaluation/metrics/tablessaw/TableSawDataProcessor.java b/src/main/java/evaluation/metrics/tablessaw/TableSawDataProcessor.java index 0d6f9f9f1..ac1e335c8 100644 --- a/src/main/java/evaluation/metrics/tablessaw/TableSawDataProcessor.java +++ b/src/main/java/evaluation/metrics/tablessaw/TableSawDataProcessor.java @@ -7,6 +7,7 @@ import evaluation.summarisers.TAGNumericStatSummary; import tech.tablesaw.api.*; import tech.tablesaw.columns.Column; +import tech.tablesaw.io.csv.CsvWriteOptions; import tech.tablesaw.plotly.Plot; import tech.tablesaw.plotly.api.LinePlot; import tech.tablesaw.plotly.components.*; @@ -26,9 +27,23 @@ public class TableSawDataProcessor implements IDataProcessor { @Override - public void processRawDataToFile(IDataLogger logger, String folderName) { + public void processRawDataToFile(IDataLogger logger, String folderName, boolean append) { DataTableSaw dts = (DataTableSaw) logger; - dts.data.write().csv(folderName + "/" + dts.data.name() + ".csv"); + String filename = folderName + "/" + dts.data.name() + ".csv"; + if(!append) { + dts.data.write().csv(filename); + } else { + try { + File file = new File(filename); + boolean headerNeeded = !file.exists(); + Writer w = new FileWriter(file, true); + CsvWriteOptions.Builder options = CsvWriteOptions.builder(w); + options.header(headerNeeded); + dts.data.write().csv(options.build()); + } catch (IOException e) { + throw new RuntimeException(e); + } + } } @Override diff --git a/src/test/java/evaluation/TunableParametersTest.java b/src/test/java/evaluation/TunableParametersTest.java index 0baba64b2..5718219cb 100644 --- a/src/test/java/evaluation/TunableParametersTest.java +++ b/src/test/java/evaluation/TunableParametersTest.java @@ -179,7 +179,7 @@ public void toJSONWithDefaults() { assertEquals(67, json.get("maxTreeDepth")); assertEquals(0.56, params.getParameterValue("rolloutPolicyParams.temperature")); assertEquals(0.56, ((JSONObject) json.get("rolloutPolicyParams")).get("temperature")); - assertEquals(Math.sqrt(2), (Double) json.get("K"), 0.002); + assertEquals(1.0, (Double) json.get("K"), 0.002); assertEquals(false, json.get("useMASTAsActionHeuristic")); assertEquals("BUDGET_FM_CALLS", json.get("budgetType"));