-
Notifications
You must be signed in to change notification settings - Fork 8
Expand file tree
/
Copy pathdata_module.py
More file actions
113 lines (95 loc) · 3.89 KB
/
data_module.py
File metadata and controls
113 lines (95 loc) · 3.89 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import torch
import soundfile as sf
from typing import Union
from pathlib import Path
import numpy as np
import random
import pytorch_lightning as pl
import torch.utils
import torch.utils.data
class Dataset(torch.utils.data.Dataset):
def __init__(self, data_src_dir: Union[str, Path], data_tgt_dir: Union[str, Path], is_train: bool, cut_len: int = 16000 * 4):
self.is_train = is_train
self.cut_len = cut_len
self.data_src_dir = Path(data_src_dir)
self.data_tgt_dir = Path(data_tgt_dir)
self.wav_names = [p.stem for p in self.data_src_dir.glob('*.wav')]
def normalize_src_tgt(self, src, tgt, eps=1e-8):
# use reference channel (first channel) to normalize all channel
# src: [num_mics, T], tgt: [num_mics, T]
norm_factor = src[0, :].std() + eps
src = src / norm_factor
tgt = tgt / norm_factor
return src, tgt
def load_wav(self, path):
wav, sr = sf.read(path, dtype='float32')
assert sr == 16000
if wav.ndim == 1:
return wav[None, :]
elif wav.ndim == 2:
return wav.T
else:
raise ValueError('wav ndim > 2')
def __len__(self):
return len(self.wav_names)
def __getitem__(self, idx):
name = self.wav_names[idx]
src = self.load_wav(self.data_src_dir / (name + '.wav'))
tgt = self.load_wav(self.data_tgt_dir / (name + '.wav'))
length = src.shape[-1]
assert length == tgt.shape[-1]
if not self.is_train:
tgt = tgt[0] # [length,], reference
src, tgt = self.normalize_src_tgt(src, tgt)
return src, tgt, length, name
# randomly permute the channel index
indices = torch.randperm(src.shape[0]).numpy()
src = src[indices]
tgt = tgt[indices]
tgt = tgt[0] # [length,], reference
if length < self.cut_len:
src = np.pad(src, ((0, 0), (0, self.cut_len - length)), mode='wrap')
tgt = np.pad(tgt, (0, self.cut_len - length), mode='wrap')
else:
# randomly cut segment
wav_start = random.randint(0, length - self.cut_len)
src = src[:, wav_start: wav_start + self.cut_len]
tgt = tgt[wav_start: wav_start + self.cut_len]
src, tgt = self.normalize_src_tgt(src, tgt)
length = self.cut_len
return src, tgt, length, name
class DataModule(pl.LightningDataModule):
def __init__(
self,
train_src_dir,
train_tgt_dir,
val_src_dir,
val_tgt_dir,
test_src_dir,
test_tgt_dir,
batch_size,
cut_len,
num_workers,
):
super().__init__()
self.train_src_dir = train_src_dir
self.train_tgt_dir = train_tgt_dir
self.val_src_dir = val_src_dir
self.val_tgt_dir = val_tgt_dir
self.test_src_dir = test_src_dir
self.test_tgt_dir = test_tgt_dir
self.batch_size = batch_size
self.cut_len = cut_len
self.num_workers = num_workers
def setup(self, stage=None):
if stage == 'fit' or stage is None:
self.train_dataset = Dataset(self.train_src_dir, self.train_tgt_dir, is_train=True, cut_len=self.cut_len)
self.val_dataset = Dataset(self.val_src_dir, self.val_tgt_dir, is_train=False, cut_len=self.cut_len)
if stage == 'test' or stage is None:
self.test_dataset = Dataset(self.test_src_dir, self.test_tgt_dir, is_train=False, cut_len=self.cut_len)
def train_dataloader(self):
return torch.utils.data.DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers)
def val_dataloader(self):
return torch.utils.data.DataLoader(self.val_dataset, batch_size=1, shuffle=False)
def test_dataloader(self):
return torch.utils.data.DataLoader(self.test_dataset, batch_size=1, shuffle=False)