Skip to content

[API] M layer design #1870

@fkiraly

Description

@fkiraly

Issue to discuss the M layer design. For context, see design document here: sktime/enhancement-proposals#39

My current proposed design, this is base on v1 and v2 metadata layer design. Long-term state:

  • starts off the current "metadata" class, but named like the model, e.g., TFT. The current ligthning network is renamed TFT_NN
  • has tags and get_test_params etc similar to current "metadata" class
  • __init__ has all args of two objects: the loader (D2, e.g., DecoderEncoderModule) and the network (e.g., TFT_NN). minus data
  • method get_loader_class gets the loader class (e.g., the class DecoderEncoderModule); get_loader(data: TimeSeries) produces an loader object, an instance of the get_loader_class return.
  • method get_nn_class returns the nn class (e.g., TFT_NN); get_nn(loader) gets an instance of the nn class.
  • finally, there is a method init(data), which calls the above in sequence, and produces a pair of loader and nn, as if the two get methods were called in sequence.
  • __call__ dispatches to init

So, a usage vignette could look like:

from lightning.pytorch import Trainer
from pytorch_forecasting import TimeSeries
from pytorch_forecasting.models import TFT

dataset = TimeSeries(...)
model_cfg = TFT(
    max_encoder_length=30,
    max_prediction_length=1,
    batch_size=32,
    loss=nn.MSELoss(),
    logging_metrics=[MAE(), SMAPE()],
    optimizer="adam",
    hidden_size=64,
    num_layers=2,
    attention_head_size=4,
)

net, loader = model_pkg(dataset)

trainer = Trainer(
    max_epochs=5,
    accelerator="auto",
    devices=1,
    enable_progress_bar=True,
    log_every_n_steps=10,
)

trainer.fit(net, loader)

etc

The only thing that changes for other models are the model class, and the args/values of it, for model_pkg.

In sktime, we would add the trainer as an arg to __init__, and sktime fit(data) does self.trainer(*self.model_cfg(data)) (with some potential conversion for data - or we could allow TimeSeries as an mtype)

Metadata

Metadata

Assignees

No one assigned

    Labels

    API designAPI design & software architectureenhancementNew feature or request

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions