Skip to content

Commit f84dda5

Browse files
committed
Report sequences with uni-valued labels when training
1 parent 422fe7d commit f84dda5

File tree

1 file changed

+5
-6
lines changed

1 file changed

+5
-6
lines changed

gecco/crf/__init__.py

+5-6
Original file line numberDiff line numberDiff line change
@@ -355,6 +355,11 @@ def fit(
355355
# extract features and labels
356356
feats: List[Dict[str, bool]] = extract_features(sequence)
357357
labels: List[str] = extract_labels(sequence)
358+
if all(label == "0" for label in labels):
359+
raise ValueError(f"only negative labels found in sequence {sequence[0].source.id!r}")
360+
elif all(label == "1" for label in labels):
361+
raise ValueError(f"only positive labels found in sequence {sequence[0].source.id!r}")
362+
358363
# check we have as many observations as we have labels
359364
if len(feats) != len(labels):
360365
raise ValueError("different number of features and labels found, something is wrong")
@@ -366,12 +371,6 @@ def fit(
366371
training_features.append(feats[win])
367372
training_labels.append(labels[win])
368373

369-
# check labels
370-
if all(label == "1" for y in training_labels for label in y):
371-
raise ValueError("only positives labels found, something is wrong.")
372-
elif all(label == "0" for y in training_labels for label in y):
373-
raise ValueError("only negative labels found, something is wrong.")
374-
375374
# fit the model
376375
self.model = model = sklearn_crfsuite.CRF(**self._options)
377376
model.fit(training_features, training_labels)

0 commit comments

Comments
 (0)