-
Notifications
You must be signed in to change notification settings - Fork 724
Open
Labels
API designAPI design & software architectureAPI design & software architectureenhancementNew feature or requestNew feature or request
Description
Issue to discuss the M layer design. For context, see design document here: sktime/enhancement-proposals#39
My current proposed design, this is base on v1 and v2 metadata layer design. Long-term state:
- starts off the current "metadata" class, but named like the model, e.g.,
TFT. The current ligthning network is renamedTFT_NN - has tags and
get_test_paramsetc similar to current "metadata" class __init__has all args of two objects: the loader (D2, e.g.,DecoderEncoderModule) and the network (e.g.,TFT_NN). minus data- method
get_loader_classgets the loader class (e.g., the classDecoderEncoderModule);get_loader(data: TimeSeries)produces an loader object, an instance of theget_loader_classreturn. - method
get_nn_classreturns the nn class (e.g.,TFT_NN);get_nn(loader)gets an instance of the nn class. - finally, there is a method
init(data), which calls the above in sequence, and produces a pair of loader and nn, as if the twogetmethods were called in sequence. __call__dispatches toinit
So, a usage vignette could look like:
from lightning.pytorch import Trainer
from pytorch_forecasting import TimeSeries
from pytorch_forecasting.models import TFT
dataset = TimeSeries(...)
model_cfg = TFT(
max_encoder_length=30,
max_prediction_length=1,
batch_size=32,
loss=nn.MSELoss(),
logging_metrics=[MAE(), SMAPE()],
optimizer="adam",
hidden_size=64,
num_layers=2,
attention_head_size=4,
)
net, loader = model_pkg(dataset)
trainer = Trainer(
max_epochs=5,
accelerator="auto",
devices=1,
enable_progress_bar=True,
log_every_n_steps=10,
)
trainer.fit(net, loader)
etcThe only thing that changes for other models are the model class, and the args/values of it, for model_pkg.
In sktime, we would add the trainer as an arg to __init__, and sktime fit(data) does self.trainer(*self.model_cfg(data)) (with some potential conversion for data - or we could allow TimeSeries as an mtype)
Metadata
Metadata
Assignees
Labels
API designAPI design & software architectureAPI design & software architectureenhancementNew feature or requestNew feature or request