Skip to content

Conversation

@Atharva9621
Copy link

@Atharva9621 Atharva9621 commented Oct 20, 2025

Add support for custom loss and metrics in model_sweep

Fixes #544

  • Custom loss, metrics, and optimizers can now be passed to model_sweep in the same way as tabular_model.fit() through custom_fit_params.
  • custom_fit_params expects a dictionary specifying the custom loss, metrics, or optimizer.
  • Minimal code changes; fully backward compatible.
  • Updated corresponding tests.

Example usage

class CustomLoss(nn.Module):
      def __init__(self):
          super(CustomLoss, self).__init__()
  
      def forward(self, inputs, targets):
          loss = torch.mean((inputs - targets) ** 4)
          return 100*loss.mean()

def custom_metric(y_hat, y):
    return (y_hat - y).mean()

sweep_df, best_model = model_sweep(
    task="regression",
    train=train,
    test=val,
    data_config=data_config,
    optimizer_config=optimizer_config,
    trainer_config=trainer_config,
    model_list="lite",
    custom_fit_params = {
        "loss": CustomLoss(),
        "metrics": [custom_metric],
        "metrics_prob_inputs": [True],
        "optimizer": torch.optim.Adagrad,
    }
)

📚 Documentation preview 📚: https://pytorch-tabular--587.org.readthedocs.build/en/587/

@dosubot dosubot bot added size:M This PR changes 30-99 lines, ignoring generated files. enhancement New feature or request labels Oct 20, 2025
Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR adds support for custom loss functions, metrics, and optimizers to the model_sweep function, making it consistent with the TabularModel.fit() API.

Key Changes

  • Added custom_fit_params parameter to model_sweep function that accepts custom loss, metrics, and optimizer specifications
  • Updated validation logic to ensure rank_metric is "loss" when custom metrics are provided
  • Enhanced test coverage with a new test case (test_model_compare_custom) demonstrating custom fit parameters

Reviewed Changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 3 comments.

File Description
src/pytorch_tabular/tabular_model_sweep.py Added custom_fit_params parameter to model_sweep and _validate_args, with validation logic and documentation; unpacks params when calling prepare_model
tests/test_common.py Updated _run_model_compare to accept and forward custom_fit_params; added new test case with custom loss, metrics, and optimizer

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

suppress_lightning_logger (bool, optional): If True, will suppress the lightning logger. Defaults to True.
custom_fit_params (dict, optional): A dict specifying custom loss, metrics and optimizer.
The behviour of these custom parameters is similar to those passed through the `fit` method
Copy link

Copilot AI Nov 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Typo in documentation: "behviour" should be "behaviour".

Suggested change
The behviour of these custom parameters is similar to those passed through the `fit` method
The behaviour of these custom parameters is similar to those passed through the `fit` method

Copilot uses AI. Check for mistakes.
assert len(comp_df) == 3
else:
assert len(comp_df) == len(model_list)
if custom_fit_params.get("metric", None) == fake_metric:
Copy link

Copilot AI Nov 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug in test assertion: should check "metrics" (plural) instead of "metric" (singular). The custom_fit_params dictionary uses the key "metrics" (line 1277), so this condition will never be true, making this assertion ineffective.

Suggested change
if custom_fit_params.get("metric", None) == fake_metric:
if fake_metric in custom_fit_params.get("metrics", []):

Copilot uses AI. Check for mistakes.
else:
assert len(comp_df) == len(model_list)
if custom_fit_params.get("metric", None) == fake_metric:
assert "test_fake_metric" in comp_df.columns()
Copy link

Copilot AI Nov 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: columns() is being called as a method, but pandas DataFrame's columns is a property, not a method. This should be comp_df.columns instead of comp_df.columns().

Suggested change
assert "test_fake_metric" in comp_df.columns()
assert "test_fake_metric" in comp_df.columns

Copilot uses AI. Check for mistakes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

enhancement New feature or request size:M This PR changes 30-99 lines, ignoring generated files.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Help: custom loss for model_sweep

1 participant