@@ -41,6 +41,159 @@ def get_bubble_sort_ptr(dtype: str) -> int:
4141
4242 return _materialize (dtype )
4343
44+
45+ def get_selection_sort_ptr (dtype : str ) -> int :
46+ dtype = dtype .lower ().strip ()
47+ if dtype not in _SUPPORTED :
48+ raise ValueError (f"Unsupported dtype '{ dtype } '. Supported: { list (_SUPPORTED )} " )
49+
50+ return _materialize_selection (dtype )
51+
52+
53+ def _build_selection_sort_ir (dtype : str ) -> str :
54+ if dtype not in _SUPPORTED :
55+ raise ValueError (f"Unsupported dtype '{ dtype } '. Supported: { list (_SUPPORTED )} " )
56+
57+ T , _ = _SUPPORTED [dtype ]
58+ i32 = ir .IntType (32 )
59+ i64 = ir .IntType (64 )
60+
61+ mod = ir .Module (name = f"selection_sort_{ dtype } _module" )
62+ fn_name = f"selection_sort_{ dtype } "
63+
64+ fn_ty = ir .FunctionType (ir .VoidType (), [T .as_pointer (), i32 ])
65+ fn = ir .Function (mod , fn_ty , name = fn_name )
66+
67+ arr , n = fn .args
68+ arr .name , n .name = "arr" , "n"
69+
70+ b_entry = fn .append_basic_block ("entry" )
71+ b_outer = fn .append_basic_block ("outer" )
72+ b_find_min = fn .append_basic_block ("find_min" )
73+ b_inner = fn .append_basic_block ("inner" )
74+ b_update_min = fn .append_basic_block ("update_min" )
75+ b_inc_j = fn .append_basic_block ("inc_j" )
76+ b_check_swap = fn .append_basic_block ("check_swap" )
77+ b_do_swap = fn .append_basic_block ("do_swap" )
78+ b_outer_latch = fn .append_basic_block ("outer.latch" )
79+ b_exit = fn .append_basic_block ("exit" )
80+
81+ b = ir .IRBuilder (b_entry )
82+
83+ cond_trivial = b .icmp_signed ("<=" , n , ir .Constant (i32 , 1 ))
84+ b .cbranch (cond_trivial , b_exit , b_outer )
85+
86+ b .position_at_end (b_outer )
87+ i_phi = b .phi (i32 , name = "i" )
88+ i_phi .add_incoming (ir .Constant (i32 , 0 ), b_entry )
89+
90+ n_minus_1 = b .sub (n , ir .Constant (i32 , 1 ), name = "n_minus_1" )
91+ cond_outer = b .icmp_signed ("<" , i_phi , n_minus_1 )
92+ b .cbranch (cond_outer , b_find_min , b_exit )
93+
94+ b .position_at_end (b_find_min )
95+ min_idx = b .alloca (i32 , name = "min_idx" )
96+ b .store (i_phi , min_idx )
97+ j_init = b .add (i_phi , ir .Constant (i32 , 1 ), name = "j_init" )
98+ b .branch (b_inner )
99+
100+ b .position_at_end (b_inner )
101+ j_phi = b .phi (i32 , name = "j" )
102+ j_phi .add_incoming (j_init , b_find_min )
103+
104+ cond_inner = b .icmp_signed ("<" , j_phi , n )
105+ b .cbranch (cond_inner , b_update_min , b_check_swap )
106+
107+ b .position_at_end (b_update_min )
108+ j64 = b .sext (j_phi , i64 )
109+ min_idx_val = b .load (min_idx )
110+ min64 = b .sext (min_idx_val , i64 )
111+
112+ ptr_j = b .gep (arr , [j64 ], inbounds = True )
113+ ptr_min = b .gep (arr , [min64 ], inbounds = True )
114+ val_j = b .load (ptr_j )
115+ val_min = b .load (ptr_min )
116+
117+ if isinstance (T , ir .IntType ):
118+ cmp_less = b .icmp_signed ("<" , val_j , val_min )
119+ else :
120+ cmp_less = b .fcmp_ordered ("<" , val_j , val_min , fastmath = True )
121+
122+ b .cbranch (cmp_less , b_inc_j , b_inc_j )
123+
124+ b .position_at_end (b_inc_j )
125+ cur_min = b .load (min_idx )
126+ new_min = b .select (cmp_less , j_phi , cur_min )
127+ b .store (new_min , min_idx )
128+ j_next = b .add (j_phi , ir .Constant (i32 , 1 ), name = "j_next" )
129+ j_phi .add_incoming (j_next , b_inc_j )
130+ b .branch (b_inner )
131+
132+ b .position_at_end (b_check_swap )
133+ final_min = b .load (min_idx )
134+ need_swap = b .icmp_signed ("!=" , final_min , i_phi )
135+ b .cbranch (need_swap , b_do_swap , b_outer_latch )
136+
137+ b .position_at_end (b_do_swap )
138+ i64_idx = b .sext (i_phi , i64 )
139+ min64_idx = b .sext (final_min , i64 )
140+ ptr_i = b .gep (arr , [i64_idx ], inbounds = True )
141+ ptr_min2 = b .gep (arr , [min64_idx ], inbounds = True )
142+ val_i = b .load (ptr_i )
143+ val_min2 = b .load (ptr_min2 )
144+ b .store (val_min2 , ptr_i )
145+ b .store (val_i , ptr_min2 )
146+ b .branch (b_outer_latch )
147+
148+ b .position_at_end (b_outer_latch )
149+ i_next = b .add (i_phi , ir .Constant (i32 , 1 ), name = "i_next" )
150+ i_phi .add_incoming (i_next , b_outer_latch )
151+ b .branch (b_outer )
152+
153+ b .position_at_end (b_exit )
154+ b .ret_void ()
155+
156+ return str (mod )
157+
158+
159+ def _materialize_selection (dtype : str ) -> int :
160+ _ensure_target_machine ()
161+
162+ key = f"selection_sort_{ dtype } "
163+ if key in _fn_ptr_cache :
164+ return _fn_ptr_cache [key ]
165+
166+ try :
167+ llvm_ir = _build_selection_sort_ir (dtype )
168+ mod = binding .parse_assembly (llvm_ir )
169+ mod .verify ()
170+
171+ try :
172+ pm = binding .ModulePassManager ()
173+ pm .add_instruction_combining_pass ()
174+ pm .add_reassociate_pass ()
175+ pm .add_gvn_pass ()
176+ pm .add_cfg_simplification_pass ()
177+ pm .run (mod )
178+ except AttributeError :
179+ pass
180+
181+ engine = binding .create_mcjit_compiler (mod , _target_machine )
182+ engine .finalize_object ()
183+ engine .run_static_constructors ()
184+
185+ addr = engine .get_function_address (f"selection_sort_{ dtype } " )
186+ if not addr :
187+ raise RuntimeError (f"Failed to get address for selection_sort_{ dtype } " )
188+
189+ _fn_ptr_cache [key ] = addr
190+ _engines [key ] = engine
191+
192+ return addr
193+
194+ except Exception as e :
195+ raise RuntimeError (f"Failed to materialize selection function for dtype { dtype } : { e } " )
196+
44197def _build_bubble_sort_ir (dtype : str ) -> str :
45198 if dtype not in _SUPPORTED :
46199 raise ValueError (f"Unsupported dtype '{ dtype } '. Supported: { list (_SUPPORTED )} " )
0 commit comments