Skip to content
2 changes: 2 additions & 0 deletions src/bloqade/geometry/dialects/grid/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,5 +31,7 @@
Scale as Scale,
Shape as Shape,
Shift as Shift,
ShiftSubgridX as ShiftSubgridX,
ShiftSubgridY as ShiftSubgridY,
)
from .types import Grid as Grid, GridType as GridType
34 changes: 34 additions & 0 deletions src/bloqade/geometry/dialects/grid/_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
Scale,
Shape,
Shift,
ShiftSubgridX,
ShiftSubgridY,
)
from .types import Grid

Expand Down Expand Up @@ -231,6 +233,38 @@ def shift(grid: Grid[Nx, Ny], x_shift: float, y_shift: float) -> Grid[Nx, Ny]:
...


@_wraps(ShiftSubgridX)
def shift_subgrid_x(
grid: Grid[Nx, Ny], x_indices: ilist.IList[int, typing.Any], x_shift: float
) -> Grid[Nx, Ny]:
"""Shift a sub grid of grid in the x directions.

Args:
grid (Grid): a grid object
x_indices (ilist.IList[int, typing.Any]): a list/ilist of x indices to shift
x_shift (float): shift in the x direction
Returns:
Grid: a new grid object that has been shifted
"""
...


@_wraps(ShiftSubgridY)
def shift_subgrid_y(
grid: Grid[Nx, Ny], y_indices: ilist.IList[int, typing.Any], y_shift: float
) -> Grid[Nx, Ny]:
"""Shift a sub grid of grid in the y directions.

Args:
grid (Grid): a grid object
y_indices (ilist.IList[int, typing.Any]): a list/ilist of y indices to shift
y_shift (float): shift in the y direction
Returns:
Grid: a new grid object that has been shifted
"""
...


@_wraps(Shape)
def shape(grid: Grid) -> tuple[int, int]:
"""Get the shape of a grid.
Expand Down
26 changes: 26 additions & 0 deletions src/bloqade/geometry/dialects/grid/concrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,32 @@ def shift(

return (grid.shift(x_shift, y_shift),)

@impl(stmts.ShiftSubgridX)
def shift_subgrid_x(
self,
interp: Interpreter,
frame: Frame,
stmt: stmts.ShiftSubgridX,
):
grid = frame.get_casted(stmt.zone, Grid)
x_indices = frame.get_casted(stmt.x_indices, ilist.IList)
x_shift = frame.get_casted(stmt.x_shift, float)

return (grid.shift_subgrid_x(x_indices, x_shift),)

@impl(stmts.ShiftSubgridY)
def shift_subgrid_y(
self,
interp: Interpreter,
frame: Frame,
stmt: stmts.ShiftSubgridY,
):
grid = frame.get_casted(stmt.zone, Grid)
y_indices = frame.get_casted(stmt.y_indices, ilist.IList)
y_shift = frame.get_casted(stmt.y_shift, float)

return (grid.shift_subgrid_y(y_indices, y_shift),)

@impl(stmts.Scale)
def scale(
self,
Expand Down
28 changes: 28 additions & 0 deletions src/bloqade/geometry/dialects/grid/stmts.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,34 @@ class Shift(ir.Statement):
result: ir.ResultValue = info.result(GridType[NumX, NumY])


@statement(dialect=dialect)
class ShiftSubgridX(ir.Statement):
name = "shift_subgrid_x"
traits = frozenset({ir.Pure(), lowering.FromPythonCall()})
zone: ir.SSAValue = info.argument(
type=GridType[NumX := types.TypeVar("NumX"), NumY := types.TypeVar("NumY")]
)
x_indices: ir.SSAValue = info.argument(
ilist.IListType[types.Int, types.TypeVar("SubNumX")]
)
x_shift: ir.SSAValue = info.argument(types.Float)
result: ir.ResultValue = info.result(GridType[NumX, NumY])


@statement(dialect=dialect)
class ShiftSubgridY(ir.Statement):
name = "shift_subgrid_y"
traits = frozenset({ir.Pure(), lowering.FromPythonCall()})
zone: ir.SSAValue = info.argument(
type=GridType[NumX := types.TypeVar("NumX"), NumY := types.TypeVar("NumY")]
)
y_indices: ir.SSAValue = info.argument(
ilist.IListType[types.Int, types.TypeVar("SubNumY")]
)
y_shift: ir.SSAValue = info.argument(types.Float)
result: ir.ResultValue = info.result(GridType[NumX, NumY])


@statement(dialect=dialect)
class Scale(ir.Statement):
name = "scale_grid"
Expand Down
78 changes: 78 additions & 0 deletions src/bloqade/geometry/dialects/grid/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,84 @@ def shift(self, x_shift: float, y_shift: float) -> "Grid[NumX, NumY]":
y_init=self.y_init + y_shift if self.y_init is not None else None,
)

def shift_subgrid_x(
self, x_indices: ilist.IList[int, Nx] | slice, x_shift: float
) -> "Grid[NumX, NumY]":
"""Shift a sub grid of grid in the x directions.

Args:
grid (Grid): a grid object
x_indices (float): a list/ilist of x indices to shift
x_shift (float): shift in the x direction
Returns:
Grid: a new grid object that has been shifted
"""
indices = get_indices(len(self.x_spacing) + 1, x_indices)

def shift_x(index):
new_spacing = self.x_spacing[index]
if index in indices and (index + 1) not in indices:
new_spacing -= x_shift
elif index not in indices and (index + 1) in indices:
new_spacing += x_shift
return new_spacing

new_spacing = tuple(shift_x(i) for i in range(len(self.x_spacing)))

assert all(
x >= 0 for x in new_spacing
), "Invalid shift: column order changes after shift."
Comment on lines +404 to +406
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this is a user input we should make this an exception instead of an assertion. Generally Assertions are used to enforce internal logic inside a code to catch potential bugs while exceptions are indicating invalid inputs from a user.


x_init = self.x_init
if x_init is not None and 0 in indices:
x_init += x_shift

return Grid(
x_spacing=new_spacing,
y_spacing=self.y_spacing,
x_init=x_init,
y_init=self.y_init,
)

def shift_subgrid_y(
self, y_indices: ilist.IList[int, Ny] | slice, y_shift: float
) -> "Grid[NumX, NumY]":
"""Shift a sub grid of grid in the y directions.

Args:
grid (Grid): a grid object
y_indices (float): a list/ilist of y indices to shift
y_shift (float): shift in the y direction
Returns:
Grid: a new grid object that has been shifted
"""
indices = get_indices(len(self.y_spacing) + 1, y_indices)

def shift_y(index):
new_spacing = self.y_spacing[index]
if index in indices and (index + 1) not in indices:
new_spacing -= y_shift
elif index not in indices and (index + 1) in indices:
new_spacing += y_shift
return new_spacing

new_spacing = tuple(shift_y(i) for i in range(len(self.y_spacing)))

assert all(
y >= 0 for y in new_spacing
), "Invalid shift: row order changes after shift."

y_init = self.y_init
if y_init is not None and 0 in indices:
y_init += y_shift

return Grid(
x_spacing=self.x_spacing,
y_spacing=new_spacing,
x_init=self.x_init,
y_init=y_init,
)

def repeat(
self, x_times: int, y_times: int, x_gap: float, y_gap: float
) -> "Grid[NumX, NumY]":
Expand Down
2 changes: 2 additions & 0 deletions test/grid/test_concrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,8 @@ def test_from_ranges(self):
(grid.GetYPos, "y_positions", ()),
(grid.Get, "get", ((1, 0),)),
(grid.Shift, "shift", (1.0, 2.0)),
(grid.ShiftSubgridX, "shift_subgrid_x", (ilist.IList([0]), -1)),
(grid.ShiftSubgridY, "shift_subgrid_y", (ilist.IList([0]), -1)),
(grid.Scale, "scale", (1.0, 2.0)),
(grid.Repeat, "repeat", (1, 2, 0.5, 1.0)),
(grid.GetSubGrid, "get_view", (ilist.IList((0,)), ilist.IList((1,)))),
Expand Down
131 changes: 131 additions & 0 deletions test/grid/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,137 @@ def test_shift(self):
)
assert shifted_grid.is_equal(expected_grid)

@pytest.mark.parametrize(
"x_indices, x_shift, expected_grid",
[
(
ilist.IList([]),
0,
Grid(
x_spacing=(1, 2, 3),
y_spacing=(4, 5),
x_init=1,
y_init=2,
),
),
(
ilist.IList([0, 1]),
1,
Grid(
x_spacing=(1, 1, 3),
y_spacing=(4, 5),
x_init=2,
y_init=2,
),
),
(
ilist.IList([1]),
1,
Grid(
x_spacing=(2, 1, 3),
y_spacing=(4, 5),
x_init=1,
y_init=2,
),
),
(
ilist.IList([1, 2, 3]),
1,
Grid(
x_spacing=(2, 2, 3),
y_spacing=(4, 5),
x_init=1,
y_init=2,
),
),
(
slice(1, 4, 1),
1,
Grid(
x_spacing=(2, 2, 3),
y_spacing=(4, 5),
x_init=1,
y_init=2,
),
),
(ilist.IList([1]), 3, None),
],
)
def test_shift_subgrid_x(self, x_indices, x_shift, expected_grid):
if expected_grid is None:
with pytest.raises(AssertionError):
shifted_grid = self.grid_obj.shift_subgrid_x(x_indices, x_shift)
return

shifted_grid = self.grid_obj.shift_subgrid_x(x_indices, x_shift)
assert shifted_grid.is_equal(expected_grid)

@pytest.mark.parametrize(
"y_indices, y_shift, expected_grid",
[
(
ilist.IList([]),
0,
Grid(
x_spacing=(1, 2, 3),
y_spacing=(4, 5),
x_init=1,
y_init=2,
),
),
(
ilist.IList([0]),
-1,
Grid(
x_spacing=(1, 2, 3),
y_spacing=(5, 5),
x_init=1,
y_init=1,
),
),
(
ilist.IList([1]),
1,
Grid(
x_spacing=(1, 2, 3),
y_spacing=(5, 4),
x_init=1,
y_init=2,
),
),
(
ilist.IList([0, 2]),
1,
Grid(
x_spacing=(1, 2, 3),
y_spacing=(3, 6),
x_init=1,
y_init=3,
),
),
(
slice(0, 1, 1),
-1,
Grid(
x_spacing=(1, 2, 3),
y_spacing=(5, 5),
x_init=1,
y_init=1,
),
),
(ilist.IList([0]), 5, None),
],
)
def test_shift_subgrid_y(self, y_indices, y_shift, expected_grid):

if expected_grid is None:
with pytest.raises(AssertionError):
shifted_grid = self.grid_obj.shift_subgrid_y(y_indices, y_shift)
return

shifted_grid = self.grid_obj.shift_subgrid_y(y_indices, y_shift)
assert shifted_grid.is_equal(expected_grid)

def test_scale(self):
scaled_grid = self.grid_obj.scale(2, 3)
expected_grid = Grid(
Expand Down
Loading