-
Notifications
You must be signed in to change notification settings - Fork 19.7k
Fix/random_crop validation behavior #21871
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
utsab345
wants to merge
15
commits into
keras-team:master
Choose a base branch
from
utsab345:fix/randomcrop-validation-behavior
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+232
−123
Open
Changes from all commits
Commits
Show all changes
15 commits
Select commit
Hold shift + click to select a range
2b65f89
Simplify save_img: remove _format, normalize jpg→jpeg, add RGBA→RGB h…
utsab345 2a5bb21
Merge branch 'keras-team:master' into master
utsab345 6e0340b
Merge branch 'keras-team:master' into master
utsab345 7b84f95
Merge branch 'keras-team:master' into master
utsab345 323ca8d
Merge branch 'keras-team:master' into master
utsab345 1fa75a0
Merge branch 'keras-team:master' into master
utsab345 80f4d4e
Merge branch 'keras-team:master' into master
utsab345 7c052bc
Merge remote-tracking branch 'upstream/master'
utsab345 73314e2
Merge branch 'keras-team:master' into master
utsab345 198512f
Merge branch 'keras-team:master' into master
utsab345 3af1c52
Merge branch 'keras-team:master' into master
utsab345 95ae553
Fix JAX flash attention mask tracing
utsab345 3968450
Merge branch 'keras-team:master' into master
utsab345 1577691
Fix RandomCrop validation behavior and >= condition (closes #21868)
utsab345 4b94a3b
Fix RandomCrop validation behavior: center crop vs resize
utsab345 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -50,16 +50,23 @@ class RandomCrop(BaseImagePreprocessingLayer): | |
| """ | ||
|
|
||
| def __init__( | ||
| self, height, width, seed=None, data_format=None, name=None, **kwargs | ||
| self, | ||
| height, | ||
| width, | ||
| seed=None, | ||
| data_format=None, | ||
| name=None, | ||
| center_crop=True, | ||
| **kwargs, | ||
| ): | ||
| super().__init__(name=name, **kwargs) | ||
| self.height = height | ||
| self.width = width | ||
| self.seed = ( | ||
| seed if seed is not None else backend.random.make_default_seed() | ||
| ) | ||
| self.seed = seed if seed is not None else backend.random.make_default_seed() | ||
| self.generator = SeedGenerator(seed) | ||
| self.data_format = backend.standardize_data_format(data_format) | ||
| # New flag to control validation behavior: center crop if True, otherwise resize. | ||
| self.center_crop = center_crop | ||
|
|
||
| if self.data_format == "channels_first": | ||
| self.height_axis = -2 | ||
|
|
@@ -92,7 +99,7 @@ def get_random_transformation(self, data, training=True, seed=None): | |
| f"height and width. Received: images.shape={input_shape}" | ||
| ) | ||
|
|
||
| if training and input_height > self.height and input_width > self.width: | ||
| if training and input_height >= self.height and input_width >= self.width: | ||
| h_start = self.backend.cast( | ||
| self.backend.random.uniform( | ||
| (), | ||
|
|
@@ -112,70 +119,83 @@ def get_random_transformation(self, data, training=True, seed=None): | |
| "int32", | ||
| ) | ||
| else: | ||
| crop_height = int(float(input_width * self.height) / self.width) | ||
| crop_height = max(min(input_height, crop_height), 1) | ||
| crop_width = int(float(input_height * self.width) / self.height) | ||
| crop_width = max(min(input_width, crop_width), 1) | ||
| h_start = int(float(input_height - crop_height) / 2) | ||
| w_start = int(float(input_width - crop_width) / 2) | ||
| # Validation (training=False) behavior based on self.center_crop flag | ||
| if self.center_crop: | ||
| # Center crop | ||
| h_start = self.backend.cast((input_height - self.height) / 2, "int32") | ||
| w_start = self.backend.cast((input_width - self.width) / 2, "int32") | ||
| else: | ||
| # Direct resize: set offsets to zero; cropping will be bypassed later | ||
| h_start = self.backend.cast(0, "int32") | ||
| w_start = self.backend.cast(0, "int32") | ||
|
Comment on lines
+122
to
+130
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The formulas that were there before were correct, the result was just not used. Previously, it would resize while keeping the aspect ratio and then center crop the extra. But this: h_start = self.backend.cast((input_height - self.height) / 2, "int32")
w_start = self.backend.cast((input_width - self.width) / 2, "int32")will give you negative |
||
|
|
||
| return h_start, w_start | ||
|
|
||
| def transform_images(self, images, transformation, training=True): | ||
| if training: | ||
| images = self.backend.cast(images, self.compute_dtype) | ||
| crop_box_hstart, crop_box_wstart = transformation | ||
| crop_height = self.height | ||
| crop_width = self.width | ||
|
|
||
| # If we are in validation mode and center_crop is False, skip cropping and directly resize. | ||
| if not training and not self.center_crop: | ||
| # Direct resize to target size | ||
| images = self.backend.image.resize( | ||
| images, | ||
| size=(self.height, self.width), | ||
| data_format=self.data_format, | ||
| ) | ||
| images = self.backend.cast(images, self.compute_dtype) | ||
| crop_box_hstart, crop_box_wstart = transformation | ||
| crop_height = self.height | ||
| crop_width = self.width | ||
| return images | ||
|
|
||
| if self.data_format == "channels_last": | ||
| if len(images.shape) == 4: | ||
| images = images[ | ||
| :, | ||
| crop_box_hstart : crop_box_hstart + crop_height, | ||
| crop_box_wstart : crop_box_wstart + crop_width, | ||
| :, | ||
| ] | ||
| else: | ||
| images = images[ | ||
| crop_box_hstart : crop_box_hstart + crop_height, | ||
| crop_box_wstart : crop_box_wstart + crop_width, | ||
| :, | ||
| ] | ||
| if self.data_format == "channels_last": | ||
| if len(images.shape) == 4: | ||
| images = images[ | ||
| :, | ||
| crop_box_hstart : crop_box_hstart + crop_height, | ||
| crop_box_wstart : crop_box_wstart + crop_width, | ||
| :, | ||
| ] | ||
| else: | ||
| if len(images.shape) == 4: | ||
| images = images[ | ||
| :, | ||
| :, | ||
| crop_box_hstart : crop_box_hstart + crop_height, | ||
| crop_box_wstart : crop_box_wstart + crop_width, | ||
| ] | ||
| else: | ||
| images = images[ | ||
| :, | ||
| crop_box_hstart : crop_box_hstart + crop_height, | ||
| crop_box_wstart : crop_box_wstart + crop_width, | ||
| ] | ||
|
|
||
| shape = self.backend.shape(images) | ||
| new_height = shape[self.height_axis] | ||
| new_width = shape[self.width_axis] | ||
| if ( | ||
| not isinstance(new_height, int) | ||
| or not isinstance(new_width, int) | ||
| or new_height != self.height | ||
| or new_width != self.width | ||
| ): | ||
| # Resize images if size mismatch or | ||
| # if size mismatch cannot be determined | ||
| # (in the case of a TF dynamic shape). | ||
| images = self.backend.image.resize( | ||
| images, | ||
| size=(self.height, self.width), | ||
| data_format=self.data_format, | ||
| ) | ||
| # Resize may have upcasted the outputs | ||
| images = self.backend.cast(images, self.compute_dtype) | ||
| images = images[ | ||
| crop_box_hstart : crop_box_hstart + crop_height, | ||
| crop_box_wstart : crop_box_wstart + crop_width, | ||
| :, | ||
| ] | ||
| else: | ||
| if len(images.shape) == 4: | ||
| images = images[ | ||
| :, | ||
| :, | ||
| crop_box_hstart : crop_box_hstart + crop_height, | ||
| crop_box_wstart : crop_box_wstart + crop_width, | ||
| ] | ||
| else: | ||
| images = images[ | ||
| :, | ||
| crop_box_hstart : crop_box_hstart + crop_height, | ||
| crop_box_wstart : crop_box_wstart + crop_width, | ||
| ] | ||
| # Resize if the cropped image doesn't match target size | ||
| shape = self.backend.shape(images) | ||
| new_height = shape[self.height_axis] | ||
| new_width = shape[self.width_axis] | ||
| if ( | ||
| not isinstance(new_height, int) | ||
| or not isinstance(new_width, int) | ||
| or new_height != self.height | ||
| or new_width != self.width | ||
| ): | ||
| # Resize images if size mismatch or | ||
| # if size mismatch cannot be determined | ||
| # (in the case of a TF dynamic shape). | ||
| images = self.backend.image.resize( | ||
| images, | ||
| size=(self.height, self.width), | ||
| data_format=self.data_format, | ||
| ) | ||
| # Resize may have upcasted the outputs | ||
| images = self.backend.cast(images, self.compute_dtype) | ||
| return images | ||
|
|
||
| def transform_labels(self, labels, transformation, training=True): | ||
|
|
@@ -199,58 +219,57 @@ def transform_bounding_boxes( | |
| } | ||
| """ | ||
|
|
||
| if training: | ||
| h_start, w_start = transformation | ||
| if not self.backend.is_tensor(bounding_boxes["boxes"]): | ||
| bounding_boxes = densify_bounding_boxes( | ||
| bounding_boxes, backend=self.backend | ||
| ) | ||
| boxes = bounding_boxes["boxes"] | ||
| # Convert to a standard xyxy as operations are done xyxy by default. | ||
| boxes = convert_format( | ||
| boxes=boxes, | ||
| source=self.bounding_box_format, | ||
| target="xyxy", | ||
| height=self.height, | ||
| width=self.width, | ||
| # Apply transformation for both training and validation | ||
| h_start, w_start = transformation | ||
| if not self.backend.is_tensor(bounding_boxes["boxes"]): | ||
| bounding_boxes = densify_bounding_boxes( | ||
| bounding_boxes, backend=self.backend | ||
| ) | ||
| h_start = self.backend.cast(h_start, boxes.dtype) | ||
| w_start = self.backend.cast(w_start, boxes.dtype) | ||
| if len(self.backend.shape(boxes)) == 3: | ||
| boxes = self.backend.numpy.stack( | ||
| [ | ||
| self.backend.numpy.maximum(boxes[:, :, 0] - h_start, 0), | ||
| self.backend.numpy.maximum(boxes[:, :, 1] - w_start, 0), | ||
| self.backend.numpy.maximum(boxes[:, :, 2] - h_start, 0), | ||
| self.backend.numpy.maximum(boxes[:, :, 3] - w_start, 0), | ||
| ], | ||
| axis=-1, | ||
| ) | ||
| else: | ||
| boxes = self.backend.numpy.stack( | ||
| [ | ||
| self.backend.numpy.maximum(boxes[:, 0] - h_start, 0), | ||
| self.backend.numpy.maximum(boxes[:, 1] - w_start, 0), | ||
| self.backend.numpy.maximum(boxes[:, 2] - h_start, 0), | ||
| self.backend.numpy.maximum(boxes[:, 3] - w_start, 0), | ||
| ], | ||
| axis=-1, | ||
| ) | ||
|
|
||
| # Convert to user defined bounding box format | ||
| boxes = convert_format( | ||
| boxes=boxes, | ||
| source="xyxy", | ||
| target=self.bounding_box_format, | ||
| height=self.height, | ||
| width=self.width, | ||
| boxes = bounding_boxes["boxes"] | ||
| # Convert to a standard xyxy as operations are done xyxy by default. | ||
| boxes = convert_format( | ||
| boxes=boxes, | ||
| source=self.bounding_box_format, | ||
| target="xyxy", | ||
| height=self.height, | ||
| width=self.width, | ||
| ) | ||
| h_start = self.backend.cast(h_start, boxes.dtype) | ||
| w_start = self.backend.cast(w_start, boxes.dtype) | ||
| if len(self.backend.shape(boxes)) == 3: | ||
| boxes = self.backend.numpy.stack( | ||
| [ | ||
| self.backend.numpy.maximum(boxes[:, :, 0] - h_start, 0), | ||
| self.backend.numpy.maximum(boxes[:, :, 1] - w_start, 0), | ||
| self.backend.numpy.maximum(boxes[:, :, 2] - h_start, 0), | ||
| self.backend.numpy.maximum(boxes[:, :, 3] - w_start, 0), | ||
| ], | ||
| axis=-1, | ||
| ) | ||
| else: | ||
| boxes = self.backend.numpy.stack( | ||
| [ | ||
| self.backend.numpy.maximum(boxes[:, 0] - h_start, 0), | ||
| self.backend.numpy.maximum(boxes[:, 1] - w_start, 0), | ||
| self.backend.numpy.maximum(boxes[:, 2] - h_start, 0), | ||
| self.backend.numpy.maximum(boxes[:, 3] - w_start, 0), | ||
| ], | ||
| axis=-1, | ||
| ) | ||
|
|
||
| return { | ||
| "boxes": boxes, | ||
| "labels": bounding_boxes["labels"], | ||
| } | ||
| return bounding_boxes | ||
| # Convert to user defined bounding box format | ||
| boxes = convert_format( | ||
| boxes=boxes, | ||
| source="xyxy", | ||
| target=self.bounding_box_format, | ||
| height=self.height, | ||
| width=self.width, | ||
| ) | ||
|
|
||
| return { | ||
| "boxes": boxes, | ||
| "labels": bounding_boxes["labels"], | ||
| } | ||
|
|
||
| def transform_segmentation_masks( | ||
| self, segmentation_masks, transformation, training=True | ||
|
|
@@ -271,6 +290,7 @@ def get_config(self): | |
| "width": self.width, | ||
| "seed": self.seed, | ||
| "data_format": self.data_format, | ||
| "center_crop": self.center_crop, | ||
| } | ||
| ) | ||
| return config | ||
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please revert this file. Unrelated to random_crop.