Skip to content
35 changes: 32 additions & 3 deletions src/irx/builders/llvmliteir.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,12 @@ def __init__(self) -> None:
codemodel="small"
)

self._llvm.module.triple = self.target_machine.triple
self._llvm.module.data_layout = str(self.target_machine.target_data)

if self._llvm.SIZE_T_TYPE is None:
self._llvm.SIZE_T_TYPE = self._get_size_t_type_from_triple()

self._add_builtins()

def translate(self, node: astx.AST) -> str:
Expand All @@ -183,7 +189,30 @@ def translate(self, node: astx.AST) -> str:
def _init_native_size_types(self) -> None:
"""Initialize pointer/size_t types from host."""
self._llvm.POINTER_BITS = ctypes.sizeof(ctypes.c_void_p) * 8
self._llvm.SIZE_T_TYPE = ir.IntType(ctypes.sizeof(ctypes.c_size_t) * 8)
self._llvm.SIZE_T_TYPE = None

def _get_size_t_type_from_triple(self) -> ir.IntType:
"""Determine size_t type from target triple using LLVM API."""
triple = self.target_machine.triple.lower()

if any(
arch in triple
for arch in [
"x86_64",
"amd64",
"aarch64",
"arm64",
"ppc64",
"mips64",
]
):
return ir.IntType(64)
elif any(arch in triple for arch in ["i386", "i686", "arm", "mips"]):
if "64" in triple:
return ir.IntType(64)
return ir.IntType(32)

return ir.IntType(ctypes.sizeof(ctypes.c_size_t) * 8)

def initialize(self) -> None:
"""Initialize self."""
Expand Down Expand Up @@ -240,8 +269,8 @@ def initialize(self) -> None:
self._llvm.INT32_TYPE,
]
)
# Platform-sized unsigned integer (assume 64-bit for CI targets)
self._llvm.SIZE_T_TYPE = ir.IntType(64)
# SIZE_T_TYPE already initialized based on host; do not override with a
# fixed width here to avoid mismatches on non-64-bit targets.

def _add_builtins(self) -> None:
# The C++ tutorial adds putchard() simply by defining it in the host
Expand Down
61 changes: 61 additions & 0 deletions tests/test_llvmlite_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from __future__ import annotations

from typing import Any, cast
from unittest.mock import Mock

from irx.builders.llvmliteir import (
LLVMLiteIRVisitor,
Expand Down Expand Up @@ -95,3 +96,63 @@ def test_emit_int_div_signed_and_unsigned() -> None:

assert getattr(signed, "opname", "") == "sdiv"
assert getattr(unsigned, "opname", "") == "udiv"


def test_get_size_t_type_from_triple_32bit() -> None:
"""Test _get_size_t_type_from_triple for 32-bit architectures."""
visitor = LLVMLiteIRVisitor()

mock_tm = Mock()
mock_tm.triple = "i386-unknown-linux-gnu"
visitor.target_machine = mock_tm

size_t_ty = visitor._get_size_t_type_from_triple()
assert size_t_ty.width == 32 # noqa: PLR2004


def test_get_size_t_type_from_triple_fallback() -> None:
"""Test _get_size_t_type_from_triple fallback for unknown architectures."""
visitor = LLVMLiteIRVisitor()

mock_tm = Mock()
mock_tm.triple = "unknown-arch-unknown-os"
visitor.target_machine = mock_tm

size_t_ty = visitor._get_size_t_type_from_triple()
assert isinstance(size_t_ty, ir.IntType)
assert size_t_ty.width in (32, 64)


def test_scalar_vector_float_conversion_fptrunc() -> None:
"""Test scalar-vector promotion with float truncation."""
visitor = LLVMLiteIRVisitor()
_prime_builder(visitor)

double_ty = visitor._llvm.DOUBLE_TYPE
float_ty = visitor._llvm.FLOAT_TYPE
vec_ty = ir.VectorType(float_ty, 2)

scalar = ir.Constant(double_ty, 3.14)
converted = visitor._llvm.ir_builder.fptrunc(scalar, float_ty, "test")
result = splat_scalar(visitor._llvm.ir_builder, converted, vec_ty)

assert isinstance(result.type, ir.VectorType)
assert result.type.element == float_ty


def test_scalar_vector_float_conversion_fpext() -> None:
"""Test scalar-vector promotion with float extension."""
visitor = LLVMLiteIRVisitor()
_prime_builder(visitor)

float_ty = visitor._llvm.FLOAT_TYPE
double_ty = visitor._llvm.DOUBLE_TYPE
vec_ty = ir.VectorType(double_ty, 2)

scalar = ir.Constant(float_ty, 3.14)

converted = visitor._llvm.ir_builder.fpext(scalar, double_ty, "test")
result = splat_scalar(visitor._llvm.ir_builder, converted, vec_ty)

assert isinstance(result.type, ir.VectorType)
assert result.type.element == double_ty
Loading