Skip to content

Commit 622a9f6

Browse files
committed
updating
1 parent 2a5d422 commit 622a9f6

File tree

6 files changed

+64
-13
lines changed

6 files changed

+64
-13
lines changed

Diff for: config/config.yaml

+6
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ dataset:
2222

2323
training:
2424
model_type:
25+
clip_enc: True
26+
lmm_model: True
2527
encoder: "clip"
2628
decoder1: "roberta"
2729
decoder2: "llama2"
@@ -45,3 +47,7 @@ training:
4547
cnn_encoder:
4648
input_channels: 4
4749
hid_dim: 256
50+
51+
adaptor:
52+
in_dim: 768
53+
features: [512,256,128,64]

Diff for: models/adaptor.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -24,4 +24,4 @@ def forward(self, x_clip, x_roberta):
2424
x = torch.cat((xc,xr), dim=0)
2525
x = torch.flatten(x, start_dim=-2, end_dim=-1)
2626

27-
return x # (B, features[-1])
27+
return x # (B, features[-1])

Diff for: models/lmm_model.py

+34-6
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,35 @@
1+
import torch
2+
import torch.nn as nn
13

2-
class Lmm_model(nn.Module):
3-
def __init__(self,) -> None:
4-
super(Lmm_model, self).__init__()
5-
6-
def forward(self,):
7-
pass
4+
class LLM_model(nn.Module):
5+
6+
def __init__(self,
7+
encoder,
8+
decoder1,
9+
decoder2,
10+
adaptor,
11+
dim,
12+
image_length,
13+
max_len,
14+
num_classes,
15+
):
16+
super(LLM_model, self).__init__()
17+
self.enc = encoder
18+
self.adaptor = adaptor
19+
self.dec1 = decoder1
20+
self.dec2 = decoder2
21+
self.proj1 = nn.Linear(dim, 768)
22+
self.proj2 = nn.Linear(image_length, max_len)
23+
self.clf1 = nn.Linear(64, num_classes)
24+
self.clf2 = nn.Linear(max_len, num_classes)
25+
26+
def forward(self, imgs, ids, attns):
27+
encoded_imgs = self.enc(imgs) # (B, L=w*h, dim)
28+
last_hidden_roberta = self.dec(ids, attns) # (B, max_len, 768)
29+
output = self.adaptor(encoded_imgs,
30+
last_hidden_roberta) # (B, features[-1])
31+
# classifier
32+
output = self.clf1(output)
33+
output = self.clf2(output.permute(0,2,1)).permute(0,2,1) # (B, num_classes, num_classes)
34+
35+
return output

Diff for: run.py

+7
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from models.roberta import RobertaEncoder
1919
from models.llama2 import Llama2Decoder
2020
from models.model import ClevrMath_model
21+
from models.adaptor import Adaptor
2122
from src.training import train
2223
from src.testing import evaluate
2324

@@ -86,9 +87,13 @@ def define_model(max_len):
8687
if decoder2 == "llama2":
8788
DEC2 = Llama2Decoder()
8889

90+
ADA = Adaptor(cfg.training.adaptor.in_dim,
91+
cfg.training.adaptor.features)
92+
8993
model = ClevrMath_model(ENC,
9094
DEC1,
9195
DEC2,
96+
ADA,
9297
dim,
9398
image_length,
9499
max_len,
@@ -191,6 +196,7 @@ def train_model(rank=None):
191196
criterion,
192197
cfg.training.general.clip,
193198
device,
199+
clip_enc=cfg.training.model_type.clip_enc,
194200
ddp=cfg.general.ddp,
195201
rank=rank,
196202
)
@@ -201,6 +207,7 @@ def train_model(rank=None):
201207
val_dataloader,
202208
criterion,
203209
device,
210+
clip_enc=cfg.training.model_type.clip_enc,
204211
)
205212

206213
end_time = time.time()

Diff for: src/testing.py

+1
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ def evaluate(
88
test_dataloader,
99
criterion,
1010
device,
11+
lmm_enc=False,
1112
is_test=False,
1213
):
1314
model.eval()

Diff for: src/training.py

+15-6
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ def train(
1010
criterion,
1111
clip,
1212
device,
13+
clip_enc=False,
1314
ddp=False,
1415
rank=None,
1516
):
@@ -29,14 +30,22 @@ def train(
2930
ids = ids.to(device)
3031
attns = attns.to(device)
3132
labels = labels.to(device, dtype=torch.long)
32-
33-
_imgs = list()
34-
for im in imgs:
35-
tnsr = torch.load(f"{data_path}/image_tensors/{int(im.item())}.pt")
36-
_imgs.append(tnsr)
33+
34+
if not clip_enc:
35+
_imgs = list()
36+
for im in imgs:
37+
tnsr = torch.load(f"{data_path}/image_tensors/{int(im.item())}.pt")
38+
_imgs.append(tnsr)
3739

40+
imgs = torch.stack(_imgs).to(device)
3841

39-
imgs = torch.stack(_imgs).to(device)
42+
else:
43+
_imgs = list()
44+
for im in imgs:
45+
_i = f"{data_path}/images/{int(im.item())}.png"
46+
_imgs.append(_i)
47+
48+
imgs = torch.stack(_imgs).to(device)
4049

4150
# setting gradients to zero
4251
optimizer.zero_grad()

0 commit comments

Comments
 (0)