-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutil.py
98 lines (87 loc) · 4.18 KB
/
util.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import matplotlib.pyplot as plt
import numpy as np
import logging
import sys
import torch
def show_ts_image(ts_data, title='Time Series Data'):
ts_list = ts_data.squeeze().tolist()
time_list = range(len(ts_list))
plt.plot(time_list, ts_list)
plt.xlabel('Time')
plt.ylabel('Value')
plt.title(title)
plt.show()
def show_freq_image(freq_data, title='Frequency Spectrum'):
N = len(freq_data) # 频率数据的长度
sample_rate = 1000 # 假设采样率为1000Hz
freq_resolution = sample_rate / N
freq_list = np.arange(0, sample_rate, freq_resolution)
plt.plot(freq_list, freq_data)
plt.xlabel('Frequency (Hz)')
plt.ylabel('Magnitude')
plt.title(title)
plt.show()
def _logger(logger_name, level=logging.DEBUG):
"""
Method to return a custom logger with the given name and level
"""
logger = logging.getLogger(logger_name)
logger.setLevel(level)
# format_string = ("%(asctime)s — %(name)s — %(levelname)s — %(funcName)s:"
# "%(lineno)d — %(message)s")
format_string = "%(message)s"
log_format = logging.Formatter(format_string)
# Creating and adding the console handler
console_handler = logging.StreamHandler(sys.stdout)
console_handler.setFormatter(log_format)
logger.addHandler(console_handler)
# Creating and adding the file handler
file_handler = logging.FileHandler(logger_name, mode='a')
file_handler.setFormatter(log_format)
logger.addHandler(file_handler)
return logger
def collate_fn(data, max_len=None):
"""Build mini-batch tensors from a list of (X, mask) tuples. Mask input. Create
Args:
data: len(batch_size) list of tuples (X, y).
- X: torch tensor of shape (seq_length, feat_dim); variable seq_length.
- y: torch tensor of shape (num_labels,) : class indices or numerical targets
(for classification or regression, respectively). num_labels > 1 for multi-task models
max_len: global fixed sequence length. Used for architectures requiring fixed length input,
where the batch length cannot vary dynamically. Longer sequences are clipped, shorter are padded with 0s
Returns:
X: (batch_size, padded_length, feat_dim) torch tensor of masked features (input)
targets: (batch_size, padded_length, feat_dim) torch tensor of unmasked features (output)
target_masks: (batch_size, padded_length, feat_dim) boolean torch tensor
0 indicates masked values to be predicted, 1 indicates unaffected/"active" feature values
padding_masks: (batch_size, padded_length) boolean tensor, 1 means keep vector at this position, 0 means padding
"""
batch_size = len(data)
features, labels = zip(*data)
# Stack and pad features and masks (convert 2D to 3D tensors, i.e. add batch dimension)
lengths = [X.shape[0] for X in features] # original sequence length for each time series
if max_len is None:
max_len = max(lengths)
X = torch.zeros(batch_size, max_len, features[0].shape[-1]) # (batch_size, padded_length, feat_dim)
for i in range(batch_size):
end = min(lengths[i], max_len)
X[i, :end, :] = features[i][:end, :]
targets = torch.stack(labels, dim=0) # (batch_size, num_labels)
padding_masks = padding_mask(torch.tensor(lengths, dtype=torch.int16),
max_len=max_len) # (batch_size, padded_length) boolean tensor, "1" means keep
return X, targets, padding_masks
def padding_mask(lengths, max_len=None):
"""
Used to mask padded positions: creates a (batch_size, max_len) boolean mask from a tensor of sequence lengths,
where 1 means keep element at this position (time step)
"""
batch_size = lengths.numel()
max_len = max_len or lengths.max_val() # trick works because of overloading of 'or' operator for non-boolean types
return (torch.arange(0, max_len, device=lengths.device)
.type_as(lengths)
.repeat(batch_size, 1)
.lt(lengths.unsqueeze(1)))
def set_requires_grad(model, dict_, requires_grad=True):
for param in model.named_parameters():
if param[0] in dict_:
param[1].requires_grad = requires_grad