Skip to content

Commit e7cdfc2

Browse files
committed
updating core and captup config
1 parent 14fb310 commit e7cdfc2

File tree

4 files changed

+811
-15
lines changed

4 files changed

+811
-15
lines changed

captum/insights/attr_vis/config.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
)
2020
from captum.attr._utils.approximation_methods import SUPPORTED_METHODS
2121

22-
with open('data/response.json') as f:
22+
with open('data/response_temp.json') as f:
2323
response_json = json.load(f)
2424

2525

core/databunch.py

+13-13
Original file line numberDiff line numberDiff line change
@@ -27,41 +27,41 @@ def split_databunch(response, src):
2727

2828
if response_split['validation']['method'] == 'subsets':
2929
args = {
30-
'train_size': response_split['validation']['subsets']['train_size'],
31-
'valid_size': response_split['validation']['subsets']['valid_size'],
32-
'seed': response_split['validation']['subsets']['seed']
30+
'train_size': response_split['validation']['by_subsets']['train_size'],
31+
'valid_size': response_split['validation']['by_subsets']['valid_size'],
32+
'seed': response_split['validation']['by_subsets']['seed']
3333
}
3434

3535
if response_split['validation']['method'] == 'by_files': # TODO: test it out
36-
args = {'valid_name': response_split['validation']['files']['valid_names']}
36+
args = {'valid_name': response_split['validation']['by_files']['valid_names']}
3737

3838
if response_split['validation']['method'] == 'by_fname_file':
3939
args = {
40-
'fname': response_split['validation']['fname_files']['fname'],
41-
'path': response_split['validation']['fname_files']['path']
40+
'fname': response_split['validation']['by_fname_files']['fname'],
41+
'path': response_split['validation']['by_fname_files']['path']
4242
}
4343

4444
if response_split['validation']['method'] == 'by_folder':
4545
args = {
46-
'train': response_split['validation']['folder']['train'],
47-
'valid': response_split['validation']['folder']['valid']
46+
'train': response_split['validation']['by_folder']['train'],
47+
'valid': response_split['validation']['by_folder']['valid']
4848
}
4949
# For tabular, same csv; for vision, csv with labels
5050
if response_split['validation']['method'] == 'by_idx':
5151
df = pd.open_csv(response_split['validation']['csv_name'])
52-
valid_idx = range(len(df) - response_split['validation']['idx']['valid_idx'], len(df))
52+
valid_idx = range(len(df) - response_split['validation']['by_idx']['valid_idx'], len(df))
5353
args = {'valid_idx': valid_idx}
5454

5555
if response_split['validation']['method'] == 'by_idxs':
5656
args = {
57-
'train_idx': response_split['validation']['idxs']['train_idx'],
58-
'valid_idx': response_split['validation']['idxs']['valid_idx']
57+
'train_idx': response_split['validation']['by_idxs']['train_idx'],
58+
'valid_idx': response_split['validation']['by_idxs']['valid_idx']
5959
}
6060

6161
if response_split['validation']['method'] == 'by_list':
6262
args = {
63-
'train': response_split['validation']['list']['train'],
64-
'valid': response_split['validation']['list']['valid']
63+
'train': response_split['validation']['by_list']['train'],
64+
'valid': response_split['validation']['by_list']['valid']
6565
}
6666

6767
if response_split['validation']['method'] == 'by_valid_func':

0 commit comments

Comments
 (0)