Skip to content

Commit e3bf415

Browse files
Add files via upload
1 parent 598bc8b commit e3bf415

File tree

97 files changed

+1344463
-0
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

97 files changed

+1344463
-0
lines changed

README.md

+24
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
## StreamPrompt: Learnable Prompt-guided Data Selection for Efficient Stream Learning
2+
3+
## Setup
4+
* Install miniconda
5+
* `conda env create -f environment.yml`
6+
* `conda activate sl`
7+
* Install fastmoe library: https://github.com/laekov/fastmoe/blob/master/doc/installation-guide.md
8+
9+
10+
11+
## Datasets
12+
* Create a folder `data/`
13+
* **Clear10**, **Clear100**: retrieve from: https://clear-benchmark.github.io/
14+
* **CORe50**: `sh core50.sh`
15+
16+
## Training
17+
All commands should be run under the project root directory. **The scripts are set up for 1 GPUs** but can be modified for your hardware.
18+
19+
```bash
20+
sh experiments/clear10.sh
21+
sh experiments/imagenet-r.sh
22+
sh experiments/domainnet.sh
23+
```
24+

buffer/__init__.py

Whitespace-only changes.

buffer/aser_update.py

+115
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
import torch
2+
from collections import namedtuple
3+
from buffer.reservoir_update import Reservoir_update
4+
from buffer.buffer_utils import ClassBalancedRandomSampling, random_retrieve, n_classes, nonzero_indices, maybe_cuda
5+
from buffer.aser_utils import compute_knn_sv, add_minority_class_input
6+
7+
8+
class ASER_update(object):
9+
def __init__(self, config, **kwargs):
10+
super().__init__()
11+
Args = namedtuple('Args', ['mem_size', 'num_tasks', 'data'])
12+
params = Args(mem_size=config.mem_size, num_tasks=None, data=config.dataset)
13+
# params = Args(mem_size=config.mem_size, num_tasks=config.num_tasks, data=config.dataset)
14+
15+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
16+
self.k = 3
17+
self.mem_size = params.mem_size
18+
self.num_tasks = params.num_tasks
19+
self.out_dim = n_classes[params.data]
20+
self.n_smp_cls = int(1.5)
21+
self.n_total_smp = int(1.5 * self.out_dim)
22+
self.reservoir_update = Reservoir_update(params)
23+
ClassBalancedRandomSampling.class_index_cache = None
24+
25+
def update(self, buffer, x, y, **kwargs):
26+
model = buffer.model
27+
28+
place_left = self.mem_size - buffer.current_index
29+
30+
# If buffer is not filled, use available space to store whole or part of batch
31+
if place_left:
32+
x_fit = x[:place_left]
33+
y_fit = y[:place_left]
34+
35+
ind = torch.arange(start=buffer.current_index, end=buffer.current_index + x_fit.size(0), device=self.device)
36+
ClassBalancedRandomSampling.update_cache(buffer.buffer_label, self.out_dim,
37+
new_y=y_fit, ind=ind, device=self.device)
38+
self.reservoir_update.update(buffer, x_fit, y_fit)
39+
40+
# If buffer is filled, update buffer by sv
41+
if buffer.current_index == self.mem_size:
42+
# remove what is already in the buffer
43+
cur_x, cur_y = x[place_left:], y[place_left:]
44+
self._update_by_knn_sv(model, buffer, cur_x, cur_y)
45+
46+
def _update_by_knn_sv(self, model, buffer, cur_x, cur_y):
47+
"""
48+
Returns indices for replacement.
49+
Buffered instances with smallest SV are replaced by current input with higher SV.
50+
Args:
51+
model (object): neural network.
52+
buffer (object): buffer object.
53+
cur_x (tensor): current input data tensor.
54+
cur_y (tensor): current input label tensor.
55+
Returns
56+
ind_buffer (tensor): indices of buffered instances to be replaced.
57+
ind_cur (tensor): indices of current data to do replacement.
58+
"""
59+
cur_x = maybe_cuda(cur_x)
60+
cur_y = maybe_cuda(cur_y)
61+
62+
# Find minority class samples from current input batch
63+
minority_batch_x, minority_batch_y = add_minority_class_input(cur_x, cur_y, self.mem_size, self.out_dim)
64+
65+
# Evaluation set
66+
eval_x, eval_y, eval_indices = \
67+
ClassBalancedRandomSampling.sample(buffer.buffer_img, buffer.buffer_label, self.n_smp_cls,
68+
device=self.device)
69+
70+
# Concatenate minority class samples from current input batch to evaluation set
71+
eval_x = torch.cat((eval_x, minority_batch_x))
72+
eval_y = torch.cat((eval_y, minority_batch_y))
73+
74+
# Candidate set
75+
cand_excl_indices = set(eval_indices.tolist())
76+
cand_x, cand_y, cand_ind = random_retrieve(buffer, self.n_total_smp, cand_excl_indices, return_indices=True)
77+
78+
# Concatenate current input batch to candidate set
79+
cand_x = torch.cat((cand_x, cur_x))
80+
cand_y = torch.cat((cand_y, cur_y))
81+
82+
sv_matrix = compute_knn_sv(model, eval_x, eval_y, cand_x, cand_y, self.k, device=self.device)
83+
sv = sv_matrix.sum(0)
84+
85+
n_cur = cur_x.size(0)
86+
n_cand = cand_x.size(0)
87+
88+
# Number of previously buffered instances in candidate set
89+
n_cand_buf = n_cand - n_cur
90+
91+
sv_arg_sort = sv.argsort(descending=True)
92+
93+
# Divide SV array into two segments
94+
# - large: candidate args to be retained; small: candidate args to be discarded
95+
sv_arg_large = sv_arg_sort[:n_cand_buf]
96+
sv_arg_small = sv_arg_sort[n_cand_buf:]
97+
98+
# Extract args relevant to replacement operation
99+
# If current data instances are in 'large' segment, they are added to buffer
100+
# If buffered instances are in 'small' segment, they are discarded from buffer
101+
# Replacement happens between these two sets
102+
# Retrieve original indices from candidate args
103+
ind_cur = sv_arg_large[nonzero_indices(sv_arg_large >= n_cand_buf)] - n_cand_buf
104+
arg_buffer = sv_arg_small[nonzero_indices(sv_arg_small < n_cand_buf)]
105+
ind_buffer = cand_ind[arg_buffer]
106+
107+
buffer.n_seen_so_far += n_cur
108+
109+
# perform overwrite op
110+
y_upt = cur_y[ind_cur]
111+
x_upt = cur_x[ind_cur]
112+
ClassBalancedRandomSampling.update_cache(buffer.buffer_label, self.out_dim,
113+
new_y=y_upt, ind=ind_buffer, device=self.device)
114+
buffer.buffer_img[ind_buffer] = x_upt
115+
buffer.buffer_label[ind_buffer] = y_upt

buffer/aser_utils.py

+156
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
import torch
2+
from buffer.buffer_utils import maybe_cuda, mini_batch_deep_features, euclidean_distance, nonzero_indices, ohe_label
3+
from buffer.buffer_utils import ClassBalancedRandomSampling
4+
5+
6+
def compute_knn_sv(model, eval_x, eval_y, cand_x, cand_y, k, device="cpu"):
7+
"""
8+
Compute KNN SV of candidate data w.r.t. evaluation data.
9+
Args:
10+
model (object): neural network.
11+
eval_x (tensor): evaluation data tensor.
12+
eval_y (tensor): evaluation label tensor.
13+
cand_x (tensor): candidate data tensor.
14+
cand_y (tensor): candidate label tensor.
15+
k (int): number of nearest neighbours.
16+
device (str): device for tensor allocation.
17+
Returns
18+
sv_matrix (tensor): KNN Shapley value matrix of candidate data w.r.t. evaluation data.
19+
"""
20+
# Compute KNN SV score for candidate samples w.r.t. evaluation samples
21+
n_eval = eval_x.size(0)
22+
n_cand = cand_x.size(0)
23+
# Initialize SV matrix to matrix of -1
24+
sv_matrix = torch.zeros((n_eval, n_cand), device=device)
25+
# Get deep features
26+
eval_df, cand_df = deep_features(model, eval_x, n_eval, cand_x, n_cand)
27+
# Sort indices based on distance in deep feature space
28+
sorted_ind_mat = sorted_cand_ind(eval_df, cand_df, n_eval, n_cand)
29+
30+
# Evaluation set labels
31+
el = eval_y
32+
el_vec = el.repeat([n_cand, 1]).T
33+
# Sorted candidate set labels
34+
cl = cand_y[sorted_ind_mat]
35+
36+
# Indicator function matrix
37+
indicator = (el_vec == cl).float()
38+
indicator_next = torch.zeros_like(indicator, device=device)
39+
indicator_next[:, 0:n_cand - 1] = indicator[:, 1:]
40+
indicator_diff = indicator - indicator_next
41+
42+
cand_ind = torch.arange(n_cand, dtype=torch.float, device=device) + 1
43+
denom_factor = cand_ind.clone()
44+
denom_factor[:n_cand - 1] = denom_factor[:n_cand - 1] * k
45+
numer_factor = cand_ind.clone()
46+
numer_factor[k:n_cand - 1] = k
47+
numer_factor[n_cand - 1] = 1
48+
factor = numer_factor / denom_factor
49+
50+
indicator_factor = indicator_diff * factor
51+
indicator_factor_cumsum = indicator_factor.flip(1).cumsum(1).flip(1)
52+
53+
# Row indices
54+
row_ind = torch.arange(n_eval, device=device)
55+
row_mat = torch.repeat_interleave(row_ind, n_cand).reshape([n_eval, n_cand])
56+
57+
# Compute SV recursively
58+
sv_matrix[row_mat, sorted_ind_mat] = indicator_factor_cumsum
59+
60+
return sv_matrix
61+
62+
63+
def deep_features(model, eval_x, n_eval, cand_x, n_cand):
64+
"""
65+
Compute deep features of evaluation and candidate data.
66+
Args:
67+
model (object): neural network.
68+
eval_x (tensor): evaluation data tensor.
69+
n_eval (int): number of evaluation data.
70+
cand_x (tensor): candidate data tensor.
71+
n_cand (int): number of candidate data.
72+
Returns
73+
eval_df (tensor): deep features of evaluation data.
74+
cand_df (tensor): deep features of evaluation data.
75+
"""
76+
# Get deep features
77+
if cand_x is None:
78+
num = n_eval
79+
total_x = eval_x
80+
else:
81+
num = n_eval + n_cand
82+
total_x = torch.cat((eval_x, cand_x), 0)
83+
84+
# compute deep features with mini-batches
85+
total_x = maybe_cuda(total_x)
86+
deep_features_ = mini_batch_deep_features(model, total_x, num)
87+
88+
eval_df = deep_features_[0:n_eval]
89+
cand_df = deep_features_[n_eval:]
90+
return eval_df, cand_df
91+
92+
93+
def sorted_cand_ind(eval_df, cand_df, n_eval, n_cand):
94+
"""
95+
Sort indices of candidate data according to
96+
their Euclidean distance to each evaluation data in deep feature space.
97+
Args:
98+
eval_df (tensor): deep features of evaluation data.
99+
cand_df (tensor): deep features of evaluation data.
100+
n_eval (int): number of evaluation data.
101+
n_cand (int): number of candidate data.
102+
Returns
103+
sorted_cand_ind (tensor): sorted indices of candidate set w.r.t. each evaluation data.
104+
"""
105+
# Sort indices of candidate set according to distance w.r.t. evaluation set in deep feature space
106+
# Preprocess feature vectors to facilitate vector-wise distance computation
107+
eval_df_repeat = eval_df.repeat([1, n_cand]).reshape([n_eval * n_cand, eval_df.shape[1]])
108+
cand_df_tile = cand_df.repeat([n_eval, 1])
109+
# Compute distance between evaluation and candidate feature vectors
110+
distance_vector = euclidean_distance(eval_df_repeat, cand_df_tile)
111+
# Turn distance vector into distance matrix
112+
distance_matrix = distance_vector.reshape((n_eval, n_cand))
113+
# Sort candidate set indices based on distance
114+
sorted_cand_ind_ = distance_matrix.argsort(1)
115+
return sorted_cand_ind_
116+
117+
118+
def add_minority_class_input(cur_x, cur_y, mem_size, num_class):
119+
"""
120+
Find input instances from minority classes, and concatenate them to evaluation data/label tensors later.
121+
This facilitates the inclusion of minority class samples into memory when ASER's update method is used under online-class incremental setting.
122+
123+
More details:
124+
125+
Evaluation set may not contain any samples from minority classes (i.e., those classes with very few number of corresponding samples stored in the memory).
126+
This happens after task changes in online-class incremental setting.
127+
Minority class samples can then get very low or negative KNN-SV, making it difficult to store any of them in the memory.
128+
129+
By identifying minority class samples in the current input batch, and concatenating them to the evaluation set,
130+
KNN-SV of the minority class samples can be artificially boosted (i.e., positive value with larger magnitude).
131+
This allows to quickly accomodate new class samples in the memory right after task changes.
132+
133+
Threshold for being a minority class is a hyper-parameter related to the class proportion.
134+
In this implementation, it is randomly selected between 0 and 1 / number of all classes for each current input batch.
135+
136+
137+
Args:
138+
cur_x (tensor): current input data tensor.
139+
cur_y (tensor): current input label tensor.
140+
mem_size (int): memory size.
141+
num_class (int): number of classes in dataset.
142+
Returns
143+
minority_batch_x (tensor): subset of current input data from minority class.
144+
minority_batch_y (tensor): subset of current input label from minority class.
145+
"""
146+
# Select input instances from minority classes that will be concatenated to pre-selected data
147+
threshold = torch.tensor(1).float().uniform_(0, 1 / num_class).item()
148+
149+
# If number of buffered samples from certain class is lower than random threshold,
150+
# that class is minority class
151+
cls_proportion = ClassBalancedRandomSampling.class_num_cache.float() / mem_size
152+
minority_ind = nonzero_indices(cls_proportion[cur_y] < threshold)
153+
154+
minority_batch_x = cur_x[minority_ind]
155+
minority_batch_y = cur_y[minority_ind]
156+
return minority_batch_x, minority_batch_y

0 commit comments

Comments
 (0)