-
Notifications
You must be signed in to change notification settings - Fork 0
Adding a basic example model for experimentation and testing. #3
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,3 +1,3 @@ | ||
| from .example_module import greetings, meaning | ||
| from .example_model import ExampleModel | ||
|
|
||
| __all__ = ["greetings", "meaning"] | ||
| __all__ = ["ExampleModel"] |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,3 @@ | ||
| [model] | ||
| [model.ExampleModel] | ||
| layer = 10 |
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,35 @@ | ||||||||||||||||||||||||||||||||||||||
| import torch | ||||||||||||||||||||||||||||||||||||||
| import torch.nn as nn | ||||||||||||||||||||||||||||||||||||||
| from hyrax.models.model_registry import hyrax_model | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| @hyrax_model | ||||||||||||||||||||||||||||||||||||||
| class ExampleModel(nn.Module): | ||||||||||||||||||||||||||||||||||||||
| """Simple example of an externally defined model for testing and demonstration | ||||||||||||||||||||||||||||||||||||||
| purposes.""" | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| def __init__(self, config, data_sample=None): | ||||||||||||||||||||||||||||||||||||||
| """Basic initialization with architecture definition""" | ||||||||||||||||||||||||||||||||||||||
| super().__init__() | ||||||||||||||||||||||||||||||||||||||
| channels, width, height = data_sample["data"]["image"].shape | ||||||||||||||||||||||||||||||||||||||
| self.config = config | ||||||||||||||||||||||||||||||||||||||
| layer = self.config["model"]["ExampleModel"]["layer"] | ||||||||||||||||||||||||||||||||||||||
| self.linear = nn.Linear(channels * width * height, layer) | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| def forward(self, x): | ||||||||||||||||||||||||||||||||||||||
| """Standard PyTorch forward method""" | ||||||||||||||||||||||||||||||||||||||
| return self.linear(x) | ||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
| return self.linear(x) | |
| return self.linear(x.view(x.size(0), -1)) |
Copilot
AI
Oct 1, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Inconsistent tensor handling between train_step and forward methods. The train_step method passes x directly to forward(), but forward() expects a flattened tensor. This will cause a shape mismatch error.
Copilot
AI
Oct 1, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing input validation and unclear indexing logic. The code assumes data_dict has the expected structure and that 'image' is indexable, but there's no validation. The [0] indexing is also unclear - consider adding a comment explaining why only the first element is taken.
| image = data_dict["data"]["image"][0] | |
| label = data_dict["data"]["label"] | |
| return (torch.tensor(image, dtype=torch.float32), torch.tensor(label, dtype=torch.float32)) | |
| # Validate input structure | |
| if not isinstance(data_dict, dict): | |
| raise ValueError("Expected data_dict to be a dict.") | |
| if "data" not in data_dict or not isinstance(data_dict["data"], dict): | |
| raise ValueError("Expected data_dict to have a 'data' key containing a dict.") | |
| data = data_dict["data"] | |
| if "image" not in data or "label" not in data: | |
| raise ValueError("Expected 'data' dict to contain 'image' and 'label' keys.") | |
| image = data["image"] | |
| # Assume 'image' is a batch; take the first sample | |
| if not hasattr(image, "__getitem__") or len(image) == 0: | |
| raise ValueError("'image' must be a non-empty indexable object (e.g., list, array).") | |
| image_sample = image[0] | |
| label = data["label"] | |
| return (torch.tensor(image_sample, dtype=torch.float32), torch.tensor(label, dtype=torch.float32)) |
This file was deleted.
This file was deleted.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing input validation for data_sample parameter. The code assumes data_sample is not None and has the expected nested structure, but could fail if data_sample is None or missing the 'data'/'image' keys.