@@ -36,6 +36,9 @@ def get_color_matrix(rgb_img, mask):
36
36
if len (np .shape (mask )) != 2 :
37
37
fatal_error ("Input mask is not an gray-scale image." )
38
38
39
+ # convert to float and normalize to work with values between 0-1
40
+ rgb_img = rgb_img .astype (np .float64 )/ 255
41
+
39
42
# create empty color_matrix
40
43
color_matrix = np .zeros ((len (np .unique (mask ))- 1 , 4 ))
41
44
@@ -205,6 +208,8 @@ def apply_transformation_matrix(source_img, target_img, transformation_matrix):
205
208
# split transformation_matrix
206
209
red , green , blue , red2 , green2 , blue2 , red3 , green3 , blue3 = np .split (transformation_matrix , 9 , 1 )
207
210
211
+ # convert img to float to avoid integer overflow, normalize between 0-1
212
+ source_img = source_img .astype (np .float64 )/ 255
208
213
# find linear, square, and cubic values of source_img color channels
209
214
source_b , source_g , source_r = cv2 .split (source_img )
210
215
source_b2 = np .square (source_b )
@@ -226,9 +231,10 @@ def apply_transformation_matrix(source_img, target_img, transformation_matrix):
226
231
bgr = [b , g , r ]
227
232
corrected_img = cv2 .merge (bgr )
228
233
229
- # round corrected_img elements to be within range and of the correct data type
230
- corrected_img = np .rint (corrected_img )
231
- corrected_img [np .where (corrected_img > 255 )] = 255
234
+ # return values of the image to the 0-255 range
235
+ corrected_img = 255 * np .clip (corrected_img , 0 , 1 )
236
+ corrected_img = np .floor (corrected_img )
237
+ # cast back to unsigned int
232
238
corrected_img = corrected_img .astype (np .uint8 )
233
239
234
240
if params .debug == "print" :
@@ -237,6 +243,8 @@ def apply_transformation_matrix(source_img, target_img, transformation_matrix):
237
243
elif params .debug == "plot" :
238
244
# If debug is plot, print a horizontal view of source_img, corrected_img, and target_img to the plotting device
239
245
# plot horizontal comparison of source_img, corrected_img (with rounded elements) and target_img
246
+ # cast source_img back to unsigned int between 0-255 for visualization
247
+ source_img = (255 * source_img ).astype (np .uint8 )
240
248
plot_image (np .hstack ([source_img , corrected_img , target_img ]))
241
249
242
250
# return corrected_img
0 commit comments