Skip to content

Commit 13b024c

Browse files
feat: add selection-sort
1 parent fc75883 commit 13b024c

File tree

5 files changed

+703
-0
lines changed

5 files changed

+703
-0
lines changed

pydatastructs/linear_data_structures/_backend/cpp/algorithms/algorithms.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ static PyMethodDef algorithms_PyMethodDef[] = {
1010
METH_VARARGS | METH_KEYWORDS, ""},
1111
{"bubble_sort_llvm", (PyCFunction)bubble_sort_llvm,
1212
METH_VARARGS | METH_KEYWORDS, ""},
13+
{"selection_sort_llvm", (PyCFunction)selection_sort_llvm,
14+
METH_VARARGS | METH_KEYWORDS, ""},
1315
{"selection_sort", (PyCFunction) selection_sort,
1416
METH_VARARGS | METH_KEYWORDS, ""},
1517
{"insertion_sort", (PyCFunction) insertion_sort,

pydatastructs/linear_data_structures/_backend/cpp/algorithms/llvm_algorithms.py

Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,159 @@ def get_bubble_sort_ptr(dtype: str) -> int:
4141

4242
return _materialize(dtype)
4343

44+
45+
def get_selection_sort_ptr(dtype: str) -> int:
46+
dtype = dtype.lower().strip()
47+
if dtype not in _SUPPORTED:
48+
raise ValueError(f"Unsupported dtype '{dtype}'. Supported: {list(_SUPPORTED)}")
49+
50+
return _materialize_selection(dtype)
51+
52+
53+
def _build_selection_sort_ir(dtype: str) -> str:
54+
if dtype not in _SUPPORTED:
55+
raise ValueError(f"Unsupported dtype '{dtype}'. Supported: {list(_SUPPORTED)}")
56+
57+
T, _ = _SUPPORTED[dtype]
58+
i32 = ir.IntType(32)
59+
i64 = ir.IntType(64)
60+
61+
mod = ir.Module(name=f"selection_sort_{dtype}_module")
62+
fn_name = f"selection_sort_{dtype}"
63+
64+
fn_ty = ir.FunctionType(ir.VoidType(), [T.as_pointer(), i32])
65+
fn = ir.Function(mod, fn_ty, name=fn_name)
66+
67+
arr, n = fn.args
68+
arr.name, n.name = "arr", "n"
69+
70+
b_entry = fn.append_basic_block("entry")
71+
b_outer = fn.append_basic_block("outer")
72+
b_find_min = fn.append_basic_block("find_min")
73+
b_inner = fn.append_basic_block("inner")
74+
b_update_min = fn.append_basic_block("update_min")
75+
b_inc_j = fn.append_basic_block("inc_j")
76+
b_check_swap = fn.append_basic_block("check_swap")
77+
b_do_swap = fn.append_basic_block("do_swap")
78+
b_outer_latch = fn.append_basic_block("outer.latch")
79+
b_exit = fn.append_basic_block("exit")
80+
81+
b = ir.IRBuilder(b_entry)
82+
83+
cond_trivial = b.icmp_signed("<=", n, ir.Constant(i32, 1))
84+
b.cbranch(cond_trivial, b_exit, b_outer)
85+
86+
b.position_at_end(b_outer)
87+
i_phi = b.phi(i32, name="i")
88+
i_phi.add_incoming(ir.Constant(i32, 0), b_entry)
89+
90+
n_minus_1 = b.sub(n, ir.Constant(i32, 1), name="n_minus_1")
91+
cond_outer = b.icmp_signed("<", i_phi, n_minus_1)
92+
b.cbranch(cond_outer, b_find_min, b_exit)
93+
94+
b.position_at_end(b_find_min)
95+
min_idx = b.alloca(i32, name="min_idx")
96+
b.store(i_phi, min_idx)
97+
j_init = b.add(i_phi, ir.Constant(i32, 1), name="j_init")
98+
b.branch(b_inner)
99+
100+
b.position_at_end(b_inner)
101+
j_phi = b.phi(i32, name="j")
102+
j_phi.add_incoming(j_init, b_find_min)
103+
104+
cond_inner = b.icmp_signed("<", j_phi, n)
105+
b.cbranch(cond_inner, b_update_min, b_check_swap)
106+
107+
b.position_at_end(b_update_min)
108+
j64 = b.sext(j_phi, i64)
109+
min_idx_val = b.load(min_idx)
110+
min64 = b.sext(min_idx_val, i64)
111+
112+
ptr_j = b.gep(arr, [j64], inbounds=True)
113+
ptr_min = b.gep(arr, [min64], inbounds=True)
114+
val_j = b.load(ptr_j)
115+
val_min = b.load(ptr_min)
116+
117+
if isinstance(T, ir.IntType):
118+
cmp_less = b.icmp_signed("<", val_j, val_min)
119+
else:
120+
cmp_less = b.fcmp_ordered("<", val_j, val_min, fastmath=True)
121+
122+
b.cbranch(cmp_less, b_inc_j, b_inc_j)
123+
124+
b.position_at_end(b_inc_j)
125+
cur_min = b.load(min_idx)
126+
new_min = b.select(cmp_less, j_phi, cur_min)
127+
b.store(new_min, min_idx)
128+
j_next = b.add(j_phi, ir.Constant(i32, 1), name="j_next")
129+
j_phi.add_incoming(j_next, b_inc_j)
130+
b.branch(b_inner)
131+
132+
b.position_at_end(b_check_swap)
133+
final_min = b.load(min_idx)
134+
need_swap = b.icmp_signed("!=", final_min, i_phi)
135+
b.cbranch(need_swap, b_do_swap, b_outer_latch)
136+
137+
b.position_at_end(b_do_swap)
138+
i64_idx = b.sext(i_phi, i64)
139+
min64_idx = b.sext(final_min, i64)
140+
ptr_i = b.gep(arr, [i64_idx], inbounds=True)
141+
ptr_min2 = b.gep(arr, [min64_idx], inbounds=True)
142+
val_i = b.load(ptr_i)
143+
val_min2 = b.load(ptr_min2)
144+
b.store(val_min2, ptr_i)
145+
b.store(val_i, ptr_min2)
146+
b.branch(b_outer_latch)
147+
148+
b.position_at_end(b_outer_latch)
149+
i_next = b.add(i_phi, ir.Constant(i32, 1), name="i_next")
150+
i_phi.add_incoming(i_next, b_outer_latch)
151+
b.branch(b_outer)
152+
153+
b.position_at_end(b_exit)
154+
b.ret_void()
155+
156+
return str(mod)
157+
158+
159+
def _materialize_selection(dtype: str) -> int:
160+
_ensure_target_machine()
161+
162+
key = f"selection_sort_{dtype}"
163+
if key in _fn_ptr_cache:
164+
return _fn_ptr_cache[key]
165+
166+
try:
167+
llvm_ir = _build_selection_sort_ir(dtype)
168+
mod = binding.parse_assembly(llvm_ir)
169+
mod.verify()
170+
171+
try:
172+
pm = binding.ModulePassManager()
173+
pm.add_instruction_combining_pass()
174+
pm.add_reassociate_pass()
175+
pm.add_gvn_pass()
176+
pm.add_cfg_simplification_pass()
177+
pm.run(mod)
178+
except AttributeError:
179+
pass
180+
181+
engine = binding.create_mcjit_compiler(mod, _target_machine)
182+
engine.finalize_object()
183+
engine.run_static_constructors()
184+
185+
addr = engine.get_function_address(f"selection_sort_{dtype}")
186+
if not addr:
187+
raise RuntimeError(f"Failed to get address for selection_sort_{dtype}")
188+
189+
_fn_ptr_cache[key] = addr
190+
_engines[key] = engine
191+
192+
return addr
193+
194+
except Exception as e:
195+
raise RuntimeError(f"Failed to materialize selection function for dtype {dtype}: {e}")
196+
44197
def _build_bubble_sort_ir(dtype: str) -> str:
45198
if dtype not in _SUPPORTED:
46199
raise ValueError(f"Unsupported dtype '{dtype}'. Supported: {list(_SUPPORTED)}")

0 commit comments

Comments
 (0)