Skip to content

Commit dee77c9

Browse files
authored
Merge pull request #829 from danforthcenter/update_color_correction
Update color correction
2 parents 4d69e78 + 4680a34 commit dee77c9

10 files changed

+11
-3
lines changed

plantcv/plantcv/transform/color_correction.py

+11-3
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,9 @@ def get_color_matrix(rgb_img, mask):
3636
if len(np.shape(mask)) != 2:
3737
fatal_error("Input mask is not an gray-scale image.")
3838

39+
# convert to float and normalize to work with values between 0-1
40+
rgb_img = rgb_img.astype(np.float64)/255
41+
3942
# create empty color_matrix
4043
color_matrix = np.zeros((len(np.unique(mask))-1, 4))
4144

@@ -205,6 +208,8 @@ def apply_transformation_matrix(source_img, target_img, transformation_matrix):
205208
# split transformation_matrix
206209
red, green, blue, red2, green2, blue2, red3, green3, blue3 = np.split(transformation_matrix, 9, 1)
207210

211+
# convert img to float to avoid integer overflow, normalize between 0-1
212+
source_img = source_img.astype(np.float64)/255
208213
# find linear, square, and cubic values of source_img color channels
209214
source_b, source_g, source_r = cv2.split(source_img)
210215
source_b2 = np.square(source_b)
@@ -226,9 +231,10 @@ def apply_transformation_matrix(source_img, target_img, transformation_matrix):
226231
bgr = [b, g, r]
227232
corrected_img = cv2.merge(bgr)
228233

229-
# round corrected_img elements to be within range and of the correct data type
230-
corrected_img = np.rint(corrected_img)
231-
corrected_img[np.where(corrected_img > 255)] = 255
234+
# return values of the image to the 0-255 range
235+
corrected_img = 255*np.clip(corrected_img, 0, 1)
236+
corrected_img = np.floor(corrected_img)
237+
# cast back to unsigned int
232238
corrected_img = corrected_img.astype(np.uint8)
233239

234240
if params.debug == "print":
@@ -237,6 +243,8 @@ def apply_transformation_matrix(source_img, target_img, transformation_matrix):
237243
elif params.debug == "plot":
238244
# If debug is plot, print a horizontal view of source_img, corrected_img, and target_img to the plotting device
239245
# plot horizontal comparison of source_img, corrected_img (with rounded elements) and target_img
246+
# cast source_img back to unsigned int between 0-255 for visualization
247+
source_img = (255*source_img).astype(np.uint8)
240248
plot_image(np.hstack([source_img, corrected_img, target_img]))
241249

242250
# return corrected_img

tests/data/matrix_b1.npz

20 Bytes
Binary file not shown.

tests/data/matrix_b2.npz

20 Bytes
Binary file not shown.

tests/data/matrix_m1.npz

20 Bytes
Binary file not shown.

tests/data/matrix_m2.npz

20 Bytes
Binary file not shown.

tests/data/source1_matrix.npz

20 Bytes
Binary file not shown.

tests/data/source2_matrix.npz

20 Bytes
Binary file not shown.

tests/data/source_corrected.png

-758 Bytes
Loading

tests/data/target_matrix.npz

20 Bytes
Binary file not shown.

tests/data/transformation_matrix1.npz

20 Bytes
Binary file not shown.

0 commit comments

Comments
 (0)