Skip to content

Commit

Permalink
Improve tests
Browse files Browse the repository at this point in the history
Signed-off-by: Dashiell Stander <[email protected]>
  • Loading branch information
dashstander committed Sep 25, 2024
1 parent 154172e commit 0360516
Show file tree
Hide file tree
Showing 5 changed files with 46 additions and 34 deletions.
3 changes: 1 addition & 2 deletions src/algebraist/irreps.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# limitations under the License.


from copy import deepcopy
from functools import cached_property, reduce
from itertools import combinations, pairwise
import numpy as np
Expand All @@ -26,7 +25,7 @@
from algebraist.utils import adj_trans_decomp, cycle_to_one_line, trans_to_one_line


def contiguous_cycle(n: int, i: int):
def contiguous_cycle(n: int, i: int) -> tuple[int, ...]:
""" Generates a permutation (in cycle notation) of the form (i, i+1, ..., n)
"""
if i == n - 1:
Expand Down
23 changes: 7 additions & 16 deletions tests/test_fourier.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch
import algebraist
from algebraist.fourier import (
slow_sn_ft, slow_sn_ift, slow_sn_fourier_decomposition, sn_fft, sn_ifft, sn_fourier_decomposition, calc_power
slow_sn_ft, slow_sn_ift, sn_fft, sn_ifft, sn_fourier_decomposition, calc_power
)
from algebraist.permutations import Permutation
from algebraist.irreps import SnIrrep
Expand Down Expand Up @@ -48,9 +48,10 @@ def generate_random_fourier_transform(n, batch_size=None):
for irrep in SnIrrep.generate_all_irreps(n):
ft[irrep.partition] = torch.randn(batch_size, irrep.dim, irrep.dim)
if not has_batch:
ft[irrep.partition] =ft[irrep.partition].squeeze()
ft[irrep.partition] = ft[irrep.partition].squeeze()
return ft


@pytest.mark.parametrize("n", [3, 4, 5])
@pytest.mark.parametrize("batch_size", [None, 1, 5])
def test_fourier_transform_invertibility(n, batch_size):
Expand All @@ -66,15 +67,9 @@ def test_fourier_transform_invertibility(n, batch_size):
def test_fourier_decomposition(n, batch_size):
f = generate_random_function(n, batch_size)
ft = sn_fft(f, n)
decomp = slow_sn_fourier_decomposition(ft, n)
if batch_size is not None and batch_size > 1:
assert decomp.shape == (batch_size, len(ft), math.factorial(n))
else:
assert decomp.shape == (len(ft), math.factorial(n))
reconstructed = decomp.sum(dim=-2)
if batch_size is None:
f = f.unsqueeze(0)
assert torch.allclose(f, reconstructed, atol=1e-5), f"Fourier decomposition failed for n={n}, batch_size={batch_size}"
decomp = sn_fourier_decomposition(ft, n)

assert torch.allclose(f, sum(decomp.values()), atol=1e-5), f"Fourier decomposition failed for n={n}, batch_size={batch_size}"


@pytest.mark.parametrize("n", [3, 4, 5])
Expand Down Expand Up @@ -109,13 +104,11 @@ def test_convolution_theorem(n):
ft_conv_freq[shape] = ft_f[shape] * ft_g[shape]
else:
ft_conv_freq[shape] = ft_f[shape] @ ft_g[shape]
#ft_conv = {shape: torch.matmul(ft_f[shape], ft_g[shape]) for shape in ft_f.keys()}
for shape in ft_f.keys():
assert torch.allclose(ft_conv_time[shape], ft_conv_freq[shape], atol=1.e-4),\
f"Convolution theorem failed for n={n}, partition={shape}, max diff = {(ft_conv_time[shape] - ft_conv_freq[shape]).abs().max()}"



@pytest.mark.parametrize("n", [3, 4, 5])
def test_permutation_action(n):
f = generate_random_function(n, None)
Expand All @@ -142,8 +135,7 @@ def test_permutation_action(n):

for shape in ft_action.keys():
assert torch.allclose(ft_perm[shape], ft_action[shape], atol=1e-4), \
f"Permutation action failed for n={n}, shape={shape}"

f"Permutation action failed for n={n}, shape={shape}"


@pytest.mark.parametrize("n", [3, 4, 5])
Expand All @@ -167,6 +159,5 @@ def test_sn_ifft(n):
assert torch.allclose(slow_ift, fast_ift)



if __name__ == '__main__':
pytest.main(['-v', '-s'])
2 changes: 1 addition & 1 deletion tests/test_irreps.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def test_all_partitions(n):
assert matrix.shape == (irrep.dim, irrep.dim)


@given( sn_with_permutations())
@given(sn_with_permutations())
def test_representation_homomorphism(n_and_permutations):
n, permutations = n_and_permutations
partitions = generate_partitions(n)
Expand Down
37 changes: 24 additions & 13 deletions tests/test_permutations.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@

from algebraist.permutations import Permutation


# Helper strategy to generate valid permutations
@st.composite
def permutation_strategy(draw, max_n=10):
n = draw(st.integers(min_value=1, max_value=max_n))
n = draw(st.integers(min_value=3, max_value=max_n))
return Permutation(draw(st.permutations(range(n))))


Expand All @@ -17,17 +18,17 @@ def test_init():
assert p.n == 3


def test_full_group():
group = Permutation.full_group(3)
assert len(group) == 6
assert Permutation([0, 1, 2]) in group
assert Permutation([2, 1, 0]) in group
@given(n=st.integers(min_value=3, max_value=6))
def test_full_group(n):
group = Permutation.full_group(n)
assert len(group) == math.factorial(n)


def test_identity():
id3 = Permutation.identity(3)
assert id3.sigma == (0, 1, 2)
assert id3.is_identity()
@given(n=st.integers(min_value=3, max_value=6))
def test_identity(n):
ident = Permutation.identity(n)
assert ident == tuple(range(n))
assert ident.is_identity()


def test_transposition():
Expand All @@ -41,6 +42,7 @@ def test_multiplication():
p3 = p1 * p2
assert p3.sigma == (0, 1, 2)


@pytest.mark.parametrize("perm, power, expected", [
(Permutation([1, 2, 0]), 2, (2, 0, 1)),
(Permutation([1, 2, 0]), 3, (0, 1, 2)),
Expand Down Expand Up @@ -79,7 +81,6 @@ def test_transposition_decomposition():
assert p.transposition_decomposition() == [(0, 2), (1, 2)]



@pytest.mark.parametrize("n", [3, 4, 5])
def test_permutation_index(n):
indices = [p.permutation_index() for p in Permutation.full_group(n)]
Expand Down Expand Up @@ -114,32 +115,42 @@ def test_parity_product_property(perm1, perm2):
def test_double_inverse_property(perm):
assert perm == perm.inverse.inverse


# Property: order of permutation divides group order (n!)
@given(perm=permutation_strategy())
def test_order_divides_group_order(perm):
from math import factorial
assert factorial(perm.n) % perm.order == 0
assert math.factorial(perm.n) % perm.order == 0


# Property: conjugacy class partition sums to n
@given(perm=permutation_strategy())
def test_conjugacy_class_sum(perm):
assert sum(perm.conjugacy_class) == perm.n


# Property: multiplication is associative
@given(perm1=permutation_strategy(), perm2=permutation_strategy(), perm3=permutation_strategy())
def test_multiplication_associativity(perm1, perm2, perm3):
if perm1.n != perm2.n or perm2.n != perm3.n:
return # Skip if permutations are of different sizes
assert (perm1 * perm2) * perm3 == perm1 * (perm2 * perm3)


# Property: identity permutation is neutral element
@given(perm=permutation_strategy())
def test_identity_neutral(perm):
identity = Permutation.identity(perm.n)
assert perm * identity == perm
assert identity * perm == perm


# Property: permutation to the power of its order is identity
@given(perm=permutation_strategy())
def test_power_order_is_identity(perm):
assert (perm ** perm.order).is_identity()


@given(perm=permutation_strategy())
def test_inverse_gives_identity(perm):
ident = Permutation.identity(perm.n)
assert perm * perm.inverse == ident
15 changes: 13 additions & 2 deletions tests/test_tableau.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,19 @@ def test_hook_length_known_values():


def test_generate_partitions():
assert set(generate_partitions(4)) == {(4,), (3, 1), (2, 2), (2, 1, 1), (1, 1, 1, 1)}
assert set(generate_partitions(5)) == {(5,), (4, 1), (3, 2), (3, 1, 1), (2, 2, 1), (2, 1, 1, 1), (1, 1, 1, 1, 1)}
partsof9 = {
(9,), (8, 1), (7, 2), (7, 1, 1), (6, 3), (6, 2, 1), (6, 1, 1, 1), (5, 4), (5, 3, 1),
(5, 2, 2), (5, 2, 1, 1), (5, 1, 1, 1, 1), (4, 4, 1), (4, 3, 2), (4, 3, 1, 1), (4, 2, 2, 1), (4, 2, 1, 1, 1),
(4, 1, 1, 1, 1, 1), (3, 3, 3), (3, 3, 2, 1), (3, 3, 1, 1, 1), (3, 2, 2, 2), (3, 2, 2, 1, 1), (3, 2, 1, 1, 1, 1), (3, 1, 1, 1, 1, 1, 1),
(2, 2, 2, 2, 1), (2, 2, 2, 1, 1, 1), (2, 2, 1, 1, 1, 1, 1), (2, 1, 1, 1, 1, 1, 1, 1), (1, 1, 1, 1, 1, 1, 1, 1, 1)
}
partsof8 = {
(8,), (7, 1), (6, 2), (6, 1, 1), (5, 3), (5, 2, 1), (5, 1, 1, 1), (4, 4), (4, 3, 1), (4, 2, 2), (4, 2, 1, 1), (4, 1, 1, 1, 1),
(3, 3, 2), (3, 3, 1, 1), (3, 2, 2, 1), (3, 2, 1, 1, 1), (3, 1, 1, 1, 1, 1), (2, 2, 2, 2), (2, 2, 2, 1, 1), (2, 2, 1, 1, 1, 1),
(2, 1, 1, 1, 1, 1, 1), (1, 1, 1, 1, 1, 1, 1, 1)
}
assert set(generate_partitions(8)) == partsof8
assert set(generate_partitions(9)) == partsof9


def test_enumerate_standard_tableau():
Expand Down

0 comments on commit 0360516

Please sign in to comment.