Skip to content

Commit b7454c2

Browse files
gustavocidornelaswhoseoyster
authored andcommitted
Update fastText predict_proba output to conform to expected format
1 parent 881a66a commit b7454c2

File tree

1 file changed

+5
-2
lines changed

1 file changed

+5
-2
lines changed

examples/text-classification/fasttext/fasttext.ipynb

+5-2
Original file line numberDiff line numberDiff line change
@@ -625,7 +625,7 @@
625625
"\n",
626626
"import fasttext\n",
627627
"import pickle\n",
628-
"\n",
628+
"import numpy as np\n",
629629
"\n",
630630
"from pathlib import Path\n",
631631
"from typing import List\n",
@@ -646,7 +646,10 @@
646646
" def predict_proba(self, input_data_df: pd.DataFrame):\n",
647647
" \"\"\"Makes predictions with the model. Returns the class probabilities.\"\"\"\n",
648648
" text_column = input_data_df.columns[0]\n",
649-
" return input_data_df[text_column].apply(self._predict_row)\n",
649+
" \n",
650+
" preds = input_data_df[text_column].apply(self._predict_row)\n",
651+
" \n",
652+
" return np.stack(preds.values)\n",
650653
"\n",
651654
" def _predict_row(self, text: str) -> List[float]:\n",
652655
" text = text.replace(\"\\n\",\" \")\n",

0 commit comments

Comments
 (0)