Skip to content
2 changes: 1 addition & 1 deletion deeppavlov/core/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,7 +465,7 @@ def flatten_str_batch(batch: Union[str, Iterable]) -> Union[list, chain]:
['a', 'b', 'c', 'd']

"""
if isinstance(batch, str):
if isinstance(batch, str) or isinstance(batch, int) or isinstance(batch, float):
return [batch]
else:
return chain(*[flatten_str_batch(sample) for sample in batch])
Expand Down
5 changes: 4 additions & 1 deletion deeppavlov/core/trainers/nn_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
from pathlib import Path
from typing import List, Tuple, Union, Optional, Iterable

from tqdm import tqdm

from deeppavlov.core.common.errors import ConfigError
from deeppavlov.core.common.registry import register
from deeppavlov.core.data.data_learning_iterator import DataLearningIterator
Expand Down Expand Up @@ -279,7 +281,8 @@ def train_on_batches(self, iterator: DataLearningIterator) -> None:
while True:
impatient = False
self._send_event(event_name='before_train')
for x, y_true in iterator.gen_batches(self.batch_size, data_type='train'):
log.info('The model training started')
for x, y_true in tqdm(iterator.gen_batches(self.batch_size, data_type='train')):
self.last_result = self._chainer.train_on_batch(x, y_true)
if self.last_result is None:
self.last_result = {}
Expand Down
42 changes: 25 additions & 17 deletions deeppavlov/dataset_readers/basic_classification_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ class BasicClassificationDatasetReader(DatasetReader):
@overrides
def read(self, data_path: str, url: str = None,
format: str = "csv", class_sep: str = None,
float_labels: bool = False,
*args, **kwargs) -> dict:
"""
Read dataset from data_path directory.
Expand All @@ -48,6 +49,8 @@ def read(self, data_path: str, url: str = None,
format: extension of files. Set of Values: ``"csv", "json"``
class_sep: string separator of labels in column with labels
sep (str): delimeter for ``"csv"`` files. Default: None -> only one class per sample
float_labels (boolean): if True and class_sep is not None, we treat all classes as float
quotechar (str): what char we consider as quote in the dataset
header (int): row number to use as the column names
names (array): list of column names to use
orient (str): indication of expected JSON string format
Expand Down Expand Up @@ -80,7 +83,7 @@ def read(self, data_path: str, url: str = None,
file = Path(data_path).joinpath(file_name)
if file.exists():
if format == 'csv':
keys = ('sep', 'header', 'names')
keys = ('sep', 'header', 'names', 'quotechar')
options = {k: kwargs[k] for k in keys if k in kwargs}
df = pd.read_csv(file, **options)
elif format == 'json':
Expand All @@ -92,22 +95,27 @@ def read(self, data_path: str, url: str = None,

x = kwargs.get("x", "text")
y = kwargs.get('y', 'labels')
if isinstance(x, list):
if class_sep is None:
# each sample is a tuple ("text", "label")
data[data_type] = [([row[x_] for x_ in x], str(row[y]))
for _, row in df.iterrows()]
else:
# each sample is a tuple ("text", ["label", "label", ...])
data[data_type] = [([row[x_] for x_ in x], str(row[y]).split(class_sep))
for _, row in df.iterrows()]
else:
if class_sep is None:
# each sample is a tuple ("text", "label")
data[data_type] = [(row[x], str(row[y])) for _, row in df.iterrows()]
else:
# each sample is a tuple ("text", ["label", "label", ...])
data[data_type] = [(row[x], str(row[y]).split(class_sep)) for _, row in df.iterrows()]
data[data_type] = []
i = 0
prev_n_classes = 0 # to capture samples with different n_classes
for _, row in df.iterrows():
if isinstance(x, list):
sample = [row[x_] for x_ in x]
else:
sample = row[x]
label = str(row[y])
if class_sep:
label = str(row[y]).split(class_sep)
if prev_n_classes == 0:
prev_n_classes = len(label)
assert len(label) == prev_n_classes, f"Wrong class number at {i} row"
if float_labels:
label = [float(k) for k in label]
if sample == sample and label == label: # not NAN
data[data_type].append((sample, label))
else:
log.warning(f'Skipping NAN received in file {file} at {i} row')
i += 1
else:
log.warning("Cannot find {} file".format(file))

Expand Down