Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
223 changes: 161 additions & 62 deletions kmir/src/kmir/kast.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import logging
from itertools import count
from typing import TYPE_CHECKING, NamedTuple

from pyk.kast.inner import KApply, KSort, KVariable, Subst, build_cons
Expand All @@ -11,10 +12,23 @@
from pyk.kast.prelude.utils import token

from .ty import ArrayT, BoolT, EnumT, IntT, PtrT, RefT, StructT, TupleT, Ty, UintT, UnionT
from .value import BoolValue, IntValue
from .value import (
NO_SIZE,
AggregateValue,
BoolValue,
DynamicSize,
IntValue,
Local,
Metadata,
Place,
PtrLocalValue,
RangeValue,
RefValue,
StaticSize,
)

if TYPE_CHECKING:
from collections.abc import Iterable, Mapping, Sequence
from collections.abc import Iterable, Iterator, Mapping, Sequence
from random import Random
from typing import Any, Final

Expand All @@ -29,6 +43,9 @@
_LOGGER: Final = logging.getLogger(__name__)


RANDOM_MAX_ARRAY_LEN: Final = 32


LOCAL_0: Final = KApply('newLocal', KApply('ty', token(0)), KApply('Mutability::Not'))


Expand Down Expand Up @@ -474,25 +491,7 @@ def to_kast(self) -> KInner:


def _random_locals(random: Random, args: Sequence[_Local], types: Mapping[Ty, TypeMetadata]) -> list[KInner]:
res: list[KInner] = [LOCAL_0]
pointees: list[KInner] = []

next_ref = len(args) + 1
for arg in args:
rvres = _random_value(
random=random,
local=arg,
types=types,
next_ref=next_ref,
)
res.append(rvres.value.to_kast())
match rvres:
case PointerRes(pointee=pointee):
pointees.append(pointee.to_kast())
next_ref += 1

res += pointees
return res
return _RandomArgGen(random=random, args=args, types=types).run()


class SimpleRes(NamedTuple):
Expand All @@ -501,55 +500,155 @@ class SimpleRes(NamedTuple):

class ArrayRes(NamedTuple):
value: TypedValue
metadata: MetadataSize
metadata_size: MetadataSize


class PointerRes(NamedTuple):
value: TypedValue
pointee: TypedValue
RandomValueRes = SimpleRes | ArrayRes


RandomValueRes = SimpleRes | ArrayRes | PointerRes
class _RandomArgGen:
_random: Random
_args: Sequence[_Local]
_types: Mapping[Ty, TypeMetadata]
_pointees: list[TypedValue]
_ref: Iterator[int]

def __init__(self, *, random: Random, args: Sequence[_Local], types: Mapping[Ty, TypeMetadata]):
self._random = random
self._args = args
self._types = types
self._pointees = []
self._ref = count(len(args) + 1)

def _random_value(
*,
random: Random,
local: _Local,
types: Mapping[Ty, TypeMetadata],
next_ref: int,
) -> RandomValueRes:
try:
type_info = types[local.ty]
except KeyError as err:
raise ValueError(f'Unknown type: {local.ty}') from err

match type_info:
case BoolT():
return SimpleRes(
TypedValue.from_local(
value=_random_bool_value(random),
local=local,
def run(self) -> list[KInner]:
res: list[KInner] = [LOCAL_0]
res.extend(self._random_value(arg).value.to_kast() for arg in self._args)
res.extend(pointee.to_kast() for pointee in self._pointees)
return res

def _random_value(self, local: _Local) -> RandomValueRes:
try:
type_info = self._types[local.ty]
except KeyError as err:
raise ValueError(f'Unknown type: {local.ty}') from err

match type_info:
case BoolT():
return SimpleRes(
TypedValue.from_local(
value=self._random_bool_value(),
local=local,
)
)
)
case IntT() | UintT():
return SimpleRes(
TypedValue.from_local(
value=_random_int_value(random, type_info),
local=local,
),
)
case _:
raise ValueError(f'Type unsupported for random value generator: {type_info}')
case IntT() | UintT():
return SimpleRes(
TypedValue.from_local(
value=self._random_int_value(type_info),
local=local,
),
)
case EnumT(discriminants=discriminants, fields=fields):
return SimpleRes(
TypedValue.from_local(
value=self._random_enum_value(mut=local.mut, discriminants=discriminants, fields=fields),
local=local,
),
)
case StructT(fields=tys) | TupleT(components=tys):
return SimpleRes(
TypedValue.from_local(
value=self._random_struct_or_tuple_value(mut=local.mut, tys=tys),
local=local,
),
)
case ArrayT(element_type=elem_ty, length=length):
value, metadata_size = self._random_array_value(mut=local.mut, elem_ty=elem_ty, length=length)
return ArrayRes(
value=TypedValue.from_local(
value=value,
local=local,
),
metadata_size=metadata_size,
)
case PtrT() | RefT():
return SimpleRes(
value=TypedValue.from_local(
value=self._random_ptr_value(mut=local.mut, type_info=type_info),
local=local,
),
)
case _:
raise ValueError(f'Type unsupported for random value generator: {type_info}')

def _random_bool_value(self) -> BoolValue:
return BoolValue(bool(self._random.getrandbits(1)))

def _random_bool_value(random: Random) -> BoolValue:
return BoolValue(bool(random.getrandbits(1)))
def _random_int_value(self, type_info: IntT | UintT) -> IntValue:
return IntValue(
value=self._random.randint(type_info.min, type_info.max),
nbits=type_info.nbits,
signed=isinstance(type_info, IntT),
)

def _random_enum_value(
self,
*,
mut: bool,
discriminants: list[int],
fields: list[list[Ty]],
) -> AggregateValue:
variant_idx = self._random.randrange(len(discriminants))
values = self._random_fields(tys=fields[variant_idx], mut=mut)
return AggregateValue(variant_idx, values)

def _random_struct_or_tuple_value(self, *, mut: bool, tys: list[Ty]) -> AggregateValue:
return AggregateValue(0, fields=self._random_fields(tys=tys, mut=mut))

def _random_fields(self, *, tys: list[Ty], mut: bool) -> tuple[Value, ...]:
return tuple(self._random_value(local=_Local(ty=ty, mut=mut)).value.value for ty in tys)

def _random_array_value(self, *, mut: bool, elem_ty: Ty, length: int | None) -> tuple[RangeValue, MetadataSize]:
metadata_size: MetadataSize
if length is None:
length = self._random.randint(0, RANDOM_MAX_ARRAY_LEN)
metadata_size = DynamicSize(length)
else:
metadata_size = StaticSize(length)

elems = tuple(self._random_value(local=_Local(ty=elem_ty, mut=mut)).value.value for _ in range(length))
value = RangeValue(elems)
return value, metadata_size

def _random_ptr_value(self, mut: bool, type_info: PtrT | RefT) -> PtrLocalValue | RefValue:
pointee_local = _Local(ty=type_info.pointee_type, mut=mut)
pointee_res = self._random_value(pointee_local)
self._pointees.append(pointee_res.value)

metadata_size: MetadataSize
match pointee_res:
case ArrayRes(metadata_size=metadata_size):
pass
case _:
metadata_size = NO_SIZE

def _random_int_value(random: Random, type_info: IntT | UintT) -> IntValue:
return IntValue(
value=random.randint(type_info.min, type_info.max),
nbits=type_info.nbits,
signed=isinstance(type_info, IntT),
)
metadata = Metadata(size=metadata_size, pointer_offset=0, origin_size=metadata_size)

ref = next(self._ref)

match type_info:
case PtrT():
return PtrLocalValue(
stack_depth=0,
place=Place(local=Local(ref)),
mut=mut,
metadata=metadata,
)
case RefT():
return RefValue(
stack_depth=0,
place=Place(local=Local(ref)),
mut=mut,
metadata=metadata,
)
case _:
raise AssertionError()
52 changes: 51 additions & 1 deletion kmir/src/kmir/value.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, NewType

from pyk.kast.inner import KApply
from pyk.kast.prelude.collections import list_of
Expand All @@ -19,6 +19,9 @@
from .alloc import AllocId


Local = NewType('Local', int)


class Value(ABC):
@abstractmethod
def to_kast(self) -> KInner: ...
Expand Down Expand Up @@ -86,6 +89,40 @@ def to_kast(self) -> KInner:
)


@dataclass
class RefValue(Value):
stack_depth: int
place: Place
mut: bool
metadata: Metadata

def to_kast(self) -> KInner:
return KApply(
'Value::Reference',
intToken(self.stack_depth),
self.place.to_kast(),
KApply('Mutability::Mut') if self.mut else KApply('Mutability::Not'),
self.metadata.to_kast(),
)


@dataclass
class PtrLocalValue(Value):
stack_depth: int
place: Place
mut: bool
metadata: Metadata

def to_kast(self) -> KInner:
return KApply(
'Value::PtrLocal',
intToken(self.stack_depth),
self.place.to_kast(),
KApply('Mutability::Mut') if self.mut else KApply('Mutability::Not'),
self.metadata.to_kast(),
)


@dataclass
class AllocRefValue(Value):
alloc_id: AllocId
Expand All @@ -101,6 +138,19 @@ def to_kast(self) -> KInner:
)


@dataclass
class Place:
local: Local
# projection_elems: tuple[ProjectionElem, ...]

def to_kast(self) -> KInner:
return KApply(
'place',
KApply('local', intToken(self.local)),
KApply('ProjectionElems::empty'), # TODO
)


@dataclass
class Metadata:
size: MetadataSize
Expand Down
Loading