Skip to content

Commit 231c59b

Browse files
committed
update
1 parent b624d12 commit 231c59b

File tree

7 files changed

+9
-9
lines changed

7 files changed

+9
-9
lines changed

README.md

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# CleverMath_LLM
1+
# ClevrMath_LLM
22

33
Keep preprocess in the config "True" if running it for the first time.
44
Once you have run it, there is no need to re-process the dataset.
@@ -8,6 +8,9 @@ re-do these steps. This can avoided by setting the preprocess param to False in
88

99
### Requirements
1010
```
11+
conda create -n clevrmath python=3.10 -y
12+
source activate clevrmath
13+
1114
pip install torch==2.0.0 torchvision==0.15.1 torchaudio==2.0.1 --index-url https://download.pytorch.org/whl/cu118
1215
pip install -r requirements.txt
1316
```

config/config.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ dataset:
1919
training:
2020
general:
2121
clip: 1
22-
batch_size: 32
22+
batch_size: 4
2323
epochs: 1
2424
dropout: 0.1
2525
learning_rate: 0.0001
@@ -30,4 +30,4 @@ training:
3030
scheduler_gamma: 0.5
3131

3232
unet_encoder:
33-
input_channels: 3
33+
input_channels: 4
Binary file not shown.

preprocessing/create_dataloaders.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -72,10 +72,8 @@ def data_loaders():
7272
t = open("data/templates.lst").readlines()
7373

7474
assert len(q) == len(l) == len(t)
75-
76-
max_len = max([len(i.split()) for i in q])
77-
78-
image_num = range(0, len(q))
75+
76+
image_num = range(0, 20)#len(q))
7977

8078
# split the image_num into train, test, validate
8179
train_val_images, test_images = train_test_split(

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ wandb==0.16.0
33
torchtext==0.6.0
44
pyyaml==6.0.1
55
Pillow==10.1.0
6-
numpy==1.19.3
6+
numpy==1.21.0
77
pandas==1.4.4
88
opencv-python==4.8.1.78
99
opencv-contrib-python==4.8.1.78
-42 Bytes
Binary file not shown.

src/training.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@ def train(
3232
_imgs = list()
3333
for im in imgs:
3434
tnsr = torch.load(f"data/image_tensors/{int(im.item())}.pt")
35-
print("imgs shape: ", tnsr.shape)
3635
_imgs.append(tnsr)
3736

3837

0 commit comments

Comments
 (0)