Skip to content

Commit 294e33f

Browse files
ink-padinkpad
and
inkpad
authored
Fix for issue #2
* fix paths to models; fix imports * update readme Co-authored-by: inkpad <[email protected]>
1 parent 12cbf32 commit 294e33f

File tree

5 files changed

+10
-4
lines changed

5 files changed

+10
-4
lines changed

README.md

+5
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,11 @@ You would need git-lfs to access the data.
3333

3434
---
3535

36+
### PRSA Dataset
37+
For PRSA dataset, one have to download the PRSA dataset from [Kaggle](https://www.kaggle.com/sid321axn/beijing-multisite-airquality-data-set) and place them in [./data/card](/data/card/) directory.
38+
39+
---
40+
3641
### Tabular BERT
3742
To train a tabular BERT model on credit card transaction or PRSA dataset run :
3843
```

args.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,10 @@ def define_main_parser(parser=None):
2828
default="card", choices=['card', 'prsa'],
2929
help='root directory for files')
3030
parser.add_argument("--data_root", type=str,
31-
default="./data/",
31+
default="./data/credit_card/",
3232
help='root directory for files')
3333
parser.add_argument("--data_fname", type=str,
34-
default="sd190_trans",
34+
default="card_transaction.v1",
3535
help='file name of transaction')
3636
parser.add_argument("--data_extension", type=str,
3737
default="",

main.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from dataset.prsa import PRSADataset
1212
from dataset.card import TransactionDataset
1313
from models.modules import TabFormerBertLM, TabFormerGPT2
14-
from dataset.dataset import random_split_dataset
14+
from misc.utils import random_split_dataset
1515
from dataset.datacollator import TransDataCollatorForLanguageModeling
1616

1717

models/tabformer_bert.py

+1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from torch.nn import CrossEntropyLoss
44

55
from transformers.modeling_bert import ACT2FN, BertLayerNorm
6+
from transformers.modeling_bert import BertForMaskedLM
67
from transformers.configuration_bert import BertConfig
78
from models.custom_criterion import CustomAdaptiveLogSoftmax
89

requirements.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
torch
22
torchvision
3-
transformers
3+
transformers==3.2.0
44
scikit-learn
55
pandas

0 commit comments

Comments
 (0)