Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,19 @@ outputs = model.generate(input_ids, max_length=32, do_sample=True, top_p=0.4, t
print(tokenizer.batch_decode(outputs, skip_special_tokens=True)[0])
```

# Running Tests

To run the tests, ensure you have `pytest` installed:

```sh
pip install pytest
```

Run the tests using:

```sh
pytest /workspaces/matmulfreellm/tests
```

# Citation
If you use this repo in your work, please cite our preprint:
Expand Down
12 changes: 12 additions & 0 deletions tests/test_fusedbitnet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import torch
from mmfreelm.ops.fusedbitnet import activation_quant, weight_quant

def test_activation_quant():
x = torch.tensor([1.0, 2.0, 3.0])
result = activation_quant(x)
assert result is not None # Add more specific assertions based on expected behavior

def test_weight_quant():
w = torch.tensor([1.0, 2.0, 3.0])
result = weight_quant(w)
assert result is not None # Add more specific assertions based on expected behavior
7 changes: 7 additions & 0 deletions tests/test_generate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from transformers import AutoTokenizer

def test_generate_script():
tokenizer = AutoTokenizer.from_pretrained("gpt2") # Replace with a valid model name
input_prompt = "Test prompt"
input_ids = tokenizer(input_prompt, return_tensors="pt").input_ids
assert input_ids is not None
16 changes: 16 additions & 0 deletions tests/test_layernorm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import torch
from mmfreelm.modules.layernorm import layer_norm_ref, rms_norm_ref

def test_layer_norm_ref():
x = torch.randn(10, 10)
weight = torch.ones(10)
bias = torch.zeros(10)
result = layer_norm_ref(x, weight, bias)
assert result is not None # Add more specific assertions based on expected behavior

def test_rms_norm_ref():
x = torch.randn(10, 10)
weight = torch.ones(10)
bias = torch.zeros(10)
result = rms_norm_ref(x, weight, bias)
assert result is not None # Add more specific assertions based on expected behavior
11 changes: 11 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import torch
from mmfreelm.utils import contiguous

def test_contiguous():
@contiguous
def dummy_fn(ctx, x):
return x

x = torch.randn(10, 10).t() # Non-contiguous tensor
result = dummy_fn(None, x)
assert result.is_contiguous()