-
Couldn't load subscription status.
- Fork 1
PR brushing up abstract trainer class #18
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
…clarity and facilitate easier subclassing.
… and functionality
…stopping criteria
…iner as logging will be made obligatory.
…it for better clarity on the minimal trainer
…Trainer for improved integration
…_split method and update AbstractTrainer to support specification of custom DataLoaders.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
| def default_random_split( | ||
| dataset: Dataset, | ||
| **kwargs | ||
| ) -> Tuple[DataLoader, DataLoader, DataLoader]: | ||
| """ | ||
| Randomly split a dataset into train, validation, and test sets. | ||
| :param dataset: The dataset to split. | ||
| :param train_frac: Fraction of data to use for training (default: 0.7). | ||
| :param val_frac: Fraction of data to use for validation (default: 0.15). | ||
| :param test_frac: Fraction of data to use for testing (default: remaining). | ||
| :param batch_size: Batch size for the DataLoaders (default: 4). | ||
| :param shuffle: Whether to shuffle the data in the DataLoaders | ||
| (default: True). | ||
| :return: A tuple of DataLoaders for (train, val, test) splits. | ||
| """ | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Solid
| train_frac = kwargs.get("train_frac", 0.7) | ||
| val_frac = kwargs.get("val_frac", 0.15) | ||
| test_frac = kwargs.get( | ||
| "test_frac", 1.0 - train_frac - val_frac | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are there hold out data that is removed prior to calling this function?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Double checking these functions are supposed to be ... for now?
Previously logger support is achieved by wrapping
AbstractTrainerwith yet another layer of abstract class, which is redundant, this PR adds logger support toAbstractTrainerdirectly along side a few minor refactoring (no change to functionality).Refactors:
AbstractTrainer.py:data_split.pyAdds:
AbstractTrainer.py:trainmethod support to logger classTrainerProtocol.py: