diff --git a/io.openems.edge.predictor.lstm/src/io/openems/edge/predictor/lstm/PredictorLstmImpl.java b/io.openems.edge.predictor.lstm/src/io/openems/edge/predictor/lstm/PredictorLstmImpl.java index 1ae7b5f61ed..00408adf21d 100644 --- a/io.openems.edge.predictor.lstm/src/io/openems/edge/predictor/lstm/PredictorLstmImpl.java +++ b/io.openems.edge.predictor.lstm/src/io/openems/edge/predictor/lstm/PredictorLstmImpl.java @@ -137,26 +137,14 @@ protected Prediction createNewPrediction(ChannelAddress channelAddress) { var dayPlus1SeasonalityFuture = CompletableFuture .supplyAsync(() -> this.predictSeasonality(channelAddress, now.plusDays(1), hyperParameters)); - // var combinePrerequisites = CompletableFuture.allOf(seasonalityFuture, - // trendFuture); - try { - // TODO combinePrerequisites.get(); - - // Current day prediction - var currentDayPredicted = combine(trendFuture.get(), seasonalityFuture.get()); - - // Next Day prediction - var plus1DaySeasonalityPrediction = dayPlus1SeasonalityFuture.get(); - // Concat current and Nextday + CompletableFuture.allOf(seasonalityFuture, trendFuture, dayPlus1SeasonalityFuture).join(); + var currentDayPredicted = combine(trendFuture.join(), seasonalityFuture.join()); + var plus1DaySeasonalityPrediction = dayPlus1SeasonalityFuture.join(); var actualPredicted = concatenateList(currentDayPredicted, plus1DaySeasonalityPrediction); - var baseTimeOfPrediction = now.withMinute(getMinute(now, hyperParameters)).withSecond(0).withNano(0); - - return Prediction.from(this.sum, channelAddress, // - baseTimeOfPrediction, // - averageInChunks(actualPredicted)); + return Prediction.from(this.sum, channelAddress, baseTimeOfPrediction, averageInChunks(actualPredicted)); } catch (Exception e) { throw new RuntimeException("Error in getting prediction execution", e); diff --git a/io.openems.edge.predictor.lstm/src/io/openems/edge/predictor/lstm/train/LstmTrain.java b/io.openems.edge.predictor.lstm/src/io/openems/edge/predictor/lstm/train/LstmTrain.java index 67e1e3053db..40a444c14b7 100644 --- a/io.openems.edge.predictor.lstm/src/io/openems/edge/predictor/lstm/train/LstmTrain.java +++ b/io.openems.edge.predictor.lstm/src/io/openems/edge/predictor/lstm/train/LstmTrain.java @@ -93,14 +93,13 @@ public void run() { // Get the validationDate var validationDate = this.getDate(validateMap); - /** - * TODO Read an save model.adapt method ReadAndSaveModels.adapt(hyperParameters, - * validateBatchData, validateBatchDate); - */ - new TrainAndValidateBatch(// - constantScaling(removeNegatives(trainingData), 1), trainingDate, // - constantScaling(removeNegatives(validationData), 1), validationDate, // - hyperParameters); + // --- Adapt the model based on validation data before training --- + // We use the preprocessed validation data for adaptation. + ReadAndSaveModels.adapt(hyperParameters, constantScaling(removeNegatives(validationData), 1), validationDate); + + // Perform training and validation in batch + new TrainAndValidateBatch(constantScaling(removeNegatives(trainingData), 1), trainingDate, + constantScaling(removeNegatives(validationData), 1), validationDate, hyperParameters); this.parent._setLastTrainedTime(hyperParameters.getLastTrainedDate().toInstant().toEpochMilli()); this.parent._setModelError(Collections.min(hyperParameters.getRmsErrorSeasonality()));