Skip to content

Commit 2b98497

Browse files
committed
assertation added
1 parent b29b271 commit 2b98497

File tree

2 files changed

+5
-0
lines changed

2 files changed

+5
-0
lines changed

experiments/CamemBERT/model/CamemBERT.py

+2
Original file line numberDiff line numberDiff line change
@@ -46,13 +46,15 @@ def forward(self, input):
4646
print("Parsing sentence tokens.")
4747
example_input = prepare_sentence_tokens(model_name, sentence)
4848
print("example_input shape: ", example_input.shape)
49+
assert example_input.shape == (1, 7, 768), f"Expected shape (1,7,768), but got {example_input.shape}"
4950

5051
# The original example_input shape is [1, 7, 768], now we reshape it into [1, 7*768]
5152
example_input = example_input.reshape(1, 7*768)
5253
print("example_input shape after reshaping: ", example_input.shape)
5354

5455
print("Instantiating model.")
5556
model = OnlyLogitsHuggingFaceModel(model_name)
57+
assert example_input.shape == (1, 7*768), f"Expected the reshaped (1,7*768), but got {example_input.shape}"
5658
print(model(example_input).shape)
5759

5860
linalg_on_tensors_mlir = torch_mlir.compile(

experiments/VIT/model/vit.py

+3
Original file line numberDiff line numberDiff line change
@@ -38,13 +38,16 @@ def forward(self,x):
3838
model = prepare().eval()
3939
example_input = model(inputs)
4040
print(example_input.shape)
41+
assert example_input.shape == (1, 197, 768), f"Expected shape (1,197,768), but got {example_input.shape}"
42+
4143

4244
# The original example_input shape is [1, 197, 768], now we reshape it into [1, 197*768]
4345
example_input = example_input.reshape(1, 197*768)
4446
print("example_input shape after reshaping: ", example_input.shape)
4547

4648

4749
vit_model = vit().eval()
50+
assert example_input.shape == (1, 197*768), f"Expected the reshaped (1,197*768), but got {example_input.shape}"
4851
output = vit_model(example_input)
4952
print(output.shape)
5053

0 commit comments

Comments
 (0)