Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/external_hyrax_example/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .example_module import greetings, meaning
from .example_model import ExampleModel

__all__ = ["greetings", "meaning"]
__all__ = ["ExampleModel"]
3 changes: 3 additions & 0 deletions src/external_hyrax_example/default_config.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[model]
[model.ExampleModel]
layer = 10
35 changes: 35 additions & 0 deletions src/external_hyrax_example/example_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import torch
import torch.nn as nn
from hyrax.models.model_registry import hyrax_model


@hyrax_model
class ExampleModel(nn.Module):
"""Simple example of an externally defined model for testing and demonstration
purposes."""

def __init__(self, config, data_sample=None):
"""Basic initialization with architecture definition"""
super().__init__()
Copy link

Copilot AI Oct 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing input validation for data_sample parameter. The code assumes data_sample is not None and has the expected nested structure, but could fail if data_sample is None or missing the 'data'/'image' keys.

Suggested change
super().__init__()
super().__init__()
# Input validation for data_sample
if (
data_sample is None
or not isinstance(data_sample, dict)
or "data" not in data_sample
or not isinstance(data_sample["data"], dict)
or "image" not in data_sample["data"]
or not hasattr(data_sample["data"]["image"], "shape")
):
raise ValueError(
"data_sample must be a dict with structure: {'data': {'image': <array with .shape>}}"
)

Copilot uses AI. Check for mistakes.
channels, width, height = data_sample["data"]["image"].shape
self.config = config
layer = self.config["model"]["ExampleModel"]["layer"]
self.linear = nn.Linear(channels * width * height, layer)

def forward(self, x):
"""Standard PyTorch forward method"""
return self.linear(x)
Copy link

Copilot AI Oct 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The forward method expects a flattened tensor but receives raw input. The tensor needs to be flattened before passing to the linear layer. Should be return self.linear(x.view(x.size(0), -1)).

Suggested change
return self.linear(x)
return self.linear(x.view(x.size(0), -1))

Copilot uses AI. Check for mistakes.

def train_step(self, batch):
"""The innermost logic in the training loop"""
x, y = batch
y_pred = self(x)
Comment on lines +25 to +26
Copy link

Copilot AI Oct 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Inconsistent tensor handling between train_step and forward methods. The train_step method passes x directly to forward(), but forward() expects a flattened tensor. This will cause a shape mismatch error.

Copilot uses AI. Check for mistakes.
loss = nn.functional.mse_loss(y_pred, y)
return loss

@staticmethod
def to_tensor(data_dict):
"""Method that converts the data in dictionary into the form the model expects"""
image = data_dict["data"]["image"][0]
label = data_dict["data"]["label"]
return (torch.tensor(image, dtype=torch.float32), torch.tensor(label, dtype=torch.float32))
Comment on lines +33 to +35
Copy link

Copilot AI Oct 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing input validation and unclear indexing logic. The code assumes data_dict has the expected structure and that 'image' is indexable, but there's no validation. The [0] indexing is also unclear - consider adding a comment explaining why only the first element is taken.

Suggested change
image = data_dict["data"]["image"][0]
label = data_dict["data"]["label"]
return (torch.tensor(image, dtype=torch.float32), torch.tensor(label, dtype=torch.float32))
# Validate input structure
if not isinstance(data_dict, dict):
raise ValueError("Expected data_dict to be a dict.")
if "data" not in data_dict or not isinstance(data_dict["data"], dict):
raise ValueError("Expected data_dict to have a 'data' key containing a dict.")
data = data_dict["data"]
if "image" not in data or "label" not in data:
raise ValueError("Expected 'data' dict to contain 'image' and 'label' keys.")
image = data["image"]
# Assume 'image' is a batch; take the first sample
if not hasattr(image, "__getitem__") or len(image) == 0:
raise ValueError("'image' must be a non-empty indexable object (e.g., list, array).")
image_sample = image[0]
label = data["label"]
return (torch.tensor(image_sample, dtype=torch.float32), torch.tensor(label, dtype=torch.float32))

Copilot uses AI. Check for mistakes.
35 changes: 0 additions & 35 deletions src/external_hyrax_example/example_module.py

This file was deleted.

13 changes: 0 additions & 13 deletions tests/external_hyrax_example/test_example_module.py

This file was deleted.

Loading