Skip to content

Commit 432fcd2

Browse files
authored
Merge pull request #1648 from danforthcenter/fix-mask-kmeans-output-type
Fix mask kmeans output type
2 parents abe7927 + 5f97cee commit 432fcd2

File tree

2 files changed

+6
-5
lines changed

2 files changed

+6
-5
lines changed

docs/kmeans_classifier.md

+2-2
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
The first function (`pcv.predict_kmeans`) takes a target image and uses a trained kmeans model produced by [`pcv.learn.train_kmeans`](train_kmeans.md) to classify regions of the target image by the trained clusters. The second function (`pcv.mask_kmeans`) takes a list of clusters and produces the combined mask from clusters of interest. The target and training images may be in grayscale or RGB image format.
44

5-
**plantcv.kmeans_classifier.predict_kmeans**(img, model_path="./kmeansout.fit", patch_size=10)
5+
**plantcv.predict_kmeans**(img, model_path="./kmeansout.fit", patch_size=10)
66

77
**outputs** An image with regions colored and labeled according to cluster assignment
88

@@ -18,7 +18,7 @@ The first function (`pcv.predict_kmeans`) takes a target image and uses a traine
1818
- **Example use below**
1919

2020

21-
**plantcv.kmeans_classifier.mask_kmeans**(labeled_img, k, cat_list=None)
21+
**plantcv.mask_kmeans**(labeled_img, k, cat_list=None)
2222

2323
**outputs** Either a combined mask of the requestedlist of clusters or a dictionary of each cluster as a separate mask with keys corresponding to the cluster number
2424

plantcv/plantcv/kmeans_classifier.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ def predict_kmeans(img, model_path="./kmeansout.fit", patch_size=10):
4848
reshape_params = [[h - 2*mg + 1, w - 2*mg + 1], [h - 2*mg, w - 2*mg]]
4949
# Takes care of even vs odd numbered patch size reshaping
5050
labeled = train_labels.reshape(reshape_params[patch_size % 2][0], reshape_params[patch_size % 2][1])
51+
labeled = labeled.astype("uint8")
5152
_debug(visual=labeled, filename=os.path.join(params.debug_outdir, "_labeled_img.png"))
5253
return labeled
5354

@@ -70,7 +71,7 @@ def mask_kmeans(labeled_img, k, cat_list=None):
7071
mask_dict = {}
7172
L = [*range(k)]
7273
for i in L:
73-
mask_light = np.where(labeled_img == i, 255, 0)
74+
mask_light = np.where(labeled_img == i, 255, 0).astype("uint8")
7475
_debug(visual=mask_light, filename=os.path.join(params.debug_outdir, "_kmeans_mask_"+str(i)+".png"))
7576
mask_dict[str(i)] = mask_light
7677
return mask_dict
@@ -80,9 +81,9 @@ def mask_kmeans(labeled_img, k, cat_list=None):
8081
params.debug = None
8182
for idx, i in enumerate(cat_list):
8283
if idx == 0:
83-
mask_light = np.where(labeled_img == i, 255, 0)
84+
mask_light = np.where(labeled_img == i, 255, 0).astype("uint8")
8485
else:
85-
mask_light = pcv.logical_or(mask_light, np.where(labeled_img == i, 255, 0))
86+
mask_light = pcv.logical_or(mask_light, np.where(labeled_img == i, 255, 0).astype("uint8"))
8687
params.debug = debug
8788
_debug(visual=mask_light, filename=os.path.join(params.debug_outdir, "_kmeans_combined_mask.png"))
8889
return mask_light

0 commit comments

Comments
 (0)