From 83b2221d047e56620561db8a2f7b99cebd7497b6 Mon Sep 17 00:00:00 2001 From: Tommaso Dossi Date: Tue, 22 Jul 2025 15:30:27 +0200 Subject: [PATCH 01/16] [experimental] support for vectorized code generation (using jdk.incubator.vector) --- build.gradle.kts | 7 +- gradle.properties | 2 +- .../primitives-annotations/build.gradle.kts | 3 + .../primitives/annotations/GenerateVector.kt | 8 + .../kinference/primitives/types/DataType.kt | 4 +- .../primitives/vector/OperationNode.kt | 66 +++++++++ .../primitives-plugin/build.gradle.kts | 3 + .../generator/PrimitiveGenerator.kt | 25 ++++ .../generator/processor/RemovalProcessor.kt | 1 + .../processor/ReplacementProcessor.kt | 93 +++++++++++- .../processor/VectorReplacementProcessor.kt | 137 ++++++++++++++++++ 11 files changed, 336 insertions(+), 13 deletions(-) create mode 100644 plugin-build/primitives-annotations/src/commonMain/kotlin/io/kinference/primitives/annotations/GenerateVector.kt create mode 100644 plugin-build/primitives-annotations/src/commonMain/kotlin/io/kinference/primitives/vector/OperationNode.kt create mode 100644 plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/generator/processor/VectorReplacementProcessor.kt diff --git a/build.gradle.kts b/build.gradle.kts index df09ca6..f84e7b0 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -12,6 +12,7 @@ plugins { allprojects { repositories { + mavenLocal() mavenCentral() gradlePluginPortal() } @@ -19,14 +20,14 @@ allprojects { subprojects { tasks.withType(JavaCompile::class.java).all { - sourceCompatibility = JavaVersion.VERSION_1_8.toString() - targetCompatibility = JavaVersion.VERSION_1_8.toString() + sourceCompatibility = JavaVersion.VERSION_21.toString() + targetCompatibility = JavaVersion.VERSION_21.toString() } tasks.withType(KotlinCompilationTask::class.java).all { compilerOptions { if (this is KotlinJvmCompilerOptions) { - jvmTarget.set(JvmTarget.JVM_1_8) + jvmTarget.set(JvmTarget.JVM_21) } apiVersion.set(KotlinVersion.KOTLIN_2_0) diff --git a/gradle.properties b/gradle.properties index 571db53..138a42b 100644 --- a/gradle.properties +++ b/gradle.properties @@ -1,2 +1,2 @@ GROUP=io.kinference.primitives -VERSION=2.0.0-1 +VERSION=2.1.0-dev diff --git a/plugin-build/primitives-annotations/build.gradle.kts b/plugin-build/primitives-annotations/build.gradle.kts index 5a73798..1b7287f 100644 --- a/plugin-build/primitives-annotations/build.gradle.kts +++ b/plugin-build/primitives-annotations/build.gradle.kts @@ -13,3 +13,6 @@ kotlin { } } +repositories { + mavenCentral() +} diff --git a/plugin-build/primitives-annotations/src/commonMain/kotlin/io/kinference/primitives/annotations/GenerateVector.kt b/plugin-build/primitives-annotations/src/commonMain/kotlin/io/kinference/primitives/annotations/GenerateVector.kt new file mode 100644 index 0000000..fd013d2 --- /dev/null +++ b/plugin-build/primitives-annotations/src/commonMain/kotlin/io/kinference/primitives/annotations/GenerateVector.kt @@ -0,0 +1,8 @@ +package io.kinference.primitives.annotations + +import io.kinference.primitives.types.* +import io.kinference.primitives.vector.* + +/** Specify that any usage of [OperationNode] is subject for replacement in this file */ +@Target(AnnotationTarget.FILE) +annotation class GenerateVector(vararg val types: DataType) diff --git a/plugin-build/primitives-annotations/src/commonMain/kotlin/io/kinference/primitives/types/DataType.kt b/plugin-build/primitives-annotations/src/commonMain/kotlin/io/kinference/primitives/types/DataType.kt index dc4292c..738d399 100644 --- a/plugin-build/primitives-annotations/src/commonMain/kotlin/io/kinference/primitives/types/DataType.kt +++ b/plugin-build/primitives-annotations/src/commonMain/kotlin/io/kinference/primitives/types/DataType.kt @@ -31,7 +31,8 @@ enum class DataType { BOOLEAN, ALL, - NUMBER; + NUMBER, + VECTORIZABLE; /** * Resolve DataType into actual primitives -- would flatten groups into collection of primitives. @@ -41,6 +42,7 @@ enum class DataType { return when(this) { ALL -> setOf(BYTE, SHORT, INT, LONG, UBYTE, USHORT, UINT, ULONG, FLOAT, DOUBLE, BOOLEAN) NUMBER -> setOf(BYTE, SHORT, INT, LONG, UBYTE, USHORT, UINT, ULONG, FLOAT, DOUBLE) + VECTORIZABLE -> setOf(BYTE, SHORT, INT, LONG, FLOAT, DOUBLE) else -> setOf(this) } } diff --git a/plugin-build/primitives-annotations/src/commonMain/kotlin/io/kinference/primitives/vector/OperationNode.kt b/plugin-build/primitives-annotations/src/commonMain/kotlin/io/kinference/primitives/vector/OperationNode.kt new file mode 100644 index 0000000..f8411d4 --- /dev/null +++ b/plugin-build/primitives-annotations/src/commonMain/kotlin/io/kinference/primitives/vector/OperationNode.kt @@ -0,0 +1,66 @@ +@file:Suppress("Unused", "UnusedReceiverParameter") + +package io.kinference.primitives.vector + +import io.kinference.primitives.types.PrimitiveArray +import io.kinference.primitives.types.PrimitiveType + +sealed class OpNode() { + public fun into(dest: PrimitiveArray, offset: Int, len: Int) { + throw UnsupportedOperationException() + } + public fun reduce(operation: AssociativeWrapper, len: Int): PrimitiveType { + throw UnsupportedOperationException() + } + + internal abstract val isValue: Boolean + //internal abstract fun linReplace(): String + //internal abstract fun vecReplace(): String +} + +class PrimitiveSlice(val src: PrimitiveArray, val offset: Int = 0) : OpNode() { + override val isValue: Boolean = false +} + +class UnaryOp(val arg: OpNode, val operation: UnaryWrapper) : OpNode() { + override val isValue: Boolean = arg.isValue +} + +class BinaryOp(val left: OpNode, val right: OpNode, val operation: BinaryWrapper) : OpNode() { + override val isValue: Boolean = left.isValue && right.isValue +} + +class Value(val value: PrimitiveType) : OpNode() { + override val isValue: Boolean = true +} + +sealed class UnaryWrapper() { +} + +sealed class BinaryWrapper() { +} + +sealed class AssociativeWrapper() : BinaryWrapper() { +} + +object Abs : UnaryWrapper() {} + +object Exp : UnaryWrapper() {} + +object Log : UnaryWrapper() {} + +object Neg : UnaryWrapper() {} + +object Add : AssociativeWrapper() {} + +object Sub : BinaryWrapper() {} + +object Mul : AssociativeWrapper() {} + +object Div : BinaryWrapper() {} + +object Pow : BinaryWrapper() {} + +object Max : AssociativeWrapper() {} + +object Min : BinaryWrapper() {} diff --git a/plugin-build/primitives-plugin/build.gradle.kts b/plugin-build/primitives-plugin/build.gradle.kts index b00aa4b..3fc2c11 100644 --- a/plugin-build/primitives-plugin/build.gradle.kts +++ b/plugin-build/primitives-plugin/build.gradle.kts @@ -24,3 +24,6 @@ gradlePlugin { } } } +repositories { + mavenCentral() +} diff --git a/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/generator/PrimitiveGenerator.kt b/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/generator/PrimitiveGenerator.kt index a55494e..1e293ea 100644 --- a/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/generator/PrimitiveGenerator.kt +++ b/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/generator/PrimitiveGenerator.kt @@ -4,6 +4,7 @@ import io.kinference.primitives.annotations.* import io.kinference.primitives.generator.errors.require import io.kinference.primitives.generator.processor.RemovalProcessor import io.kinference.primitives.generator.processor.ReplacementProcessor +import io.kinference.primitives.types.DataType import io.kinference.primitives.utils.crossProduct import io.kinference.primitives.utils.psi.* import org.jetbrains.kotlin.cli.common.messages.CompilerMessageSeverity @@ -15,6 +16,7 @@ import org.jetbrains.kotlin.psi.* import org.jetbrains.kotlin.psi.psiUtil.visibilityModifier import org.jetbrains.kotlin.resolve.BindingContext import java.io.File +import kotlin.io.path.Path internal class PrimitiveGenerator( private val file: KtFile, private val context: BindingContext, private val output: File, @@ -60,6 +62,12 @@ internal class PrimitiveGenerator( primitiveContext = tmp } + override fun visitImportList(importList: KtImportList) { + if (file.isAnnotatedWith(context) && primitive.dataType in DataType.VECTORIZABLE.resolve()) + builder.appendLine("import jdk.incubator.vector.*") + super.visitImportList(importList) + } + override fun visitModifierList(list: KtModifierList) { if (replacementProcessor.shouldChangeVisibilityModifier(list)) { replacementProcessor.prepareReplaceText(list.visibilityModifier(), "public") @@ -141,12 +149,15 @@ internal class PrimitiveGenerator( typeReference.isAnnotatedWith(context) -> typeReference.withPrimitive(primitiveContext.type1) { super.visitTypeReference(typeReference) } + typeReference.isAnnotatedWith(context) -> typeReference.withPrimitive(primitiveContext.type2) { super.visitTypeReference(typeReference) } + typeReference.isAnnotatedWith(context) -> typeReference.withPrimitive(primitiveContext.type3) { super.visitTypeReference(typeReference) } + else -> super.visitTypeReference(typeReference) } } @@ -157,12 +168,15 @@ internal class PrimitiveGenerator( expression.isAnnotatedWith(context) -> expression.withPrimitive(primitiveContext.type1) { super.visitAnnotatedExpression(expression) } + expression.isAnnotatedWith(context) -> expression.withPrimitive(primitiveContext.type2) { super.visitAnnotatedExpression(expression) } + expression.isAnnotatedWith(context) -> expression.withPrimitive(primitiveContext.type3) { super.visitAnnotatedExpression(expression) } + else -> super.visitAnnotatedExpression(expression) } } @@ -189,6 +203,17 @@ internal class PrimitiveGenerator( val replacement = replacementProcessor.getReplacement(expression, currentPrimitive) builder.append(replacement ?: expression.text) } + + override fun visitDotQualifiedExpression(expression: KtDotQualifiedExpression) { + if (!file.isAnnotatedWith(context)) { + super.visitDotQualifiedExpression(expression) + return + } + val replacement = replacementProcessor.getReplacement(expression, currentPrimitive) + if (replacement.isNullOrEmpty()) { + super.visitDotQualifiedExpression(expression); return + } else builder.append(replacement) + } }) if (builder.isNotBlank()) { diff --git a/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/generator/processor/RemovalProcessor.kt b/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/generator/processor/RemovalProcessor.kt index 2b9dbb4..6c75ccc 100644 --- a/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/generator/processor/RemovalProcessor.kt +++ b/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/generator/processor/RemovalProcessor.kt @@ -4,6 +4,7 @@ import io.kinference.primitives.annotations.* import io.kinference.primitives.generator.isPluginAnnotation import io.kinference.primitives.types.PrimitiveArray import io.kinference.primitives.types.PrimitiveType +import io.kinference.primitives.vector.* import org.jetbrains.kotlin.com.intellij.openapi.util.Key import org.jetbrains.kotlin.com.intellij.psi.PsiElement import org.jetbrains.kotlin.com.intellij.psi.PsiWhiteSpace diff --git a/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/generator/processor/ReplacementProcessor.kt b/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/generator/processor/ReplacementProcessor.kt index 9d7ec89..dff2c85 100644 --- a/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/generator/processor/ReplacementProcessor.kt +++ b/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/generator/processor/ReplacementProcessor.kt @@ -1,6 +1,7 @@ package io.kinference.primitives.generator.processor import io.kinference.primitives.annotations.GenerateNameFromPrimitives +import io.kinference.primitives.annotations.GenerateVector import io.kinference.primitives.annotations.MakePublic import io.kinference.primitives.generator.* import io.kinference.primitives.generator.errors.require @@ -20,7 +21,7 @@ import org.jetbrains.kotlin.resolve.descriptorUtil.fqNameSafe internal class ReplacementProcessor(private val context: BindingContext, private val collector: MessageCollector) { companion object { - private fun toType(primitive: Primitive<*, *>): String { + internal fun toType(primitive: Primitive<*, *>): String { return when (primitive.dataType) { DataType.BYTE -> "toInt().toByte" DataType.SHORT -> "toInt().toShort" @@ -44,17 +45,23 @@ internal class ReplacementProcessor(private val context: BindingContext, private (PrimitiveArray::class.qualifiedName!! + ".") to { it.arrayTypeName }, (PrimitiveArray::class.qualifiedName!! + ".Companion") to { it.arrayTypeName } ) + + private val nodeTypes = setOf("BinaryOp", "UnaryOp", "PrimitiveSlice", "Value") } fun getReplacement(klass: KtClassOrObject, primitive: Primitive<*, *>): String? { - if (klass.isAnnotatedWith(context)) { - return klass.specialize(primitive, collector) - } - collector.require(CompilerMessageSeverity.WARNING, klass, !klass.isTopLevel()) { - "Class is not annotated with ${GenerateNameFromPrimitives::class.simpleName}, so its name would not be specialized. It may lead to redeclaration compile error." + if (!klass.isAnnotatedWith(context)) { + collector.require(CompilerMessageSeverity.WARNING, klass, !klass.isTopLevel()) { + "Class is not annotated with ${GenerateNameFromPrimitives::class.simpleName}, so its name would not be specialized. It may lead to redeclaration compile error." + } + return null } - return null + if (klass.isAnnotatedWith(context)) + return "" + + return klass.specialize(primitive, collector) + } fun getReplacement(function: KtNamedFunction, primitive: Primitive<*, *>): String? { @@ -78,17 +85,87 @@ internal class ReplacementProcessor(private val context: BindingContext, private } (target.isKtClassOrObject() && target.containingDeclaration!!.isAnnotatedWith()) || - (target.isNamedFunction() || target.isKtClassOrObject()) && target.isAnnotatedWith() -> { + (target.isNamedFunction() || target.isKtClassOrObject()) && target.isAnnotatedWith() -> { expression.text.specialize(primitive) } (target.isCompanion() || target.isConstructor()) && target.containingDeclaration!!.isAnnotatedWith() -> { name.specialize(primitive) } + else -> null } } + fun getReplacement(expr: KtDotQualifiedExpression, primitive: Primitive<*, *>): String? { + val receiver = expr.receiverExpression + val sel = expr.selectorExpression ?: return null + if (sel !is KtCallExpression) return null + val args = sel.valueArguments + val callName = sel.calleeExpression?.text ?: return null + return if (callName == "into" && args.size == 3) { + val dest = args[0].text + val destOffset = args[1].text + val len = args[2].text + + val vecProcessor = VectorReplacementProcessor(primitive) + val (vecReplacement, linearReplacement, _) = vecProcessor.process(receiver, collector) ?: return "" + if (primitive.dataType in DataType.VECTORIZABLE.resolve()) + """ + val vectorSpecies = ${vecProcessor.vecSpecies} + val vectorLen = vectorSpecies.length() + val vecEnd = $len - ($len % vectorLen) + for (_vec_internal_idx in 0 until vecEnd step vectorLen) { + $vecReplacement.intoArray($dest, $destOffset + _vec_internal_idx) + } + for(_vec_internal_idx in vecEnd until $len) { + $dest[$destOffset + _vec_internal_idx] = $linearReplacement + } + """.trimIndent() + else + """ + for(_vec_internal_idx in vecEnd until $len) { + $dest[$destOffset + _vec_internal_idx] = $linearReplacement + }""".trimIndent() + } else if (callName == "reduce" && args.size == 2) { + val vecProcessor = VectorReplacementProcessor(primitive) + val opName = args[0].text + val len = args[1].text + val handle = VectorReplacementProcessor.vectorHandles[opName] ?: return "" + val neutral = vecProcessor.neutralElement[handle] ?: return "" + val (vecReplacement, linReplacement, _) = vecProcessor.process(receiver, collector) ?: return "" + val linearOp = VectorReplacementProcessor.binaryLinearReplacements[opName] ?: return "" + val linAccumulate = linearOp("ret", linReplacement) + if (primitive.dataType in DataType.VECTORIZABLE.resolve()) { + """{ + val vectorSpecies = ${vecProcessor.vecSpecies} + val vectorLen = vectorSpecies.length() + val vecEnd = $len - ($len % vectorLen) + var accumulator = ${vecProcessor.vecName}.broadcast(vectorSpecies, $neutral) + for (_vec_internal_idx in 0 until vecEnd step vectorLen) { + accumulator = accumulator.lanewise(VectorOperators.$handle, $vecReplacement) + } + var ret = accumulator.reduceLanes(VectorOperators.$handle) + for(_vec_internal_idx in vecEnd until $len) { + ret = $linAccumulate + } + ret + }.invoke() + """.trimIndent() + } else { + """{ + var ret = $neutral + for(_vec_internal_idx in vecEnd until $len) { + ret = $linAccumulate + } + ret + }.invoke() + """.trimIndent() + } + } else null + } + + fun shouldChangeVisibilityModifier(list: KtModifierList): Boolean { val owner = list.owner val visibilityModifier = list.visibilityModifier() diff --git a/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/generator/processor/VectorReplacementProcessor.kt b/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/generator/processor/VectorReplacementProcessor.kt new file mode 100644 index 0000000..d36ba13 --- /dev/null +++ b/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/generator/processor/VectorReplacementProcessor.kt @@ -0,0 +1,137 @@ +package io.kinference.primitives.generator.processor + +import io.kinference.primitives.generator.Primitive +import org.jetbrains.kotlin.cli.common.messages.MessageCollector +import org.jetbrains.kotlin.psi.KtCallExpression +import org.jetbrains.kotlin.psi.KtExpression + +// TODO: merge linear and vector replacement functions, this is quadratic in op tree size + +internal class VectorReplacementProcessor(val primitive: Primitive<*, *>) { + val vecName = "${primitive.typeName}Vector" + val vecSpecies = "$vecName.SPECIES_PREFERRED" + val vecLen = "$vecName.length()" + + companion object { + val unaryLinearReplacements = mapOf( + "Exp" to { x: String -> "exp($x)" }, + "Abs" to { x: String -> "abs($x)" }, + "Neg" to { x: String -> "(-$x)" }, + "Log" to { x: String -> "ln($x)" }, + ).withDefault { null } + + val binaryLinearReplacements = mapOf( + "Add" to { x: String, y: String -> "($x + $y)" }, + "Sub" to { x: String, y: String -> "($x - $y)" }, + "Mul" to { x: String, y: String -> "($x * $y)" }, + "Div" to { x: String, y: String -> "($x / $y)" }, + "Max" to { x: String, y: String -> "max($x, $y)" }, + "Min" to { x: String, y: String -> "min($x, $y)" }, + "Pow" to { x: String, y: String -> "($x).pow($y)" }, + ).withDefault { null } + + val isAssoc = mapOf( + "ADD" to true, + "MUL" to true, + "MIN" to true, + "MAX" to true, + ).withDefault { false } + + val isCommutative = mapOf( + "ADD" to true, + "MUL" to true, + "MIN" to true, + "MAX" to true, + ).withDefault { false } + + val vectorHandles = mapOf( + "Add" to "ADD", + "Sub" to "SUB", + "Mul" to "MUL", + "Div" to "DIV", + "Exp" to "EXP", + "Max" to "MAX", + "Min" to "MIN", + "Abs" to "ABS", + "Log" to "LOG", + "Neg" to "NEG", + "Pow" to "POW" + ).withDefault { null } + + } + + val neutralElement = mapOf( + "ADD" to "0.${ReplacementProcessor.toType(primitive)}()", + "MUL" to "1.${ReplacementProcessor.toType(primitive)}()", + "MIN" to "${primitive.typeName}.MAX_VALUE", + "MAX" to "${primitive.typeName}.MIN_VALUE" + ).withDefault { null } + + fun process(expr: KtExpression?, collector: MessageCollector): Triple? { + if (expr == null) return null + if (expr !is KtCallExpression) return null + val args = expr.valueArguments + val name = expr.calleeExpression?.text ?: return null + return when (name) { + "UnaryOp" -> { + if (args.size != 2) return null + val childExpr = args[0].getArgumentExpression() + val (childVector, childLinear, isValue) = process(childExpr, collector) ?: return null + val handle = vectorHandles[args[1].text] ?: return null + val linReplace = unaryLinearReplacements[args[1].text] ?: return null + Triple( + """$childVector + .lanewise(VectorOperators.$handle)""".trimIndent(), + linReplace(childLinear), + isValue + ) + } + + "BinaryOp" -> { + if (args.size != 3) return null + val leftExpr = args[0].getArgumentExpression() ?: return null + val rightExpr = args[1].getArgumentExpression() ?: return null + val (leftVector, leftLinear, leftValue) = process(leftExpr, collector) ?: return null + val (rightVector, rightLinear, rightValue) = process(rightExpr, collector) ?: return null + val isValue = leftValue && rightValue + val handle = vectorHandles[args[2].text] ?: return null + val linear = binaryLinearReplacements[args[2].text]?.invoke(leftLinear, rightLinear) ?: return null + Triple( + if (isValue) + linear + else if (rightValue) + """$leftVector. + lanewise(VectorOperators.$handle, $rightLinear)""".trimIndent() + else if (leftValue && isCommutative[handle] == true) + """$rightVector.lanewise(VectorOperators.$handle, $leftLinear)""" + else if (leftValue) + """$vecName.broadcast($vecSpecies, $leftLinear).lanewise(VectorOperators.$handle, $rightVector)""" + else + """$leftVector + .lanewise(VectorOperators.$handle, $rightVector)""".trimIndent(), + linear, + isValue + ) + } + + "Value" -> { + if (args.size != 1) return null + val replacement = "${args[0].text}" + Triple(replacement, replacement, true) + } + + "PrimitiveSlice" -> { + if (args.size != 2) return null + val src = args[0].text + val offset = args[1].text + Triple( + "${vecName}.fromArray($vecSpecies, $src, $offset + _vec_internal_idx)", + "$src[$offset + _vec_internal_idx]", + false + ) + } + + else -> null + } + } +} From a332e378a38745c87aa2a6ae9d5917701f57dc69 Mon Sep 17 00:00:00 2001 From: Tommaso Dossi Date: Tue, 22 Jul 2025 17:34:31 +0200 Subject: [PATCH 02/16] bugfix --- .../generator/processor/VectorReplacementProcessor.kt | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/generator/processor/VectorReplacementProcessor.kt b/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/generator/processor/VectorReplacementProcessor.kt index d36ba13..0013ffb 100644 --- a/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/generator/processor/VectorReplacementProcessor.kt +++ b/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/generator/processor/VectorReplacementProcessor.kt @@ -121,9 +121,12 @@ internal class VectorReplacementProcessor(val primitive: Primitive<*, *>) { } "PrimitiveSlice" -> { - if (args.size != 2) return null + if (args.size != 2 && args.size != 1) return null val src = args[0].text - val offset = args[1].text + val offset = when (args.size) { + 2 -> args[1].text + else -> "0" + } Triple( "${vecName}.fromArray($vecSpecies, $src, $offset + _vec_internal_idx)", "$src[$offset + _vec_internal_idx]", From 0cf4e21332ac53bfb72dd193e3a133eaad739fa0 Mon Sep 17 00:00:00 2001 From: Tommaso Dossi Date: Wed, 23 Jul 2025 13:49:44 +0200 Subject: [PATCH 03/16] changed vectorization interface + more type checking --- .../primitives/vector/OperationNode.kt | 68 +++++++------------ .../generator/PrimitiveGenerator.kt | 2 +- .../processor/ReplacementProcessor.kt | 56 ++++++++------- .../processor/VectorReplacementProcessor.kt | 61 ++++++++++------- 4 files changed, 95 insertions(+), 92 deletions(-) diff --git a/plugin-build/primitives-annotations/src/commonMain/kotlin/io/kinference/primitives/vector/OperationNode.kt b/plugin-build/primitives-annotations/src/commonMain/kotlin/io/kinference/primitives/vector/OperationNode.kt index f8411d4..1db0936 100644 --- a/plugin-build/primitives-annotations/src/commonMain/kotlin/io/kinference/primitives/vector/OperationNode.kt +++ b/plugin-build/primitives-annotations/src/commonMain/kotlin/io/kinference/primitives/vector/OperationNode.kt @@ -4,63 +4,43 @@ package io.kinference.primitives.vector import io.kinference.primitives.types.PrimitiveArray import io.kinference.primitives.types.PrimitiveType +import io.kinference.primitives.types.toPrimitive sealed class OpNode() { - public fun into(dest: PrimitiveArray, offset: Int, len: Int) { + public fun into(dest: PrimitiveArray, offset: Int, len: Int): Nothing = throw UnsupportedOperationException() - } - public fun reduce(operation: AssociativeWrapper, len: Int): PrimitiveType { - throw UnsupportedOperationException() - } - - internal abstract val isValue: Boolean - //internal abstract fun linReplace(): String - //internal abstract fun vecReplace(): String -} - -class PrimitiveSlice(val src: PrimitiveArray, val offset: Int = 0) : OpNode() { - override val isValue: Boolean = false -} - -class UnaryOp(val arg: OpNode, val operation: UnaryWrapper) : OpNode() { - override val isValue: Boolean = arg.isValue -} - -class BinaryOp(val left: OpNode, val right: OpNode, val operation: BinaryWrapper) : OpNode() { - override val isValue: Boolean = left.isValue && right.isValue -} -class Value(val value: PrimitiveType) : OpNode() { - override val isValue: Boolean = true -} - -sealed class UnaryWrapper() { -} - -sealed class BinaryWrapper() { -} + public fun reduce(operation: AssociativeWrapper, len: Int): PrimitiveType = + throw UnsupportedOperationException() -sealed class AssociativeWrapper() : BinaryWrapper() { } -object Abs : UnaryWrapper() {} +final class PrimitiveSlice(val src: PrimitiveArray, val offset: Int = 0) : OpNode() {} -object Exp : UnaryWrapper() {} +final class Value(val value: PrimitiveType) : OpNode() {} -object Log : UnaryWrapper() {} +sealed class UnaryOp(val arg: OpNode) : OpNode() {} -object Neg : UnaryWrapper() {} +sealed class BinaryOp(val left: OpNode, val right: OpNode) : OpNode() {} -object Add : AssociativeWrapper() {} +sealed class AssociativeWrapper(){} -object Sub : BinaryWrapper() {} +class Exp(arg: OpNode): UnaryOp(arg){} -object Mul : AssociativeWrapper() {} +class Add(left: OpNode, right: OpNode): BinaryOp(left, right){} +class Neg(arg: OpNode): UnaryOp(arg){} +class Log(arg: OpNode): UnaryOp(arg){} -object Div : BinaryWrapper() {} +class Sub(left: OpNode, right: OpNode): BinaryOp(left, right){} +class Mul(left: OpNode, right: OpNode): BinaryOp(left, right){} +class Div(left: OpNode, right: OpNode): BinaryOp(left, right){} +class Pow(left: OpNode, right: OpNode): BinaryOp(left, right){} +class Max(left: OpNode, right: OpNode): BinaryOp(left, right){} +class Min(left: OpNode, right: OpNode): BinaryOp(left, right){} -object Pow : BinaryWrapper() {} +object ADD: AssociativeWrapper(){} +object MUL: AssociativeWrapper(){} +object MAX: AssociativeWrapper(){} -object Max : AssociativeWrapper() {} - -object Min : BinaryWrapper() {} +fun main(){ +} diff --git a/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/generator/PrimitiveGenerator.kt b/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/generator/PrimitiveGenerator.kt index 1e293ea..216dbbd 100644 --- a/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/generator/PrimitiveGenerator.kt +++ b/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/generator/PrimitiveGenerator.kt @@ -210,7 +210,7 @@ internal class PrimitiveGenerator( return } val replacement = replacementProcessor.getReplacement(expression, currentPrimitive) - if (replacement.isNullOrEmpty()) { + if (replacement == null) { super.visitDotQualifiedExpression(expression); return } else builder.append(replacement) } diff --git a/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/generator/processor/ReplacementProcessor.kt b/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/generator/processor/ReplacementProcessor.kt index dff2c85..000ad19 100644 --- a/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/generator/processor/ReplacementProcessor.kt +++ b/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/generator/processor/ReplacementProcessor.kt @@ -13,11 +13,13 @@ import org.jetbrains.kotlin.cli.common.messages.MessageCollector import org.jetbrains.kotlin.com.intellij.openapi.util.Key import org.jetbrains.kotlin.com.intellij.psi.PsiElement import org.jetbrains.kotlin.com.intellij.psi.impl.source.tree.LeafPsiElement +import org.jetbrains.kotlin.js.descriptorUtils.getKotlinTypeFqName import org.jetbrains.kotlin.psi.* import org.jetbrains.kotlin.psi.psiUtil.isExtensionDeclaration import org.jetbrains.kotlin.psi.psiUtil.visibilityModifier import org.jetbrains.kotlin.resolve.BindingContext import org.jetbrains.kotlin.resolve.descriptorUtil.fqNameSafe +import org.jetbrains.kotlin.types.typeUtil.supertypes internal class ReplacementProcessor(private val context: BindingContext, private val collector: MessageCollector) { companion object { @@ -46,7 +48,6 @@ internal class ReplacementProcessor(private val context: BindingContext, private (PrimitiveArray::class.qualifiedName!! + ".Companion") to { it.arrayTypeName } ) - private val nodeTypes = setOf("BinaryOp", "UnaryOp", "PrimitiveSlice", "Value") } @@ -99,19 +100,26 @@ internal class ReplacementProcessor(private val context: BindingContext, private fun getReplacement(expr: KtDotQualifiedExpression, primitive: Primitive<*, *>): String? { val receiver = expr.receiverExpression - val sel = expr.selectorExpression ?: return null - if (sel !is KtCallExpression) return null - val args = sel.valueArguments - val callName = sel.calleeExpression?.text ?: return null - return if (callName == "into" && args.size == 3) { + val selector = expr.selectorExpression ?: return null + if (selector !is KtCallExpression) return null + + val args = selector.valueArguments + val callName = selector.calleeExpression?.text ?: return null + + val receiverType = context.getType(receiver) ?: return null + val receiverSuperTypes = receiverType.supertypes().map { it.getKotlinTypeFqName(false) } + if(VectorReplacementProcessor.opNodeTypename !in receiverSuperTypes) return null + + val vecProcessor = VectorReplacementProcessor(context, primitive) + val (vecReplacement, linReplacement, isValue) = vecProcessor.process(receiver, collector)?: return "" + + if (callName == "into" && args.size == 3) { val dest = args[0].text val destOffset = args[1].text val len = args[2].text - val vecProcessor = VectorReplacementProcessor(primitive) - val (vecReplacement, linearReplacement, _) = vecProcessor.process(receiver, collector) ?: return "" - if (primitive.dataType in DataType.VECTORIZABLE.resolve()) - """ + if (primitive.dataType in DataType.VECTORIZABLE.resolve() && !isValue) + return """ val vectorSpecies = ${vecProcessor.vecSpecies} val vectorLen = vectorSpecies.length() val vecEnd = $len - ($len % vectorLen) @@ -119,25 +127,26 @@ internal class ReplacementProcessor(private val context: BindingContext, private $vecReplacement.intoArray($dest, $destOffset + _vec_internal_idx) } for(_vec_internal_idx in vecEnd until $len) { - $dest[$destOffset + _vec_internal_idx] = $linearReplacement + $dest[$destOffset + _vec_internal_idx] = $linReplacement } """.trimIndent() else - """ - for(_vec_internal_idx in vecEnd until $len) { - $dest[$destOffset + _vec_internal_idx] = $linearReplacement + return """ + for(_vec_internal_idx in 0 until $len) { + $dest[$destOffset + _vec_internal_idx] = $linReplacement }""".trimIndent() } else if (callName == "reduce" && args.size == 2) { - val vecProcessor = VectorReplacementProcessor(primitive) - val opName = args[0].text + val handle = args[0].text val len = args[1].text - val handle = VectorReplacementProcessor.vectorHandles[opName] ?: return "" + + if(VectorReplacementProcessor.isAssoc[handle] != true) return "" val neutral = vecProcessor.neutralElement[handle] ?: return "" - val (vecReplacement, linReplacement, _) = vecProcessor.process(receiver, collector) ?: return "" - val linearOp = VectorReplacementProcessor.binaryLinearReplacements[opName] ?: return "" + + val linearOp = VectorReplacementProcessor.binaryLinearReplacements[handle] ?: return "" val linAccumulate = linearOp("ret", linReplacement) + if (primitive.dataType in DataType.VECTORIZABLE.resolve()) { - """{ + return """{ val vectorSpecies = ${vecProcessor.vecSpecies} val vectorLen = vectorSpecies.length() val vecEnd = $len - ($len % vectorLen) @@ -153,16 +162,17 @@ internal class ReplacementProcessor(private val context: BindingContext, private }.invoke() """.trimIndent() } else { - """{ + return """{ var ret = $neutral - for(_vec_internal_idx in vecEnd until $len) { + for(_vec_internal_idx in 0 until $len) { ret = $linAccumulate } ret }.invoke() """.trimIndent() } - } else null + + } else return "" } diff --git a/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/generator/processor/VectorReplacementProcessor.kt b/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/generator/processor/VectorReplacementProcessor.kt index 0013ffb..68231ad 100644 --- a/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/generator/processor/VectorReplacementProcessor.kt +++ b/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/generator/processor/VectorReplacementProcessor.kt @@ -1,33 +1,38 @@ package io.kinference.primitives.generator.processor import io.kinference.primitives.generator.Primitive +import io.kinference.primitives.vector.* import org.jetbrains.kotlin.cli.common.messages.MessageCollector +import org.jetbrains.kotlin.js.descriptorUtils.getKotlinTypeFqName import org.jetbrains.kotlin.psi.KtCallExpression import org.jetbrains.kotlin.psi.KtExpression +import org.jetbrains.kotlin.resolve.BindingContext +import org.jetbrains.kotlin.types.getAbbreviatedType +import org.jetbrains.kotlin.types.typeUtil.supertypes // TODO: merge linear and vector replacement functions, this is quadratic in op tree size -internal class VectorReplacementProcessor(val primitive: Primitive<*, *>) { +internal class VectorReplacementProcessor(private val context: BindingContext, val primitive: Primitive<*, *>) { val vecName = "${primitive.typeName}Vector" val vecSpecies = "$vecName.SPECIES_PREFERRED" val vecLen = "$vecName.length()" companion object { val unaryLinearReplacements = mapOf( - "Exp" to { x: String -> "exp($x)" }, - "Abs" to { x: String -> "abs($x)" }, - "Neg" to { x: String -> "(-$x)" }, - "Log" to { x: String -> "ln($x)" }, + "EXP" to { x: String -> "exp($x)" }, + "ABS" to { x: String -> "abs($x)" }, + "NEG" to { x: String -> "(-$x)" }, + "LOG" to { x: String -> "ln($x)" }, ).withDefault { null } val binaryLinearReplacements = mapOf( - "Add" to { x: String, y: String -> "($x + $y)" }, - "Sub" to { x: String, y: String -> "($x - $y)" }, - "Mul" to { x: String, y: String -> "($x * $y)" }, - "Div" to { x: String, y: String -> "($x / $y)" }, - "Max" to { x: String, y: String -> "max($x, $y)" }, - "Min" to { x: String, y: String -> "min($x, $y)" }, - "Pow" to { x: String, y: String -> "($x).pow($y)" }, + "ADD" to { x: String, y: String -> "($x + $y)" }, + "SUB" to { x: String, y: String -> "($x - $y)" }, + "MUL" to { x: String, y: String -> "($x * $y)" }, + "DIV" to { x: String, y: String -> "($x / $y)" }, + "MAX" to { x: String, y: String -> "max($x, $y)" }, + "MIN" to { x: String, y: String -> "min($x, $y)" }, + "POW" to { x: String, y: String -> "($x).pow($y)" }, ).withDefault { null } val isAssoc = mapOf( @@ -58,6 +63,12 @@ internal class VectorReplacementProcessor(val primitive: Primitive<*, *>) { "Pow" to "POW" ).withDefault { null } + val opNodeTypename = OpNode::class.qualifiedName + val unaryOpNames = UnaryOp::class.sealedSubclasses.map { it.qualifiedName } + val binaryOpNames = BinaryOp::class.sealedSubclasses.map { it.qualifiedName } + val valueType = Value::class.qualifiedName + val primitiveSliceType = PrimitiveSlice::class.qualifiedName + val associativeWrapperType = AssociativeWrapper::class.qualifiedName } val neutralElement = mapOf( @@ -69,16 +80,18 @@ internal class VectorReplacementProcessor(val primitive: Primitive<*, *>) { fun process(expr: KtExpression?, collector: MessageCollector): Triple? { if (expr == null) return null + val exprType = context.getType(expr) ?: return null + val exprTypename = exprType.getKotlinTypeFqName(false) + val shortName = exprTypename.substringAfterLast('.') if (expr !is KtCallExpression) return null val args = expr.valueArguments - val name = expr.calleeExpression?.text ?: return null - return when (name) { - "UnaryOp" -> { - if (args.size != 2) return null + return when { + exprTypename in unaryOpNames -> { + if (args.size != 1) return null val childExpr = args[0].getArgumentExpression() val (childVector, childLinear, isValue) = process(childExpr, collector) ?: return null - val handle = vectorHandles[args[1].text] ?: return null - val linReplace = unaryLinearReplacements[args[1].text] ?: return null + val handle = vectorHandles[shortName] ?: return null + val linReplace = unaryLinearReplacements[handle] ?: return null Triple( """$childVector .lanewise(VectorOperators.$handle)""".trimIndent(), @@ -87,15 +100,15 @@ internal class VectorReplacementProcessor(val primitive: Primitive<*, *>) { ) } - "BinaryOp" -> { - if (args.size != 3) return null + exprTypename in binaryOpNames -> { + if (args.size != 2) return null val leftExpr = args[0].getArgumentExpression() ?: return null val rightExpr = args[1].getArgumentExpression() ?: return null val (leftVector, leftLinear, leftValue) = process(leftExpr, collector) ?: return null val (rightVector, rightLinear, rightValue) = process(rightExpr, collector) ?: return null val isValue = leftValue && rightValue - val handle = vectorHandles[args[2].text] ?: return null - val linear = binaryLinearReplacements[args[2].text]?.invoke(leftLinear, rightLinear) ?: return null + val handle = vectorHandles[shortName] ?: return null + val linear = binaryLinearReplacements[handle]?.invoke(leftLinear, rightLinear) ?: return null Triple( if (isValue) linear @@ -114,13 +127,13 @@ internal class VectorReplacementProcessor(val primitive: Primitive<*, *>) { ) } - "Value" -> { + exprTypename == valueType -> { if (args.size != 1) return null val replacement = "${args[0].text}" Triple(replacement, replacement, true) } - "PrimitiveSlice" -> { + exprTypename == primitiveSliceType -> { if (args.size != 2 && args.size != 1) return null val src = args[0].text val offset = when (args.size) { From 4c59c36aac3807739c8679fc9d914c9ef19e6d61 Mon Sep 17 00:00:00 2001 From: Tommaso Dossi Date: Thu, 24 Jul 2025 14:42:34 +0200 Subject: [PATCH 04/16] Added conditional expressions and masks --- .../kinference/primitives/types/DataType.kt | 2 +- .../primitives/vector/OperationNode.kt | 78 ++++++-- .../processor/ReplacementProcessor.kt | 9 +- .../processor/VectorReplacementProcessor.kt | 189 ++++++++++++++---- 4 files changed, 220 insertions(+), 58 deletions(-) diff --git a/plugin-build/primitives-annotations/src/commonMain/kotlin/io/kinference/primitives/types/DataType.kt b/plugin-build/primitives-annotations/src/commonMain/kotlin/io/kinference/primitives/types/DataType.kt index 738d399..f469292 100644 --- a/plugin-build/primitives-annotations/src/commonMain/kotlin/io/kinference/primitives/types/DataType.kt +++ b/plugin-build/primitives-annotations/src/commonMain/kotlin/io/kinference/primitives/types/DataType.kt @@ -39,7 +39,7 @@ enum class DataType { * Primitives would remain the same */ fun resolve(): Set { - return when(this) { + return when (this) { ALL -> setOf(BYTE, SHORT, INT, LONG, UBYTE, USHORT, UINT, ULONG, FLOAT, DOUBLE, BOOLEAN) NUMBER -> setOf(BYTE, SHORT, INT, LONG, UBYTE, USHORT, UINT, ULONG, FLOAT, DOUBLE) VECTORIZABLE -> setOf(BYTE, SHORT, INT, LONG, FLOAT, DOUBLE) diff --git a/plugin-build/primitives-annotations/src/commonMain/kotlin/io/kinference/primitives/vector/OperationNode.kt b/plugin-build/primitives-annotations/src/commonMain/kotlin/io/kinference/primitives/vector/OperationNode.kt index 1db0936..f19c5ab 100644 --- a/plugin-build/primitives-annotations/src/commonMain/kotlin/io/kinference/primitives/vector/OperationNode.kt +++ b/plugin-build/primitives-annotations/src/commonMain/kotlin/io/kinference/primitives/vector/OperationNode.kt @@ -5,6 +5,10 @@ package io.kinference.primitives.vector import io.kinference.primitives.types.PrimitiveArray import io.kinference.primitives.types.PrimitiveType import io.kinference.primitives.types.toPrimitive +import io.kinference.primitives.vector.Add +import io.kinference.primitives.vector.BinaryOp +import io.kinference.primitives.vector.Sub +import io.kinference.primitives.vector.UnaryOp sealed class OpNode() { public fun into(dest: PrimitiveArray, offset: Int, len: Int): Nothing = @@ -19,28 +23,72 @@ final class PrimitiveSlice(val src: PrimitiveArray, val offset: Int = 0) : OpNod final class Value(val value: PrimitiveType) : OpNode() {} -sealed class UnaryOp(val arg: OpNode) : OpNode() {} +sealed class UnaryOp(val arg: OpNode) : OpNode() { + constructor(arg: OpNode, mask: VecMask) : this(arg) +} -sealed class BinaryOp(val left: OpNode, val right: OpNode) : OpNode() {} +sealed class BinaryOp(val left: OpNode, val right: OpNode) : OpNode() { + constructor(left: OpNode, right: OpNode, mask: VecMask) : this(left, right) +} -sealed class AssociativeWrapper(){} +class IfElse(val condition: VecMask, val left: OpNode, val right: OpNode) : OpNode() {} -class Exp(arg: OpNode): UnaryOp(arg){} +sealed class AssociativeWrapper(){} -class Add(left: OpNode, right: OpNode): BinaryOp(left, right){} -class Neg(arg: OpNode): UnaryOp(arg){} -class Log(arg: OpNode): UnaryOp(arg){} +class Exp(arg: OpNode): UnaryOp(arg){ + constructor(arg: OpNode, mask: VecMask) : this(arg) +} +class Abs(arg: OpNode): UnaryOp(arg){ + constructor(arg: OpNode, mask: VecMask) : this(arg) +} +class Neg(arg: OpNode): UnaryOp(arg){ + constructor(arg: OpNode, mask: VecMask) : this(arg) +} +class Log(arg: OpNode): UnaryOp(arg){ + constructor(arg: OpNode, mask: VecMask) : this(arg) +} -class Sub(left: OpNode, right: OpNode): BinaryOp(left, right){} -class Mul(left: OpNode, right: OpNode): BinaryOp(left, right){} -class Div(left: OpNode, right: OpNode): BinaryOp(left, right){} -class Pow(left: OpNode, right: OpNode): BinaryOp(left, right){} -class Max(left: OpNode, right: OpNode): BinaryOp(left, right){} -class Min(left: OpNode, right: OpNode): BinaryOp(left, right){} +class Add(left: OpNode, right: OpNode): BinaryOp(left, right){ + constructor(left: OpNode, right: OpNode, mask: VecMask) : this(left, right) +} +class Sub(left: OpNode, right: OpNode): BinaryOp(left, right){ + constructor(left: OpNode, right: OpNode, mask: VecMask) : this(left, right) +} +class Mul(left: OpNode, right: OpNode): BinaryOp(left, right){ + constructor(left: OpNode, right: OpNode, mask: VecMask) : this(left, right) +} +class Div(left: OpNode, right: OpNode): BinaryOp(left, right){ + constructor(left: OpNode, right: OpNode, mask: VecMask) : this(left, right) +} +class Pow(left: OpNode, right: OpNode): BinaryOp(left, right){ + constructor(left: OpNode, right: OpNode, mask: VecMask) : this(left, right) +} +class Max(left: OpNode, right: OpNode): BinaryOp(left, right){ + constructor(left: OpNode, right: OpNode, mask: VecMask) : this(left, right) +} +class Min(left: OpNode, right: OpNode): BinaryOp(left, right){ + constructor(left: OpNode, right: OpNode, mask: VecMask) : this(left, right) +} object ADD: AssociativeWrapper(){} object MUL: AssociativeWrapper(){} object MAX: AssociativeWrapper(){} -fun main(){ -} +sealed class VecMask(){} + +sealed class Comparator(left: OpNode, right: OpNode): VecMask(){} +sealed class MaskBinaryOp(left: VecMask, right: VecMask): VecMask(){} +sealed class MaskUnaryOp(arg: VecMask): VecMask(){} + +class Not(arg: VecMask): VecMask(){} + +class Eq(left: OpNode, right: OpNode): Comparator(left, right){} +class Neq(left: OpNode, right: OpNode): Comparator(left, right){} +class LT(left: OpNode, right: OpNode): Comparator(left, right){} +class LE(left: OpNode, right: OpNode): Comparator(left, right){} +class GT(left: OpNode, right: OpNode): Comparator(left, right){} +class GE(left: OpNode, right: OpNode): Comparator(left, right){} + +class And(left: VecMask, right: VecMask): MaskBinaryOp(left, right){} +class Or(left: VecMask, right: VecMask): MaskBinaryOp(left, right){} +class Xor(left: VecMask, right: VecMask): MaskBinaryOp(left, right){} diff --git a/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/generator/processor/ReplacementProcessor.kt b/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/generator/processor/ReplacementProcessor.kt index 000ad19..9b99819 100644 --- a/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/generator/processor/ReplacementProcessor.kt +++ b/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/generator/processor/ReplacementProcessor.kt @@ -6,6 +6,7 @@ import io.kinference.primitives.annotations.MakePublic import io.kinference.primitives.generator.* import io.kinference.primitives.generator.errors.require import io.kinference.primitives.types.* +import io.kinference.primitives.vector.* import io.kinference.primitives.utils.psi.forced import io.kinference.primitives.utils.psi.isAnnotatedWith import org.jetbrains.kotlin.cli.common.messages.CompilerMessageSeverity @@ -58,8 +59,6 @@ internal class ReplacementProcessor(private val context: BindingContext, private } return null } - if (klass.isAnnotatedWith(context)) - return "" return klass.specialize(primitive, collector) @@ -108,10 +107,10 @@ internal class ReplacementProcessor(private val context: BindingContext, private val receiverType = context.getType(receiver) ?: return null val receiverSuperTypes = receiverType.supertypes().map { it.getKotlinTypeFqName(false) } - if(VectorReplacementProcessor.opNodeTypename !in receiverSuperTypes) return null + if (VectorReplacementProcessor.opNodeTypename !in receiverSuperTypes) return null val vecProcessor = VectorReplacementProcessor(context, primitive) - val (vecReplacement, linReplacement, isValue) = vecProcessor.process(receiver, collector)?: return "" + val (vecReplacement, linReplacement, isValue) = vecProcessor.process(receiver, collector) ?: return "" if (callName == "into" && args.size == 3) { val dest = args[0].text @@ -139,7 +138,7 @@ internal class ReplacementProcessor(private val context: BindingContext, private val handle = args[0].text val len = args[1].text - if(VectorReplacementProcessor.isAssoc[handle] != true) return "" + if (VectorReplacementProcessor.isAssoc[handle] != true) return "" val neutral = vecProcessor.neutralElement[handle] ?: return "" val linearOp = VectorReplacementProcessor.binaryLinearReplacements[handle] ?: return "" diff --git a/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/generator/processor/VectorReplacementProcessor.kt b/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/generator/processor/VectorReplacementProcessor.kt index 68231ad..ecc7595 100644 --- a/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/generator/processor/VectorReplacementProcessor.kt +++ b/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/generator/processor/VectorReplacementProcessor.kt @@ -1,16 +1,15 @@ package io.kinference.primitives.generator.processor +import com.sun.org.apache.xpath.internal.operations.Bool import io.kinference.primitives.generator.Primitive import io.kinference.primitives.vector.* +import org.gradle.internal.impldep.org.h2.engine.Right import org.jetbrains.kotlin.cli.common.messages.MessageCollector import org.jetbrains.kotlin.js.descriptorUtils.getKotlinTypeFqName import org.jetbrains.kotlin.psi.KtCallExpression import org.jetbrains.kotlin.psi.KtExpression import org.jetbrains.kotlin.resolve.BindingContext -import org.jetbrains.kotlin.types.getAbbreviatedType -import org.jetbrains.kotlin.types.typeUtil.supertypes -// TODO: merge linear and vector replacement functions, this is quadratic in op tree size internal class VectorReplacementProcessor(private val context: BindingContext, val primitive: Primitive<*, *>) { val vecName = "${primitive.typeName}Vector" @@ -42,13 +41,6 @@ internal class VectorReplacementProcessor(private val context: BindingContext, v "MAX" to true, ).withDefault { false } - val isCommutative = mapOf( - "ADD" to true, - "MUL" to true, - "MIN" to true, - "MAX" to true, - ).withDefault { false } - val vectorHandles = mapOf( "Add" to "ADD", "Sub" to "SUB", @@ -60,15 +52,52 @@ internal class VectorReplacementProcessor(private val context: BindingContext, v "Abs" to "ABS", "Log" to "LOG", "Neg" to "NEG", - "Pow" to "POW" + "Pow" to "POW", + ).withDefault { null } + + val maskHandles = mapOf( + "And" to "AND", + "Or" to "OR", + "Xor" to "XOR", + "Not" to "NOT", + "Eq" to "EQ", + "Neq" to "NEQ", + "LT" to "LT", + "LE" to "LE", + "GT" to "GT", + "GE" to "GE" + ).withDefault { null } + + val maskUnaryReplacement = mapOf( + "Not" to ({ x: String -> "$x.not()" }), ).withDefault { null } + val maskBinaryReplacement = mapOf( + "And" to ({ x: String, y: String -> "($x.and($y)" }), + "Or" to ({ x: String, y: String -> "$x.or($y)" }), + "Xor" to ({ x: String, y: String -> "$x.xor($y)" }), + ).withDefault { null } + + val comparatorReplacement = mapOf( + "Eq" to ({ x: String, y: String -> "($x == $y)" }), + "Neq" to ({ x: String, y: String -> "($x != $y)" }), + "LT" to ({ x: String, y: String -> "($x < $y)" }), + "LE" to ({ x: String, y: String -> "($x <= $y)" }), + "GT" to ({ x: String, y: String -> "($x > $y)" }), + "GE" to ({ x: String, y: String -> "($x >= $y)" }), + ) + val opNodeTypename = OpNode::class.qualifiedName val unaryOpNames = UnaryOp::class.sealedSubclasses.map { it.qualifiedName } val binaryOpNames = BinaryOp::class.sealedSubclasses.map { it.qualifiedName } val valueType = Value::class.qualifiedName val primitiveSliceType = PrimitiveSlice::class.qualifiedName val associativeWrapperType = AssociativeWrapper::class.qualifiedName + val maskTypes = VecMask::class.sealedSubclasses.map { it.qualifiedName } + val maskBinaryOpTypes = MaskBinaryOp::class.sealedSubclasses.map { it.qualifiedName } + val maskUnaryOpTypes = MaskUnaryOp::class.sealedSubclasses.map { it.qualifiedName } + val comparatorTypes = Comparator::class.sealedSubclasses.map { it.qualifiedName } + val ifElseType = IfElse::class.qualifiedName } val neutralElement = mapOf( @@ -87,41 +116,56 @@ internal class VectorReplacementProcessor(private val context: BindingContext, v val args = expr.valueArguments return when { exprTypename in unaryOpNames -> { - if (args.size != 1) return null + if (args.size != 1 && args.size != 2) return null val childExpr = args[0].getArgumentExpression() - val (childVector, childLinear, isValue) = process(childExpr, collector) ?: return null + val masked = args.size == 2 + var (childVector, childLinear, isValue) = process(childExpr, collector) ?: return null val handle = vectorHandles[shortName] ?: return null val linReplace = unaryLinearReplacements[handle] ?: return null - Triple( - """$childVector - .lanewise(VectorOperators.$handle)""".trimIndent(), - linReplace(childLinear), - isValue - ) + var linear = linReplace(childLinear) + var vectorized = """$childVector + .lanewise(VectorOperators.$handle""".trimIndent() + if (masked) { + isValue = false + val maskExpr = args[1].getArgumentExpression() ?: return null + val (maskVector, maskLinear) = processMask(maskExpr, collector) ?: return null + linear = "(if($maskLinear) $linear else $childLinear)" + vectorized += ", $maskVector)" + } else { + vectorized += ")" + } + Triple(vectorized, linear, isValue) } exprTypename in binaryOpNames -> { - if (args.size != 2) return null + if (args.size != 2 && args.size != 3) return null + val masked = args.size == 3 val leftExpr = args[0].getArgumentExpression() ?: return null val rightExpr = args[1].getArgumentExpression() ?: return null + val handle = vectorHandles[shortName] ?: return null val (leftVector, leftLinear, leftValue) = process(leftExpr, collector) ?: return null val (rightVector, rightLinear, rightValue) = process(rightExpr, collector) ?: return null - val isValue = leftValue && rightValue - val handle = vectorHandles[shortName] ?: return null - val linear = binaryLinearReplacements[handle]?.invoke(leftLinear, rightLinear) ?: return null + var isValue = leftValue && rightValue + var linear = binaryLinearReplacements[handle]?.invoke(leftLinear, rightLinear) ?: return null + + var vectorized = if (rightValue) + """$leftVector. + lanewise(VectorOperators.$handle, $rightLinear""".trimIndent() + else + """$leftVector + .lanewise(VectorOperators.$handle, $rightVector""".trimIndent() + + if (masked) { + isValue = false + val maskExpr = args[2].getArgumentExpression() ?: return null + val (maskVector, maskLinear) = processMask(maskExpr, collector) ?: return null + linear = "(if($maskLinear) $linear else $leftLinear)" + vectorized += ", $maskVector)" + } else { + vectorized += ")" + } Triple( - if (isValue) - linear - else if (rightValue) - """$leftVector. - lanewise(VectorOperators.$handle, $rightLinear)""".trimIndent() - else if (leftValue && isCommutative[handle] == true) - """$rightVector.lanewise(VectorOperators.$handle, $leftLinear)""" - else if (leftValue) - """$vecName.broadcast($vecSpecies, $leftLinear).lanewise(VectorOperators.$handle, $rightVector)""" - else - """$leftVector - .lanewise(VectorOperators.$handle, $rightVector)""".trimIndent(), + vectorized, linear, isValue ) @@ -129,8 +173,9 @@ internal class VectorReplacementProcessor(private val context: BindingContext, v exprTypename == valueType -> { if (args.size != 1) return null - val replacement = "${args[0].text}" - Triple(replacement, replacement, true) + val linear = "${args[0].text}" + val vectorized = "$vecName.broadcast($vecSpecies, $linear)" + Triple(vectorized, linear, true) } exprTypename == primitiveSliceType -> { @@ -147,6 +192,76 @@ internal class VectorReplacementProcessor(private val context: BindingContext, v ) } + exprTypename == ifElseType -> { + if (args.size != 3) return null + val mask = args[0].getArgumentExpression() ?: return null + val left = args[1].getArgumentExpression() ?: return null + val right = args[2].getArgumentExpression() ?: return null + val (maskVector, maskLinear) = processMask(mask, collector) ?: return null + val (leftVector, leftLinear, leftValue) = process(left, collector) ?: return null + val (rightVector, rightLinear, rightValue) = process(right, collector) ?: return null + val isValue = leftValue && rightValue + val linear = "(if($maskLinear) $leftLinear else $rightLinear)" + val vectorized = """ + $rightVector.blend($leftVector, $maskVector)""".trimIndent() + Triple(vectorized, linear, isValue) + } + + else -> null + } + } + + fun processMask(expr: KtExpression?, collector: MessageCollector): Pair? { + if (expr == null) return null + val exprType = context.getType(expr) ?: return null + val exprTypename = exprType.getKotlinTypeFqName(false) + val shortName = exprTypename.substringAfterLast('.') + if (expr !is KtCallExpression) return null + val args = expr.valueArguments + return when { + exprTypename in maskUnaryOpTypes -> { + if (args.size != 1) return null + val child = args[0].getArgumentExpression() ?: return null + val (vecReplacement, linReplacement) = processMask(child, collector) ?: return null + val handle = maskHandles[shortName] ?: return null + val linReplacer = maskUnaryReplacement[handle] ?: return null + + Pair( + linReplacer(vecReplacement), + linReplacer(linReplacement) + ) + } + + exprTypename in maskBinaryOpTypes -> { + if (args.size != 2) return null + val left = args[0].getArgumentExpression() ?: return null + val (leftVecReplacement, leftLinReplacement) = processMask(left, collector) ?: return null + val right = args[0].getArgumentExpression() ?: return null + val (rightVecReplacement, rightLinReplacement) = processMask(right, collector) ?: return null + val handle = maskHandles[shortName] ?: return null + val linReplacer = maskBinaryReplacement[handle] ?: return null + Pair( + linReplacer(leftVecReplacement, rightVecReplacement), + linReplacer(leftLinReplacement, rightLinReplacement) + ) + } + + exprTypename in comparatorTypes -> { + if (args.size != 2) return null + val leftExpr = args[0].getArgumentExpression() ?: return null + val rightExpr = args[1].getArgumentExpression() ?: return null + val handle = maskHandles[shortName] ?: return null + val (leftVector, leftLinear, leftValue) = process(leftExpr, collector) ?: return null + val (rightVector, rightLinear, rightValue) = process(rightExpr, collector) ?: return null + val isValue = leftValue && rightValue + val linear = comparatorReplacement[handle]?.invoke(leftLinear, rightLinear) ?: return null + val vectorized = """ + $leftVector + .compare(VectorOperators.$handle, $rightVector) + """.trimIndent() + Pair(vectorized, linear) + } + else -> null } } From 5d551d75a0148003152099d2487e067a7aeeb123 Mon Sep 17 00:00:00 2001 From: Tommaso Dossi Date: Thu, 24 Jul 2025 17:15:58 +0200 Subject: [PATCH 05/16] bugfix --- .../primitives/generator/processor/ReplacementProcessor.kt | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/generator/processor/ReplacementProcessor.kt b/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/generator/processor/ReplacementProcessor.kt index 9b99819..7b69392 100644 --- a/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/generator/processor/ReplacementProcessor.kt +++ b/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/generator/processor/ReplacementProcessor.kt @@ -126,13 +126,13 @@ internal class ReplacementProcessor(private val context: BindingContext, private $vecReplacement.intoArray($dest, $destOffset + _vec_internal_idx) } for(_vec_internal_idx in vecEnd until $len) { - $dest[$destOffset + _vec_internal_idx] = $linReplacement + $dest[$destOffset + _vec_internal_idx] = $linReplacement.${toType(primitive)}() } """.trimIndent() else return """ for(_vec_internal_idx in 0 until $len) { - $dest[$destOffset + _vec_internal_idx] = $linReplacement + $dest[$destOffset + _vec_internal_idx] = $linReplacement.${toType(primitive)}() }""".trimIndent() } else if (callName == "reduce" && args.size == 2) { val handle = args[0].text From 452f5348e021f7fc389314e0c09b2101ad08c293 Mon Sep 17 00:00:00 2001 From: Tommaso Dossi Date: Fri, 25 Jul 2025 12:22:14 +0200 Subject: [PATCH 06/16] Avoid redeclarations --- .../generator/PrimitiveGenerator.kt | 8 +++- .../processor/ReplacementProcessor.kt | 43 ++++++++++--------- .../processor/VectorReplacementProcessor.kt | 4 +- 3 files changed, 31 insertions(+), 24 deletions(-) diff --git a/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/generator/PrimitiveGenerator.kt b/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/generator/PrimitiveGenerator.kt index 216dbbd..e6782a7 100644 --- a/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/generator/PrimitiveGenerator.kt +++ b/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/generator/PrimitiveGenerator.kt @@ -24,6 +24,7 @@ internal class PrimitiveGenerator( ) { private data class PrimitiveContext(val type1: Primitive<*, *>? = null, val type2: Primitive<*, *>? = null, val type3: Primitive<*, *>? = null) + private var vecCount = 0; fun generate(): Set { val results = HashSet() @@ -209,10 +210,13 @@ internal class PrimitiveGenerator( super.visitDotQualifiedExpression(expression) return } - val replacement = replacementProcessor.getReplacement(expression, currentPrimitive) + val replacement = replacementProcessor.getReplacement(expression, currentPrimitive, vecCount) if (replacement == null) { super.visitDotQualifiedExpression(expression); return - } else builder.append(replacement) + } else { + vecCount += 1 + builder.append(replacement) + } } }) diff --git a/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/generator/processor/ReplacementProcessor.kt b/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/generator/processor/ReplacementProcessor.kt index 7b69392..31aeaea 100644 --- a/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/generator/processor/ReplacementProcessor.kt +++ b/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/generator/processor/ReplacementProcessor.kt @@ -97,7 +97,7 @@ internal class ReplacementProcessor(private val context: BindingContext, private } } - fun getReplacement(expr: KtDotQualifiedExpression, primitive: Primitive<*, *>): String? { + fun getReplacement(expr: KtDotQualifiedExpression, primitive: Primitive<*, *>, idx: Int): String? { val receiver = expr.receiverExpression val selector = expr.selectorExpression ?: return null if (selector !is KtCallExpression) return null @@ -112,6 +112,11 @@ internal class ReplacementProcessor(private val context: BindingContext, private val vecProcessor = VectorReplacementProcessor(context, primitive) val (vecReplacement, linReplacement, isValue) = vecProcessor.process(receiver, collector) ?: return "" + val toPrimitive = "${toType(primitive)}()" + val vecLen = "_vecLen_$idx" + val vecEnd = "_vecEnd_$idx" + val vecIdx = "_vec_internal_idx" + if (callName == "into" && args.size == 3) { val dest = args[0].text val destOffset = args[1].text @@ -119,20 +124,19 @@ internal class ReplacementProcessor(private val context: BindingContext, private if (primitive.dataType in DataType.VECTORIZABLE.resolve() && !isValue) return """ - val vectorSpecies = ${vecProcessor.vecSpecies} - val vectorLen = vectorSpecies.length() - val vecEnd = $len - ($len % vectorLen) - for (_vec_internal_idx in 0 until vecEnd step vectorLen) { + val $vecLen = ${vecProcessor.vecSpecies}.length() + val $vecEnd = $len - ($len % $vecLen) + for ($vecIdx in 0 until $vecEnd step $vecLen) { $vecReplacement.intoArray($dest, $destOffset + _vec_internal_idx) } - for(_vec_internal_idx in vecEnd until $len) { - $dest[$destOffset + _vec_internal_idx] = $linReplacement.${toType(primitive)}() + for($vecIdx in $vecEnd until $len) { + $dest[$destOffset + $vecIdx] = $linReplacement.$toPrimitive } """.trimIndent() else return """ - for(_vec_internal_idx in 0 until $len) { - $dest[$destOffset + _vec_internal_idx] = $linReplacement.${toType(primitive)}() + for($vecIdx in 0 until $len) { + $dest[$destOffset + $vecIdx] = $linReplacement.$toPrimitive }""".trimIndent() } else if (callName == "reduce" && args.size == 2) { val handle = args[0].text @@ -146,27 +150,26 @@ internal class ReplacementProcessor(private val context: BindingContext, private if (primitive.dataType in DataType.VECTORIZABLE.resolve()) { return """{ - val vectorSpecies = ${vecProcessor.vecSpecies} - val vectorLen = vectorSpecies.length() - val vecEnd = $len - ($len % vectorLen) - var accumulator = ${vecProcessor.vecName}.broadcast(vectorSpecies, $neutral) - for (_vec_internal_idx in 0 until vecEnd step vectorLen) { + val $vecLen = ${vecProcessor.vecSpecies}.length() + val $vecEnd = $len - ($len % $vecLen) + var accumulator = ${vecProcessor.vecName}.broadcast(${vecProcessor.vecSpecies}, $neutral) + for ($vecIdx in 0 until $vecEnd step $vecLen) { accumulator = accumulator.lanewise(VectorOperators.$handle, $vecReplacement) } var ret = accumulator.reduceLanes(VectorOperators.$handle) - for(_vec_internal_idx in vecEnd until $len) { - ret = $linAccumulate + for($vecIdx in $vecEnd until $len) { + ret = $linAccumulate.$toPrimitive } - ret + ret.$toPrimitive }.invoke() """.trimIndent() } else { return """{ var ret = $neutral - for(_vec_internal_idx in 0 until $len) { - ret = $linAccumulate + for($vecIdx in 0 until $len) { + ret = $linAccumulate.$toPrimitive } - ret + ret.$toPrimitive }.invoke() """.trimIndent() } diff --git a/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/generator/processor/VectorReplacementProcessor.kt b/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/generator/processor/VectorReplacementProcessor.kt index ecc7595..55165d4 100644 --- a/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/generator/processor/VectorReplacementProcessor.kt +++ b/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/generator/processor/VectorReplacementProcessor.kt @@ -29,8 +29,8 @@ internal class VectorReplacementProcessor(private val context: BindingContext, v "SUB" to { x: String, y: String -> "($x - $y)" }, "MUL" to { x: String, y: String -> "($x * $y)" }, "DIV" to { x: String, y: String -> "($x / $y)" }, - "MAX" to { x: String, y: String -> "max($x, $y)" }, - "MIN" to { x: String, y: String -> "min($x, $y)" }, + "MAX" to { x: String, y: String -> "maxOf($x, $y)" }, + "MIN" to { x: String, y: String -> "minOf($x, $y)" }, "POW" to { x: String, y: String -> "($x).pow($y)" }, ).withDefault { null } From 1fb8850f01f93b229706db8d2332786d5dc5dd35 Mon Sep 17 00:00:00 2001 From: Tommaso Dossi Date: Fri, 25 Jul 2025 12:22:53 +0200 Subject: [PATCH 07/16] added min as associative operator --- .../kotlin/io/kinference/primitives/vector/OperationNode.kt | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/plugin-build/primitives-annotations/src/commonMain/kotlin/io/kinference/primitives/vector/OperationNode.kt b/plugin-build/primitives-annotations/src/commonMain/kotlin/io/kinference/primitives/vector/OperationNode.kt index f19c5ab..0c81ac0 100644 --- a/plugin-build/primitives-annotations/src/commonMain/kotlin/io/kinference/primitives/vector/OperationNode.kt +++ b/plugin-build/primitives-annotations/src/commonMain/kotlin/io/kinference/primitives/vector/OperationNode.kt @@ -11,7 +11,7 @@ import io.kinference.primitives.vector.Sub import io.kinference.primitives.vector.UnaryOp sealed class OpNode() { - public fun into(dest: PrimitiveArray, offset: Int, len: Int): Nothing = + public fun into(dest: PrimitiveArray, offset: Int, len: Int): Unit = throw UnsupportedOperationException() public fun reduce(operation: AssociativeWrapper, len: Int): PrimitiveType = @@ -73,6 +73,7 @@ class Min(left: OpNode, right: OpNode): BinaryOp(left, right){ object ADD: AssociativeWrapper(){} object MUL: AssociativeWrapper(){} object MAX: AssociativeWrapper(){} +object MIN: AssociativeWrapper(){} sealed class VecMask(){} From 3146dbfcc71b7d3d57088ed66b9f1bcb3e259041 Mon Sep 17 00:00:00 2001 From: Tommaso Dossi Date: Fri, 25 Jul 2025 18:36:49 +0200 Subject: [PATCH 08/16] experimental local variables support --- .../generator/PrimitiveGenerator.kt | 19 +++++++ .../processor/ReplacementProcessor.kt | 3 ++ .../processor/VectorReplacementProcessor.kt | 49 ++++++++++++++++++- 3 files changed, 69 insertions(+), 2 deletions(-) diff --git a/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/generator/PrimitiveGenerator.kt b/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/generator/PrimitiveGenerator.kt index e6782a7..3df3ac5 100644 --- a/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/generator/PrimitiveGenerator.kt +++ b/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/generator/PrimitiveGenerator.kt @@ -4,6 +4,7 @@ import io.kinference.primitives.annotations.* import io.kinference.primitives.generator.errors.require import io.kinference.primitives.generator.processor.RemovalProcessor import io.kinference.primitives.generator.processor.ReplacementProcessor +import io.kinference.primitives.generator.processor.VectorReplacementProcessor import io.kinference.primitives.types.DataType import io.kinference.primitives.utils.crossProduct import io.kinference.primitives.utils.psi.* @@ -11,11 +12,14 @@ import org.jetbrains.kotlin.cli.common.messages.CompilerMessageSeverity import org.jetbrains.kotlin.cli.common.messages.MessageCollector import org.jetbrains.kotlin.com.intellij.psi.PsiWhiteSpace import org.jetbrains.kotlin.com.intellij.psi.impl.source.tree.LeafPsiElement +import org.jetbrains.kotlin.js.descriptorUtils.getKotlinTypeFqName import org.jetbrains.kotlin.lexer.KtTokens import org.jetbrains.kotlin.psi.* import org.jetbrains.kotlin.psi.psiUtil.visibilityModifier import org.jetbrains.kotlin.resolve.BindingContext +import org.jetbrains.kotlin.types.typeUtil.supertypes import java.io.File +import java.util.Vector import kotlin.io.path.Path internal class PrimitiveGenerator( @@ -24,6 +28,7 @@ internal class PrimitiveGenerator( ) { private data class PrimitiveContext(val type1: Primitive<*, *>? = null, val type2: Primitive<*, *>? = null, val type3: Primitive<*, *>? = null) + private var vecCount = 0; fun generate(): Set { @@ -200,6 +205,20 @@ internal class PrimitiveGenerator( } } + override fun visitDeclaration(dcl: KtDeclaration) { + if (file.isAnnotatedWith(context) && dcl is KtProperty) { + val init = dcl.initializer + if (init != null) { + val type = context.getType(init) + if (type != null) { + val supertypes = type.supertypes().map { it.getKotlinTypeFqName(false) }.toSet() + if (VectorReplacementProcessor.opNodeTypename in supertypes) return + } + } + } + super.visitDeclaration(dcl) + } + override fun visitSimpleNameExpression(expression: KtSimpleNameExpression) { val replacement = replacementProcessor.getReplacement(expression, currentPrimitive) builder.append(replacement ?: expression.text) diff --git a/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/generator/processor/ReplacementProcessor.kt b/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/generator/processor/ReplacementProcessor.kt index 31aeaea..77835ad 100644 --- a/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/generator/processor/ReplacementProcessor.kt +++ b/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/generator/processor/ReplacementProcessor.kt @@ -126,10 +126,13 @@ internal class ReplacementProcessor(private val context: BindingContext, private return """ val $vecLen = ${vecProcessor.vecSpecies}.length() val $vecEnd = $len - ($len % $vecLen) + ${vecProcessor.valueDeclarations} for ($vecIdx in 0 until $vecEnd step $vecLen) { + ${vecProcessor.vecDeclarations} $vecReplacement.intoArray($dest, $destOffset + _vec_internal_idx) } for($vecIdx in $vecEnd until $len) { + ${vecProcessor.linDeclarations} $dest[$destOffset + $vecIdx] = $linReplacement.$toPrimitive } """.trimIndent() diff --git a/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/generator/processor/VectorReplacementProcessor.kt b/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/generator/processor/VectorReplacementProcessor.kt index 55165d4..3856491 100644 --- a/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/generator/processor/VectorReplacementProcessor.kt +++ b/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/generator/processor/VectorReplacementProcessor.kt @@ -3,12 +3,26 @@ package io.kinference.primitives.generator.processor import com.sun.org.apache.xpath.internal.operations.Bool import io.kinference.primitives.generator.Primitive import io.kinference.primitives.vector.* +import org.gradle.internal.impldep.com.esotericsoftware.kryo.serializers.FieldSerializer import org.gradle.internal.impldep.org.h2.engine.Right +import org.jetbrains.kotlin.cfg.getElementParentDeclaration import org.jetbrains.kotlin.cli.common.messages.MessageCollector import org.jetbrains.kotlin.js.descriptorUtils.getKotlinTypeFqName import org.jetbrains.kotlin.psi.KtCallExpression import org.jetbrains.kotlin.psi.KtExpression +import org.jetbrains.kotlin.psi.KtSimpleNameExpression +import org.jetbrains.kotlin.psi.psiUtil.isContextualDeclaration import org.jetbrains.kotlin.resolve.BindingContext +import org.jetbrains.kotlin.resolve.calls.util.getType +import org.jetbrains.kotlin.com.intellij.psi.PsiElement +import org.jetbrains.kotlin.descriptors.impl.referencedProperty +import org.jetbrains.kotlin.gradle.utils.loadPropertyFromResources +import org.jetbrains.kotlin.load.kotlin.toSourceElement +import org.jetbrains.kotlin.psi.KtDeclaration +import org.jetbrains.kotlin.psi.KtProperty +import org.jetbrains.kotlin.psi.declarationRecursiveVisitor +import org.jetbrains.kotlin.resolve.source.getPsi +import javax.naming.Binding internal class VectorReplacementProcessor(private val context: BindingContext, val primitive: Primitive<*, *>) { @@ -88,6 +102,7 @@ internal class VectorReplacementProcessor(private val context: BindingContext, v ) val opNodeTypename = OpNode::class.qualifiedName + val opNodeTypes = OpNode::class.sealedSubclasses.map { it.qualifiedName } val unaryOpNames = UnaryOp::class.sealedSubclasses.map { it.qualifiedName } val binaryOpNames = BinaryOp::class.sealedSubclasses.map { it.qualifiedName } val valueType = Value::class.qualifiedName @@ -107,12 +122,42 @@ internal class VectorReplacementProcessor(private val context: BindingContext, v "MAX" to "${primitive.typeName}.MIN_VALUE" ).withDefault { null } + var valueDeclarations: String = "" + var vecDeclarations: String = "" + var linDeclarations: String = "" + var localVariables: Set = emptySet() + + fun processDeclaration(expr: KtExpression, collector: MessageCollector): Triple? { + expr as? KtSimpleNameExpression ?: return null // Triple("NOT SIMPLE: ${expr.text}", "a", false) + val varName = expr.text + val descriptor = context.get(BindingContext.REFERENCE_TARGET, expr) ?: return null //Triple("NOT_DECL: ${expr.text}", "NOT_DECL", false) + val declaration = descriptor.toSourceElement.getPsi() ?: return null //Triple("NOT_DECL_PSI: ${expr.text}", "NOT_DECL", false) + declaration as? KtDeclaration ?: return null //Triple("NOT_DECL_EXPR: ${expr.text}", "NOT_DECL", false) + declaration as? KtProperty ?: return null // Triple("NOT_DECL_PROPERTY: ${expr.text}", "NOT_DECL", false) + val actualBody = declaration.initializer ?: return null //Triple("NOT_DECL_BODY: ${expr.text}", "NOT_DECL", false) + //return Triple("BODY: ${actualBody.text}", "", false) + val (vecReplacement, linReplacement, value) = process(actualBody, collector) ?: return null + if (varName !in localVariables) { + localVariables = localVariables + varName + if (value) { + valueDeclarations += "val ${varName}_vec = $vecReplacement\n" + valueDeclarations += "val ${varName}_lin = $linReplacement\n" + } else { + vecDeclarations += "val ${varName}_vec = $vecReplacement\n" + linDeclarations += "val ${varName}_lin = $linReplacement\n" + } + } + return Triple("${varName}_vec", "${varName}_lin", value) + } + fun process(expr: KtExpression?, collector: MessageCollector): Triple? { if (expr == null) return null - val exprType = context.getType(expr) ?: return null + if (expr !is KtCallExpression) { + return processDeclaration(expr, collector) + } + val exprType = context.getType(expr) ?: return Triple("NOT_TYPED", "NOT_TYPED", false) val exprTypename = exprType.getKotlinTypeFqName(false) val shortName = exprTypename.substringAfterLast('.') - if (expr !is KtCallExpression) return null val args = expr.valueArguments return when { exprTypename in unaryOpNames -> { From 2b8ad2fb00271b9a26f21b5fc6c18638c5389d50 Mon Sep 17 00:00:00 2001 From: Tommaso Dossi Date: Wed, 30 Jul 2025 09:47:14 +0200 Subject: [PATCH 09/16] option to turn off vector API --- .../primitives/PrimitivesExtension.kt | 5 ++ .../primitives/PrimitivesGradlePlugin.kt | 3 +- .../kinference/primitives/PrimitivesTask.kt | 6 +- .../generator/PrimitiveGenerator.kt | 14 ++-- .../processor/ReplacementProcessor.kt | 75 ++++++++++++------- .../processor/VectorReplacementProcessor.kt | 54 ++++--------- 6 files changed, 85 insertions(+), 72 deletions(-) diff --git a/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/PrimitivesExtension.kt b/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/PrimitivesExtension.kt index c537ca6..13e86ba 100644 --- a/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/PrimitivesExtension.kt +++ b/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/PrimitivesExtension.kt @@ -1,7 +1,10 @@ package io.kinference.primitives +import com.sun.org.apache.xpath.internal.operations.Bool import org.gradle.api.Project import org.gradle.api.file.DirectoryProperty +import org.gradle.api.provider.Property +import org.jetbrains.kotlin.ir.declarations.DescriptorMetadataSource import javax.inject.Inject open class PrimitivesExtension @Inject constructor( @@ -9,6 +12,8 @@ open class PrimitivesExtension @Inject constructor( ) { private val objects = project.objects + var vectorize: Boolean = false + val generationPath: DirectoryProperty = objects.directoryProperty().convention( project.layout.buildDirectory.dir("generated/primitives") ) diff --git a/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/PrimitivesGradlePlugin.kt b/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/PrimitivesGradlePlugin.kt index 243f94f..b61afec 100644 --- a/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/PrimitivesGradlePlugin.kt +++ b/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/PrimitivesGradlePlugin.kt @@ -12,7 +12,6 @@ class PrimitivesGradlePlugin : Plugin { val primitivesExt = project.extensions.create(extensionName, PrimitivesExtension::class.java) - val primitivesCache = project.gradle.sharedServices.registerIfAbsent("${project.path}_${primitivesCacheName}", PrimitivesCache::class.java) { it.maxParallelUsages.set(1) } @@ -21,6 +20,7 @@ class PrimitivesGradlePlugin : Plugin { it.group = "generate" } + kotlinExt.sourceSets.all { sourceSet -> sourceSet.kotlin.srcDir(primitivesExt.generationPath.dir(sourceSet.name)) primitivesCache.get().sourceSetToResolved[sourceSet.name] = false @@ -40,6 +40,7 @@ class PrimitivesGradlePlugin : Plugin { primitiveTask.inputFiles.from(compileTask.sources) primitiveTask.libraries.from(compileTask.libraries) primitiveTask.compilation.set(compilation) + primitiveTask.vectorize.set(primitivesExt.vectorize) } compileTask.dependsOn(primitivesTask) diff --git a/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/PrimitivesTask.kt b/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/PrimitivesTask.kt index bc3b02c..b2ea59d 100644 --- a/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/PrimitivesTask.kt +++ b/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/PrimitivesTask.kt @@ -19,6 +19,9 @@ import org.jetbrains.kotlin.gradle.plugin.* import java.io.File abstract class PrimitivesTask : DefaultTask() { + @get:Input + abstract val vectorize: Property + @get:Internal abstract val generationPath: DirectoryProperty @@ -41,6 +44,7 @@ abstract class PrimitivesTask : DefaultTask() { init { group = "generate" description = "Generates primitives from sources" + vectorize.convention(false) } @TaskAction @@ -89,7 +93,7 @@ abstract class PrimitivesTask : DefaultTask() { val sourceSet = findSourceSetName(ktFile.virtualFilePath) val outputDir = generationPath.dir(sourceSet).get().asFile - PrimitiveGenerator(ktFile, result.bindingContext, outputDir, MessageCollector.NONE).generate() + PrimitiveGenerator(ktFile, result.bindingContext, outputDir, MessageCollector.NONE, vectorize.get()).generate() primitivesCache.get().resolvedPaths.add(ktFile.virtualFilePath) } diff --git a/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/generator/PrimitiveGenerator.kt b/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/generator/PrimitiveGenerator.kt index 3df3ac5..cc29ee2 100644 --- a/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/generator/PrimitiveGenerator.kt +++ b/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/generator/PrimitiveGenerator.kt @@ -19,12 +19,13 @@ import org.jetbrains.kotlin.psi.psiUtil.visibilityModifier import org.jetbrains.kotlin.resolve.BindingContext import org.jetbrains.kotlin.types.typeUtil.supertypes import java.io.File -import java.util.Vector -import kotlin.io.path.Path internal class PrimitiveGenerator( - private val file: KtFile, private val context: BindingContext, private val output: File, - private val collector: MessageCollector + private val file: KtFile, + private val context: BindingContext, + private val output: File, + private val collector: MessageCollector, + private val vectorize: Boolean = false ) { private data class PrimitiveContext(val type1: Primitive<*, *>? = null, val type2: Primitive<*, *>? = null, val type3: Primitive<*, *>? = null) @@ -43,7 +44,7 @@ internal class PrimitiveGenerator( val builder = StringBuilder() val removalProcessor = RemovalProcessor(context) - val replacementProcessor = ReplacementProcessor(context, collector) + val replacementProcessor = ReplacementProcessor(context, collector, vectorize) file.accept(object : KtDefaultVisitor() { private var currentPrimitive = primitive @@ -69,7 +70,8 @@ internal class PrimitiveGenerator( } override fun visitImportList(importList: KtImportList) { - if (file.isAnnotatedWith(context) && primitive.dataType in DataType.VECTORIZABLE.resolve()) + if (file.isAnnotatedWith(context) && primitive.dataType in DataType.VECTORIZABLE.resolve() && vectorize) + builder.appendLine("import io.kinference.ndarray.VecUtils.isModuleLoaded") builder.appendLine("import jdk.incubator.vector.*") super.visitImportList(importList) } diff --git a/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/generator/processor/ReplacementProcessor.kt b/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/generator/processor/ReplacementProcessor.kt index 77835ad..7a0deaf 100644 --- a/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/generator/processor/ReplacementProcessor.kt +++ b/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/generator/processor/ReplacementProcessor.kt @@ -22,7 +22,11 @@ import org.jetbrains.kotlin.resolve.BindingContext import org.jetbrains.kotlin.resolve.descriptorUtil.fqNameSafe import org.jetbrains.kotlin.types.typeUtil.supertypes -internal class ReplacementProcessor(private val context: BindingContext, private val collector: MessageCollector) { +internal class ReplacementProcessor( + private val context: BindingContext, + private val collector: MessageCollector, + private val vectorize: Boolean = false +) { companion object { internal fun toType(primitive: Primitive<*, *>): String { return when (primitive.dataType) { @@ -116,29 +120,38 @@ internal class ReplacementProcessor(private val context: BindingContext, private val vecLen = "_vecLen_$idx" val vecEnd = "_vecEnd_$idx" val vecIdx = "_vec_internal_idx" + val vecEnabled = "isModuleLoaded" if (callName == "into" && args.size == 3) { val dest = args[0].text val destOffset = args[1].text val len = args[2].text - if (primitive.dataType in DataType.VECTORIZABLE.resolve() && !isValue) + if (primitive.dataType in DataType.VECTORIZABLE.resolve() && vectorize) return """ - val $vecLen = ${vecProcessor.vecSpecies}.length() - val $vecEnd = $len - ($len % $vecLen) - ${vecProcessor.valueDeclarations} - for ($vecIdx in 0 until $vecEnd step $vecLen) { - ${vecProcessor.vecDeclarations} - $vecReplacement.intoArray($dest, $destOffset + _vec_internal_idx) - } - for($vecIdx in $vecEnd until $len) { - ${vecProcessor.linDeclarations} - $dest[$destOffset + $vecIdx] = $linReplacement.$toPrimitive - } + if($vecEnabled) { + val $vecLen = ${vecProcessor.vecSpecies}.length() + val $vecEnd = $len - ($len % $vecLen) + ${vecProcessor.valueDeclarations} + for ($vecIdx in 0 until $vecEnd step $vecLen) { + ${vecProcessor.vecDeclarations} + $vecReplacement.intoArray($dest, $destOffset + _vec_internal_idx) + } + for($vecIdx in $vecEnd until $len) { + ${vecProcessor.linDeclarations} + $dest[$destOffset + $vecIdx] = $linReplacement.$toPrimitive + } + }else{ + for($vecIdx in 0 until $len) { + ${vecProcessor.linDeclarations} + $dest[$destOffset + $vecIdx] = $linReplacement.$toPrimitive + } + } """.trimIndent() else return """ for($vecIdx in 0 until $len) { + ${vecProcessor.linDeclarations} $dest[$destOffset + $vecIdx] = $linReplacement.$toPrimitive }""".trimIndent() } else if (callName == "reduce" && args.size == 2) { @@ -151,25 +164,37 @@ internal class ReplacementProcessor(private val context: BindingContext, private val linearOp = VectorReplacementProcessor.binaryLinearReplacements[handle] ?: return "" val linAccumulate = linearOp("ret", linReplacement) - if (primitive.dataType in DataType.VECTORIZABLE.resolve()) { + if (primitive.dataType in DataType.VECTORIZABLE.resolve() && vectorize) { return """{ - val $vecLen = ${vecProcessor.vecSpecies}.length() - val $vecEnd = $len - ($len % $vecLen) - var accumulator = ${vecProcessor.vecName}.broadcast(${vecProcessor.vecSpecies}, $neutral) - for ($vecIdx in 0 until $vecEnd step $vecLen) { - accumulator = accumulator.lanewise(VectorOperators.$handle, $vecReplacement) - } - var ret = accumulator.reduceLanes(VectorOperators.$handle) - for($vecIdx in $vecEnd until $len) { - ret = $linAccumulate.$toPrimitive - } - ret.$toPrimitive + if($vecEnabled) { + val $vecLen = ${vecProcessor.vecSpecies}.length() + val $vecEnd = $len - ($len % $vecLen) + var accumulator = ${vecProcessor.vecName}.broadcast(${vecProcessor.vecSpecies}, $neutral) + ${vecProcessor.valueDeclarations} + for ($vecIdx in 0 until $vecEnd step $vecLen) { + ${vecProcessor.vecDeclarations} + accumulator = accumulator.lanewise(VectorOperators.$handle, $vecReplacement) + } + var ret = accumulator.reduceLanes(VectorOperators.$handle) + for($vecIdx in $vecEnd until $len) { + ${vecProcessor.linDeclarations} + ret = $linAccumulate.$toPrimitive + } + ret.$toPrimitive + }else{ + var ret = $neutral + for($vecIdx in 0 until $len) { + ${vecProcessor.linDeclarations} + ret = $linAccumulate.$toPrimitive + } + ret.$toPrimitive} }.invoke() """.trimIndent() } else { return """{ var ret = $neutral for($vecIdx in 0 until $len) { + ${vecProcessor.linDeclarations} ret = $linAccumulate.$toPrimitive } ret.$toPrimitive diff --git a/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/generator/processor/VectorReplacementProcessor.kt b/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/generator/processor/VectorReplacementProcessor.kt index 3856491..d6b3f8b 100644 --- a/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/generator/processor/VectorReplacementProcessor.kt +++ b/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/generator/processor/VectorReplacementProcessor.kt @@ -1,28 +1,20 @@ package io.kinference.primitives.generator.processor -import com.sun.org.apache.xpath.internal.operations.Bool import io.kinference.primitives.generator.Primitive import io.kinference.primitives.vector.* -import org.gradle.internal.impldep.com.esotericsoftware.kryo.serializers.FieldSerializer -import org.gradle.internal.impldep.org.h2.engine.Right -import org.jetbrains.kotlin.cfg.getElementParentDeclaration import org.jetbrains.kotlin.cli.common.messages.MessageCollector import org.jetbrains.kotlin.js.descriptorUtils.getKotlinTypeFqName import org.jetbrains.kotlin.psi.KtCallExpression import org.jetbrains.kotlin.psi.KtExpression import org.jetbrains.kotlin.psi.KtSimpleNameExpression -import org.jetbrains.kotlin.psi.psiUtil.isContextualDeclaration import org.jetbrains.kotlin.resolve.BindingContext -import org.jetbrains.kotlin.resolve.calls.util.getType -import org.jetbrains.kotlin.com.intellij.psi.PsiElement -import org.jetbrains.kotlin.descriptors.impl.referencedProperty -import org.jetbrains.kotlin.gradle.utils.loadPropertyFromResources +import org.jetbrains.kotlin.idea.references.KtReference +import org.jetbrains.kotlin.idea.references.KtSimpleNameReference import org.jetbrains.kotlin.load.kotlin.toSourceElement import org.jetbrains.kotlin.psi.KtDeclaration import org.jetbrains.kotlin.psi.KtProperty -import org.jetbrains.kotlin.psi.declarationRecursiveVisitor +import org.jetbrains.kotlin.psi.KtVariableDeclaration import org.jetbrains.kotlin.resolve.source.getPsi -import javax.naming.Binding internal class VectorReplacementProcessor(private val context: BindingContext, val primitive: Primitive<*, *>) { @@ -32,7 +24,7 @@ internal class VectorReplacementProcessor(private val context: BindingContext, v companion object { val unaryLinearReplacements = mapOf( - "EXP" to { x: String -> "exp($x)" }, + "EXP" to { x: String -> "FastMath.exp($x)" }, "ABS" to { x: String -> "abs($x)" }, "NEG" to { x: String -> "(-$x)" }, "LOG" to { x: String -> "ln($x)" }, @@ -70,16 +62,7 @@ internal class VectorReplacementProcessor(private val context: BindingContext, v ).withDefault { null } val maskHandles = mapOf( - "And" to "AND", - "Or" to "OR", - "Xor" to "XOR", - "Not" to "NOT", - "Eq" to "EQ", - "Neq" to "NEQ", - "LT" to "LT", - "LE" to "LE", - "GT" to "GT", - "GE" to "GE" + "And" to "AND", "Or" to "OR", "Xor" to "XOR", "Not" to "NOT", "Eq" to "EQ", "Neq" to "NEQ", "LT" to "LT", "LE" to "LE", "GT" to "GT", "GE" to "GE" ).withDefault { null } val maskUnaryReplacement = mapOf( @@ -128,12 +111,13 @@ internal class VectorReplacementProcessor(private val context: BindingContext, v var localVariables: Set = emptySet() fun processDeclaration(expr: KtExpression, collector: MessageCollector): Triple? { - expr as? KtSimpleNameExpression ?: return null // Triple("NOT SIMPLE: ${expr.text}", "a", false) + if (expr !is KtSimpleNameExpression) return null val varName = expr.text + val descriptor = context.get(BindingContext.REFERENCE_TARGET, expr) ?: return null //Triple("NOT_DECL: ${expr.text}", "NOT_DECL", false) val declaration = descriptor.toSourceElement.getPsi() ?: return null //Triple("NOT_DECL_PSI: ${expr.text}", "NOT_DECL", false) - declaration as? KtDeclaration ?: return null //Triple("NOT_DECL_EXPR: ${expr.text}", "NOT_DECL", false) - declaration as? KtProperty ?: return null // Triple("NOT_DECL_PROPERTY: ${expr.text}", "NOT_DECL", false) + + if (declaration !is KtVariableDeclaration) return null val actualBody = declaration.initializer ?: return null //Triple("NOT_DECL_BODY: ${expr.text}", "NOT_DECL", false) //return Triple("BODY: ${actualBody.text}", "", false) val (vecReplacement, linReplacement, value) = process(actualBody, collector) ?: return null @@ -193,11 +177,9 @@ internal class VectorReplacementProcessor(private val context: BindingContext, v var isValue = leftValue && rightValue var linear = binaryLinearReplacements[handle]?.invoke(leftLinear, rightLinear) ?: return null - var vectorized = if (rightValue) - """$leftVector. + var vectorized = if (rightValue) """$leftVector. lanewise(VectorOperators.$handle, $rightLinear""".trimIndent() - else - """$leftVector + else """$leftVector .lanewise(VectorOperators.$handle, $rightVector""".trimIndent() if (masked) { @@ -210,9 +192,7 @@ internal class VectorReplacementProcessor(private val context: BindingContext, v vectorized += ")" } Triple( - vectorized, - linear, - isValue + vectorized, linear, isValue ) } @@ -231,9 +211,7 @@ internal class VectorReplacementProcessor(private val context: BindingContext, v else -> "0" } Triple( - "${vecName}.fromArray($vecSpecies, $src, $offset + _vec_internal_idx)", - "$src[$offset + _vec_internal_idx]", - false + "${vecName}.fromArray($vecSpecies, $src, $offset + _vec_internal_idx)", "$src[$offset + _vec_internal_idx]", false ) } @@ -272,8 +250,7 @@ internal class VectorReplacementProcessor(private val context: BindingContext, v val linReplacer = maskUnaryReplacement[handle] ?: return null Pair( - linReplacer(vecReplacement), - linReplacer(linReplacement) + linReplacer(vecReplacement), linReplacer(linReplacement) ) } @@ -286,8 +263,7 @@ internal class VectorReplacementProcessor(private val context: BindingContext, v val handle = maskHandles[shortName] ?: return null val linReplacer = maskBinaryReplacement[handle] ?: return null Pair( - linReplacer(leftVecReplacement, rightVecReplacement), - linReplacer(leftLinReplacement, rightLinReplacement) + linReplacer(leftVecReplacement, rightVecReplacement), linReplacer(leftLinReplacement, rightLinReplacement) ) } From 8529d306cddfada2cba5d18e04ffaecfb1425ab6 Mon Sep 17 00:00:00 2001 From: Tommaso Dossi Date: Wed, 30 Jul 2025 17:54:46 +0200 Subject: [PATCH 10/16] better replacement, worse code --- .../generator/PrimitiveGenerator.kt | 10 +- .../kinference/primitives/generator/Utils.kt | 15 +- .../generator/processor/RemovalProcessor.kt | 1 + .../processor/ReplacementProcessor.kt | 12 +- .../processor/VectorReplacementProcessor.kt | 187 +++++++++++++----- 5 files changed, 155 insertions(+), 70 deletions(-) diff --git a/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/generator/PrimitiveGenerator.kt b/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/generator/PrimitiveGenerator.kt index cc29ee2..b55357a 100644 --- a/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/generator/PrimitiveGenerator.kt +++ b/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/generator/PrimitiveGenerator.kt @@ -72,7 +72,7 @@ internal class PrimitiveGenerator( override fun visitImportList(importList: KtImportList) { if (file.isAnnotatedWith(context) && primitive.dataType in DataType.VECTORIZABLE.resolve() && vectorize) builder.appendLine("import io.kinference.ndarray.VecUtils.isModuleLoaded") - builder.appendLine("import jdk.incubator.vector.*") + builder.appendLine("import jdk.incubator.vector.*") super.visitImportList(importList) } @@ -210,13 +210,7 @@ internal class PrimitiveGenerator( override fun visitDeclaration(dcl: KtDeclaration) { if (file.isAnnotatedWith(context) && dcl is KtProperty) { val init = dcl.initializer - if (init != null) { - val type = context.getType(init) - if (type != null) { - val supertypes = type.supertypes().map { it.getKotlinTypeFqName(false) }.toSet() - if (VectorReplacementProcessor.opNodeTypename in supertypes) return - } - } + if (isVectorClass(init, context)) return } super.visitDeclaration(dcl) } diff --git a/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/generator/Utils.kt b/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/generator/Utils.kt index 45602a3..7c2f4eb 100644 --- a/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/generator/Utils.kt +++ b/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/generator/Utils.kt @@ -2,16 +2,19 @@ package io.kinference.primitives.generator import io.kinference.primitives.annotations.* import io.kinference.primitives.generator.errors.require +import io.kinference.primitives.generator.processor.VectorReplacementProcessor import io.kinference.primitives.types.DataType import io.kinference.primitives.utils.psi.* import org.jetbrains.kotlin.cli.common.messages.CompilerMessageSeverity import org.jetbrains.kotlin.cli.common.messages.MessageCollector import org.jetbrains.kotlin.descriptors.ClassConstructorDescriptor import org.jetbrains.kotlin.descriptors.DeclarationDescriptor +import org.jetbrains.kotlin.js.descriptorUtils.getKotlinTypeFqName import org.jetbrains.kotlin.js.resolve.diagnostics.findPsi import org.jetbrains.kotlin.psi.* import org.jetbrains.kotlin.resolve.BindingContext import org.jetbrains.kotlin.resolve.isValueClass +import org.jetbrains.kotlin.types.typeUtil.supertypes import kotlin.reflect.KProperty @@ -45,13 +48,15 @@ internal fun KtAnnotationEntry.isPluginAnnotation(context: BindingContext): Bool isAnnotation(context) || isAnnotation(context) || isAnnotation(context) || - isAnnotation(context) + isAnnotation(context) || + isAnnotation(context) } internal fun DeclarationDescriptor.isNamedFunction() = findPsi() is KtNamedFunction internal fun DeclarationDescriptor.isKtClassOrObject() = findPsi() is KtClassOrObject || isValueClass() internal fun DeclarationDescriptor.isCompanion() = findPsi() is KtObjectDeclaration && containingDeclaration?.findPsi() is KtClass -internal fun DeclarationDescriptor.isConstructor() = this is ClassConstructorDescriptor || findPsi() is KtConstructor<*> && containingDeclaration?.findPsi() is KtClass +internal fun DeclarationDescriptor.isConstructor() = + this is ClassConstructorDescriptor || findPsi() is KtConstructor<*> && containingDeclaration?.findPsi() is KtClass internal fun KtNamedDeclaration.specialize(primitive: Primitive<*, *>, collector: MessageCollector): String { val name = name!! @@ -63,3 +68,9 @@ internal fun KtNamedDeclaration.specialize(primitive: Primitive<*, *>, collector internal fun String.specialize(primitive: Primitive<*, *>) = replace("Primitive", primitive.typeName) +internal fun isVectorClass(expr: KtExpression?, context: BindingContext): Boolean { + if (expr == null) return false + val type = context.getType(expr) ?: return false + val fqSupertypes = type.supertypes().map { it.getKotlinTypeFqName(false) } + return VectorReplacementProcessor.opNodeTypename in fqSupertypes +} diff --git a/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/generator/processor/RemovalProcessor.kt b/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/generator/processor/RemovalProcessor.kt index 6c75ccc..8aa6e21 100644 --- a/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/generator/processor/RemovalProcessor.kt +++ b/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/generator/processor/RemovalProcessor.kt @@ -25,6 +25,7 @@ internal class RemovalProcessor(private val context: BindingContext) { GenerateNameFromPrimitives::class.qualifiedName, GeneratePrimitives::class.qualifiedName, SpecifyPrimitives::class.qualifiedName, + GenerateVector::class.qualifiedName, ) } diff --git a/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/generator/processor/ReplacementProcessor.kt b/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/generator/processor/ReplacementProcessor.kt index 7a0deaf..9b05020 100644 --- a/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/generator/processor/ReplacementProcessor.kt +++ b/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/generator/processor/ReplacementProcessor.kt @@ -109,12 +109,10 @@ internal class ReplacementProcessor( val args = selector.valueArguments val callName = selector.calleeExpression?.text ?: return null - val receiverType = context.getType(receiver) ?: return null - val receiverSuperTypes = receiverType.supertypes().map { it.getKotlinTypeFqName(false) } - if (VectorReplacementProcessor.opNodeTypename !in receiverSuperTypes) return null + if (!isVectorClass(receiver, context)) return null - val vecProcessor = VectorReplacementProcessor(context, primitive) - val (vecReplacement, linReplacement, isValue) = vecProcessor.process(receiver, collector) ?: return "" + val vecProcessor = VectorReplacementProcessor(context, primitive, collector) + val (vecReplacement, linReplacement, isScalar) = vecProcessor.process(receiver) ?: return null val toPrimitive = "${toType(primitive)}()" val vecLen = "_vecLen_$idx" @@ -132,7 +130,7 @@ internal class ReplacementProcessor( if($vecEnabled) { val $vecLen = ${vecProcessor.vecSpecies}.length() val $vecEnd = $len - ($len % $vecLen) - ${vecProcessor.valueDeclarations} + ${vecProcessor.scalarDeclarations} for ($vecIdx in 0 until $vecEnd step $vecLen) { ${vecProcessor.vecDeclarations} $vecReplacement.intoArray($dest, $destOffset + _vec_internal_idx) @@ -170,7 +168,7 @@ internal class ReplacementProcessor( val $vecLen = ${vecProcessor.vecSpecies}.length() val $vecEnd = $len - ($len % $vecLen) var accumulator = ${vecProcessor.vecName}.broadcast(${vecProcessor.vecSpecies}, $neutral) - ${vecProcessor.valueDeclarations} + ${vecProcessor.scalarDeclarations} for ($vecIdx in 0 until $vecEnd step $vecLen) { ${vecProcessor.vecDeclarations} accumulator = accumulator.lanewise(VectorOperators.$handle, $vecReplacement) diff --git a/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/generator/processor/VectorReplacementProcessor.kt b/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/generator/processor/VectorReplacementProcessor.kt index d6b3f8b..fae975a 100644 --- a/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/generator/processor/VectorReplacementProcessor.kt +++ b/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/generator/processor/VectorReplacementProcessor.kt @@ -1,8 +1,26 @@ package io.kinference.primitives.generator.processor +import io.kinference.primitives.annotations.BindPrimitives +import io.kinference.primitives.annotations.GenerateVector +import io.kinference.primitives.annotations.SpecifyPrimitives import io.kinference.primitives.generator.Primitive +import io.kinference.primitives.generator.PrimitiveGenerator.PrimitiveContext +import io.kinference.primitives.generator.errors.require +import io.kinference.primitives.generator.getExcludes +import io.kinference.primitives.generator.getIncludes +import io.kinference.primitives.generator.getTypes +import io.kinference.primitives.generator.isVectorClass +import io.kinference.primitives.generator.toPrimitive +import io.kinference.primitives.types.DataType +import io.kinference.primitives.utils.crossProduct +import io.kinference.primitives.utils.psi.KtDefaultVisitor +import io.kinference.primitives.utils.psi.isAnnotatedWith +import io.kinference.primitives.utils.psi.isAnnotation import io.kinference.primitives.vector.* +import org.jetbrains.kotlin.cli.common.messages.CompilerMessageSeverity import org.jetbrains.kotlin.cli.common.messages.MessageCollector +import org.jetbrains.kotlin.com.intellij.psi.PsiWhiteSpace +import org.jetbrains.kotlin.com.intellij.psi.impl.source.tree.LeafPsiElement import org.jetbrains.kotlin.js.descriptorUtils.getKotlinTypeFqName import org.jetbrains.kotlin.psi.KtCallExpression import org.jetbrains.kotlin.psi.KtExpression @@ -10,14 +28,27 @@ import org.jetbrains.kotlin.psi.KtSimpleNameExpression import org.jetbrains.kotlin.resolve.BindingContext import org.jetbrains.kotlin.idea.references.KtReference import org.jetbrains.kotlin.idea.references.KtSimpleNameReference +import org.jetbrains.kotlin.lexer.KtTokens import org.jetbrains.kotlin.load.kotlin.toSourceElement +import org.jetbrains.kotlin.psi.KtAnnotatedExpression +import org.jetbrains.kotlin.psi.KtAnnotationEntry +import org.jetbrains.kotlin.psi.KtClass +import org.jetbrains.kotlin.psi.KtClassOrObject import org.jetbrains.kotlin.psi.KtDeclaration +import org.jetbrains.kotlin.psi.KtDotQualifiedExpression +import org.jetbrains.kotlin.psi.KtElement +import org.jetbrains.kotlin.psi.KtImportDirective +import org.jetbrains.kotlin.psi.KtImportList +import org.jetbrains.kotlin.psi.KtModifierList +import org.jetbrains.kotlin.psi.KtNamedFunction import org.jetbrains.kotlin.psi.KtProperty +import org.jetbrains.kotlin.psi.KtTypeReference import org.jetbrains.kotlin.psi.KtVariableDeclaration +import org.jetbrains.kotlin.psi.psiUtil.visibilityModifier import org.jetbrains.kotlin.resolve.source.getPsi -internal class VectorReplacementProcessor(private val context: BindingContext, val primitive: Primitive<*, *>) { +internal class VectorReplacementProcessor(private val context: BindingContext, val primitive: Primitive<*, *>, val collector: MessageCollector) { val vecName = "${primitive.typeName}Vector" val vecSpecies = "$vecName.SPECIES_PREFERRED" val vecLen = "$vecName.length()" @@ -85,13 +116,10 @@ internal class VectorReplacementProcessor(private val context: BindingContext, v ) val opNodeTypename = OpNode::class.qualifiedName - val opNodeTypes = OpNode::class.sealedSubclasses.map { it.qualifiedName } val unaryOpNames = UnaryOp::class.sealedSubclasses.map { it.qualifiedName } val binaryOpNames = BinaryOp::class.sealedSubclasses.map { it.qualifiedName } - val valueType = Value::class.qualifiedName + val scalarType = Value::class.qualifiedName val primitiveSliceType = PrimitiveSlice::class.qualifiedName - val associativeWrapperType = AssociativeWrapper::class.qualifiedName - val maskTypes = VecMask::class.sealedSubclasses.map { it.qualifiedName } val maskBinaryOpTypes = MaskBinaryOp::class.sealedSubclasses.map { it.qualifiedName } val maskUnaryOpTypes = MaskUnaryOp::class.sealedSubclasses.map { it.qualifiedName } val comparatorTypes = Comparator::class.sealedSubclasses.map { it.qualifiedName } @@ -105,12 +133,12 @@ internal class VectorReplacementProcessor(private val context: BindingContext, v "MAX" to "${primitive.typeName}.MIN_VALUE" ).withDefault { null } - var valueDeclarations: String = "" + var scalarDeclarations: String = "" var vecDeclarations: String = "" var linDeclarations: String = "" var localVariables: Set = emptySet() - fun processDeclaration(expr: KtExpression, collector: MessageCollector): Triple? { + private fun processDeclaration(expr: KtExpression): Triple? { if (expr !is KtSimpleNameExpression) return null val varName = expr.text @@ -120,90 +148,91 @@ internal class VectorReplacementProcessor(private val context: BindingContext, v if (declaration !is KtVariableDeclaration) return null val actualBody = declaration.initializer ?: return null //Triple("NOT_DECL_BODY: ${expr.text}", "NOT_DECL", false) //return Triple("BODY: ${actualBody.text}", "", false) - val (vecReplacement, linReplacement, value) = process(actualBody, collector) ?: return null + val (vecReplacement, linReplacement, scalar) = process(actualBody) ?: return null if (varName !in localVariables) { localVariables = localVariables + varName - if (value) { - valueDeclarations += "val ${varName}_vec = $vecReplacement\n" - valueDeclarations += "val ${varName}_lin = $linReplacement\n" + if (scalar) { + scalarDeclarations += "val ${varName}_vec = $vecReplacement\n" + scalarDeclarations += "val ${varName}_lin = $linReplacement\n" } else { vecDeclarations += "val ${varName}_vec = $vecReplacement\n" linDeclarations += "val ${varName}_lin = $linReplacement\n" } } - return Triple("${varName}_vec", "${varName}_lin", value) + return Triple("${varName}_vec", "${varName}_lin", scalar) } - fun process(expr: KtExpression?, collector: MessageCollector): Triple? { + + fun process(expr: KtExpression?): Triple? { if (expr == null) return null if (expr !is KtCallExpression) { - return processDeclaration(expr, collector) + return processDeclaration(expr) } - val exprType = context.getType(expr) ?: return Triple("NOT_TYPED", "NOT_TYPED", false) + val exprType = context.getType(expr) ?: return null val exprTypename = exprType.getKotlinTypeFqName(false) val shortName = exprTypename.substringAfterLast('.') val args = expr.valueArguments - return when { - exprTypename in unaryOpNames -> { + return when (exprTypename) { + in unaryOpNames -> { if (args.size != 1 && args.size != 2) return null val childExpr = args[0].getArgumentExpression() val masked = args.size == 2 - var (childVector, childLinear, isValue) = process(childExpr, collector) ?: return null + var (childVector, childLinear, isScalar) = process(childExpr) ?: return null val handle = vectorHandles[shortName] ?: return null val linReplace = unaryLinearReplacements[handle] ?: return null var linear = linReplace(childLinear) var vectorized = """$childVector - .lanewise(VectorOperators.$handle""".trimIndent() + .lanewise(VectorOperators.$handle""".trimIndent() if (masked) { - isValue = false - val maskExpr = args[1].getArgumentExpression() ?: return null - val (maskVector, maskLinear) = processMask(maskExpr, collector) ?: return null + isScalar = false + val maskExpr = args[1].getArgumentExpression() + val (maskVector, maskLinear) = processMask(maskExpr) ?: return null linear = "(if($maskLinear) $linear else $childLinear)" vectorized += ", $maskVector)" } else { vectorized += ")" } - Triple(vectorized, linear, isValue) + Triple(vectorized, linear, isScalar) } - exprTypename in binaryOpNames -> { + in binaryOpNames -> { if (args.size != 2 && args.size != 3) return null val masked = args.size == 3 - val leftExpr = args[0].getArgumentExpression() ?: return null - val rightExpr = args[1].getArgumentExpression() ?: return null + val leftExpr = args[0].getArgumentExpression() + val rightExpr = args[1].getArgumentExpression() val handle = vectorHandles[shortName] ?: return null - val (leftVector, leftLinear, leftValue) = process(leftExpr, collector) ?: return null - val (rightVector, rightLinear, rightValue) = process(rightExpr, collector) ?: return null - var isValue = leftValue && rightValue + val (leftVector, leftLinear, leftScalar) = process(leftExpr) ?: return null + val (rightVector, rightLinear, rightScalar) = process(rightExpr) ?: return null + var isScalar = leftScalar && rightScalar var linear = binaryLinearReplacements[handle]?.invoke(leftLinear, rightLinear) ?: return null - var vectorized = if (rightValue) """$leftVector. - lanewise(VectorOperators.$handle, $rightLinear""".trimIndent() + var vectorized = if (rightScalar) """$leftVector. + lanewise(VectorOperators.$handle, $rightLinear""".trimIndent() else """$leftVector - .lanewise(VectorOperators.$handle, $rightVector""".trimIndent() + .lanewise(VectorOperators.$handle, $rightVector""".trimIndent() if (masked) { - isValue = false - val maskExpr = args[2].getArgumentExpression() ?: return null - val (maskVector, maskLinear) = processMask(maskExpr, collector) ?: return null + isScalar = false + val maskExpr = args[2].getArgumentExpression() + val (maskVector, maskLinear) = processMask(maskExpr) ?: return null linear = "(if($maskLinear) $linear else $leftLinear)" vectorized += ", $maskVector)" } else { vectorized += ")" } Triple( - vectorized, linear, isValue + vectorized, linear, isScalar ) } - exprTypename == valueType -> { + scalarType -> { if (args.size != 1) return null - val linear = "${args[0].text}" + val linear = replaceLeaves(args[0].getArgumentExpression()?: return null) val vectorized = "$vecName.broadcast($vecSpecies, $linear)" Triple(vectorized, linear, true) } - exprTypename == primitiveSliceType -> { + primitiveSliceType -> { if (args.size != 2 && args.size != 1) return null val src = args[0].text val offset = when (args.size) { @@ -215,26 +244,26 @@ internal class VectorReplacementProcessor(private val context: BindingContext, v ) } - exprTypename == ifElseType -> { + ifElseType -> { if (args.size != 3) return null val mask = args[0].getArgumentExpression() ?: return null val left = args[1].getArgumentExpression() ?: return null val right = args[2].getArgumentExpression() ?: return null - val (maskVector, maskLinear) = processMask(mask, collector) ?: return null - val (leftVector, leftLinear, leftValue) = process(left, collector) ?: return null - val (rightVector, rightLinear, rightValue) = process(right, collector) ?: return null - val isValue = leftValue && rightValue + val (maskVector, maskLinear) = processMask(mask) ?: return null + val (leftVector, leftLinear, leftScalar) = process(left) ?: return null + val (rightVector, rightLinear, rightScalar) = process(right) ?: return null + val isScalar = leftScalar && rightScalar val linear = "(if($maskLinear) $leftLinear else $rightLinear)" val vectorized = """ - $rightVector.blend($leftVector, $maskVector)""".trimIndent() - Triple(vectorized, linear, isValue) + $rightVector.blend($leftVector, $maskVector)""".trimIndent() + Triple(vectorized, linear, isScalar) } else -> null } } - fun processMask(expr: KtExpression?, collector: MessageCollector): Pair? { + fun processMask(expr: KtExpression?): Pair? { if (expr == null) return null val exprType = context.getType(expr) ?: return null val exprTypename = exprType.getKotlinTypeFqName(false) @@ -245,7 +274,7 @@ internal class VectorReplacementProcessor(private val context: BindingContext, v exprTypename in maskUnaryOpTypes -> { if (args.size != 1) return null val child = args[0].getArgumentExpression() ?: return null - val (vecReplacement, linReplacement) = processMask(child, collector) ?: return null + val (vecReplacement, linReplacement) = processMask(child) ?: return null val handle = maskHandles[shortName] ?: return null val linReplacer = maskUnaryReplacement[handle] ?: return null @@ -257,9 +286,9 @@ internal class VectorReplacementProcessor(private val context: BindingContext, v exprTypename in maskBinaryOpTypes -> { if (args.size != 2) return null val left = args[0].getArgumentExpression() ?: return null - val (leftVecReplacement, leftLinReplacement) = processMask(left, collector) ?: return null + val (leftVecReplacement, leftLinReplacement) = processMask(left) ?: return null val right = args[0].getArgumentExpression() ?: return null - val (rightVecReplacement, rightLinReplacement) = processMask(right, collector) ?: return null + val (rightVecReplacement, rightLinReplacement) = processMask(right) ?: return null val handle = maskHandles[shortName] ?: return null val linReplacer = maskBinaryReplacement[handle] ?: return null Pair( @@ -272,9 +301,9 @@ internal class VectorReplacementProcessor(private val context: BindingContext, v val leftExpr = args[0].getArgumentExpression() ?: return null val rightExpr = args[1].getArgumentExpression() ?: return null val handle = maskHandles[shortName] ?: return null - val (leftVector, leftLinear, leftValue) = process(leftExpr, collector) ?: return null - val (rightVector, rightLinear, rightValue) = process(rightExpr, collector) ?: return null - val isValue = leftValue && rightValue + val (leftVector, leftLinear, leftScalar) = process(leftExpr) ?: return null + val (rightVector, rightLinear, rightScalar) = process(rightExpr) ?: return null + val isScalar = leftScalar && rightScalar val linear = comparatorReplacement[handle]?.invoke(leftLinear, rightLinear) ?: return null val vectorized = """ $leftVector @@ -286,4 +315,56 @@ internal class VectorReplacementProcessor(private val context: BindingContext, v else -> null } } + + private fun replaceLeaves(expr: KtExpression): String { + val builder = StringBuilder() + expr.accept(object : KtDefaultVisitor() { + val replacementProcessor = ReplacementProcessor(context, collector) + private var currentPrimitive = primitive + + private fun KtElement.withPrimitive(primitive: Primitive<*, *>?, body: () -> Unit) { + collector.require(CompilerMessageSeverity.ERROR, this, primitive != null) { + "Primitive was bound with @${BindPrimitives::class.simpleName} sub-annotation," + + " but outer expression is not annotated with @${BindPrimitives::class.simpleName}" + } + + val tmp = currentPrimitive + currentPrimitive = primitive + body() + currentPrimitive = tmp + } + override fun visitClass(klass: KtClass) { + if (primitive.dataType in klass.getExcludes(context)) return + if (klass.isAnnotatedWith(context) && primitive.dataType !in klass.getIncludes(context)!!) return + + super.visitClass(klass) + } + + override fun visitLeafElement(element: LeafPsiElement) { + if (replacementProcessor.haveReplaceText(element)) { + builder.append(replacementProcessor.getReplacement(element)) + return + } + + if (element.elementType != KtTokens.IDENTIFIER) { + builder.append(element.text) + return + } + + when (val parent = element.parent) { + is KtClassOrObject -> builder.append(replacementProcessor.getReplacement(parent, currentPrimitive) ?: element.text) + is KtNamedFunction -> builder.append(replacementProcessor.getReplacement(parent, currentPrimitive) ?: element.text) + else -> builder.append(element.text) + } + } + + + override fun visitSimpleNameExpression(expression: KtSimpleNameExpression) { + val replacement = replacementProcessor.getReplacement(expression, currentPrimitive) + builder.append(replacement ?: expression.text) + } + + }) + return builder.toString() + } } From 111622964e7504c6c02616cecc94dcdebab33c28 Mon Sep 17 00:00:00 2001 From: Tommaso Dossi Date: Thu, 31 Jul 2025 10:52:44 +0200 Subject: [PATCH 11/16] refactoring --- .../generator/PrimitiveGenerator.kt | 201 +-------------- .../generator/processor/GenerationVisitor.kt | 244 ++++++++++++++++++ .../processor/ReplacementProcessor.kt | 5 +- .../processor/VectorReplacementProcessor.kt | 97 +------ .../primitives/utils/psi/KtDefaultVisitor.kt | 1 + 5 files changed, 264 insertions(+), 284 deletions(-) create mode 100644 plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/generator/processor/GenerationVisitor.kt diff --git a/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/generator/PrimitiveGenerator.kt b/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/generator/PrimitiveGenerator.kt index b55357a..dd46cae 100644 --- a/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/generator/PrimitiveGenerator.kt +++ b/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/generator/PrimitiveGenerator.kt @@ -2,6 +2,7 @@ package io.kinference.primitives.generator import io.kinference.primitives.annotations.* import io.kinference.primitives.generator.errors.require +import io.kinference.primitives.generator.processor.GenerationVisitor import io.kinference.primitives.generator.processor.RemovalProcessor import io.kinference.primitives.generator.processor.ReplacementProcessor import io.kinference.primitives.generator.processor.VectorReplacementProcessor @@ -41,208 +42,18 @@ internal class PrimitiveGenerator( } for (primitive in types.flatMap { it.toPrimitive() }.toSet()) { - val builder = StringBuilder() + val visitor = GenerationVisitor(primitive, context, collector, file) + file.accept(visitor) + val text = visitor.text() - val removalProcessor = RemovalProcessor(context) - val replacementProcessor = ReplacementProcessor(context, collector, vectorize) - - file.accept(object : KtDefaultVisitor() { - private var currentPrimitive = primitive - - private fun KtElement.withPrimitive(primitive: Primitive<*, *>?, body: () -> Unit) { - collector.require(CompilerMessageSeverity.ERROR, this, primitive != null) { - "Primitive was bound with @${BindPrimitives::class.simpleName} sub-annotation," + - " but outer expression is not annotated with @${BindPrimitives::class.simpleName}" - } - - val tmp = currentPrimitive - currentPrimitive = primitive - body() - currentPrimitive = tmp - } - - private var primitiveContext = PrimitiveContext() - private fun withContext(context: PrimitiveContext, body: () -> Unit) { - val tmp = primitiveContext - primitiveContext = context - body() - primitiveContext = tmp - } - - override fun visitImportList(importList: KtImportList) { - if (file.isAnnotatedWith(context) && primitive.dataType in DataType.VECTORIZABLE.resolve() && vectorize) - builder.appendLine("import io.kinference.ndarray.VecUtils.isModuleLoaded") - builder.appendLine("import jdk.incubator.vector.*") - super.visitImportList(importList) - } - - override fun visitModifierList(list: KtModifierList) { - if (replacementProcessor.shouldChangeVisibilityModifier(list)) { - replacementProcessor.prepareReplaceText(list.visibilityModifier(), "public") - } - - super.visitModifierList(list) - } - - override fun visitWhiteSpace(space: PsiWhiteSpace) { - if (removalProcessor.shouldRemoveWhiteSpace(space)) return - - super.visitWhiteSpace(space) - } - - override fun visitAnnotationEntry(annotationEntry: KtAnnotationEntry) { - if (removalProcessor.shouldRemoveAnnotation(annotationEntry)) { - removalProcessor.prepareRemoval(annotationEntry) - return - } - - builder.append(annotationEntry.text) - } - - override fun visitImportDirective(importDirective: KtImportDirective) { - if (removalProcessor.shouldRemoveImport(importDirective)) { - removalProcessor.prepareRemoval(importDirective) - return - } - - return super.visitImportDirective(importDirective) - } - - override fun visitNamedFunction(function: KtNamedFunction) { - if (primitive.dataType in function.getExcludes(context)) return - if (function.isAnnotatedWith(context) && primitive.dataType !in function.getIncludes(context)!!) return - - if (function.isAnnotatedWith(context)) { - for (annotation in function.annotationEntries.filter { it.isAnnotation(context) }) { - val primitives1 = annotation.getTypes(context, BindPrimitives::type1).flatMap { it.toPrimitive() }.toSet() - val primitives2 = annotation.getTypes(context, BindPrimitives::type2).flatMap { it.toPrimitive() }.toSet() - val primitives3 = annotation.getTypes(context, BindPrimitives::type3).flatMap { it.toPrimitive() }.toSet() - - - collector.require( - CompilerMessageSeverity.WARNING, annotation, - primitives1.isNotEmpty() || primitives2.isNotEmpty() || primitives3.isNotEmpty() - ) { - "All arguments of @${BindPrimitives::class.simpleName} are empty. It would lead to omitting of the function during generation." - } - - val combinations = crossProduct(primitives1, primitives2, primitives3) - - for (combination in combinations) { - var index = 0 - val primitive1 = if (primitives1.isEmpty()) null else combination[index++] - val primitive2 = if (primitives2.isEmpty()) null else combination[index++] - val primitive3 = if (primitives3.isEmpty()) null else combination[index] - - withContext(PrimitiveContext(primitive1, primitive2, primitive3)) { - super.visitNamedFunction(function) - } - builder.append('\n') - } - } - } else { - super.visitNamedFunction(function) - } - } - - override fun visitClass(klass: KtClass) { - if (primitive.dataType in klass.getExcludes(context)) return - if (klass.isAnnotatedWith(context) && primitive.dataType !in klass.getIncludes(context)!!) return - - super.visitClass(klass) - } - - override fun visitTypeReference(typeReference: KtTypeReference) { - when { - typeReference.isAnnotatedWith(context) -> typeReference.withPrimitive(primitiveContext.type1) { - super.visitTypeReference(typeReference) - } - - typeReference.isAnnotatedWith(context) -> typeReference.withPrimitive(primitiveContext.type2) { - super.visitTypeReference(typeReference) - } - - typeReference.isAnnotatedWith(context) -> typeReference.withPrimitive(primitiveContext.type3) { - super.visitTypeReference(typeReference) - } - - else -> super.visitTypeReference(typeReference) - } - } - - - override fun visitAnnotatedExpression(expression: KtAnnotatedExpression) { - when { - expression.isAnnotatedWith(context) -> expression.withPrimitive(primitiveContext.type1) { - super.visitAnnotatedExpression(expression) - } - - expression.isAnnotatedWith(context) -> expression.withPrimitive(primitiveContext.type2) { - super.visitAnnotatedExpression(expression) - } - - expression.isAnnotatedWith(context) -> expression.withPrimitive(primitiveContext.type3) { - super.visitAnnotatedExpression(expression) - } - - else -> super.visitAnnotatedExpression(expression) - } - } - - override fun visitLeafElement(element: LeafPsiElement) { - if (replacementProcessor.haveReplaceText(element)) { - builder.append(replacementProcessor.getReplacement(element)) - return - } - - if (element.elementType != KtTokens.IDENTIFIER) { - builder.append(element.text) - return - } - - when (val parent = element.parent) { - is KtClassOrObject -> builder.append(replacementProcessor.getReplacement(parent, currentPrimitive) ?: element.text) - is KtNamedFunction -> builder.append(replacementProcessor.getReplacement(parent, currentPrimitive) ?: element.text) - else -> builder.append(element.text) - } - } - - override fun visitDeclaration(dcl: KtDeclaration) { - if (file.isAnnotatedWith(context) && dcl is KtProperty) { - val init = dcl.initializer - if (isVectorClass(init, context)) return - } - super.visitDeclaration(dcl) - } - - override fun visitSimpleNameExpression(expression: KtSimpleNameExpression) { - val replacement = replacementProcessor.getReplacement(expression, currentPrimitive) - builder.append(replacement ?: expression.text) - } - - override fun visitDotQualifiedExpression(expression: KtDotQualifiedExpression) { - if (!file.isAnnotatedWith(context)) { - super.visitDotQualifiedExpression(expression) - return - } - val replacement = replacementProcessor.getReplacement(expression, currentPrimitive, vecCount) - if (replacement == null) { - super.visitDotQualifiedExpression(expression); return - } else { - vecCount += 1 - builder.append(replacement) - } - } - }) - - if (builder.isNotBlank()) { + if (text.isNotBlank()) { val file = File( output, "${file.packageFqName.asString().replace('.', '/')}/${file.name.replace("Primitive", primitive.typeName)}" ) results.add(file) file.parentFile.mkdirs() - file.writeText(removalProcessor.reformat(builder.toString())) + file.writeText(visitor.removalProcessor.reformat(text)) } } diff --git a/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/generator/processor/GenerationVisitor.kt b/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/generator/processor/GenerationVisitor.kt new file mode 100644 index 0000000..5588118 --- /dev/null +++ b/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/generator/processor/GenerationVisitor.kt @@ -0,0 +1,244 @@ +package io.kinference.primitives.generator.processor + +import io.kinference.primitives.annotations.BindPrimitives +import io.kinference.primitives.annotations.GenerateVector +import io.kinference.primitives.annotations.SpecifyPrimitives +import io.kinference.primitives.generator.Primitive +import io.kinference.primitives.generator.PrimitiveGenerator.PrimitiveContext +import io.kinference.primitives.generator.errors.require +import io.kinference.primitives.generator.getExcludes +import io.kinference.primitives.generator.getIncludes +import io.kinference.primitives.generator.getTypes +import io.kinference.primitives.generator.isVectorClass +import io.kinference.primitives.generator.toPrimitive +import io.kinference.primitives.types.DataType +import io.kinference.primitives.utils.crossProduct +import io.kinference.primitives.utils.psi.KtDefaultVisitor +import io.kinference.primitives.utils.psi.isAnnotatedWith +import io.kinference.primitives.utils.psi.isAnnotation +import org.jetbrains.kotlin.cli.common.messages.CompilerMessageSeverity +import org.jetbrains.kotlin.cli.common.messages.MessageCollector +import org.jetbrains.kotlin.com.intellij.psi.PsiWhiteSpace +import org.jetbrains.kotlin.com.intellij.psi.impl.source.tree.LeafPsiElement +import org.jetbrains.kotlin.lexer.KtTokens +import org.jetbrains.kotlin.psi.KtAnnotatedExpression +import org.jetbrains.kotlin.psi.KtAnnotationEntry +import org.jetbrains.kotlin.psi.KtClass +import org.jetbrains.kotlin.psi.KtClassOrObject +import org.jetbrains.kotlin.psi.KtDeclaration +import org.jetbrains.kotlin.psi.KtDotQualifiedExpression +import org.jetbrains.kotlin.psi.KtElement +import org.jetbrains.kotlin.psi.KtFile +import org.jetbrains.kotlin.psi.KtImportDirective +import org.jetbrains.kotlin.psi.KtImportList +import org.jetbrains.kotlin.psi.KtModifierList +import org.jetbrains.kotlin.psi.KtNamedFunction +import org.jetbrains.kotlin.psi.KtProperty +import org.jetbrains.kotlin.psi.KtSimpleNameExpression +import org.jetbrains.kotlin.psi.KtTypeReference +import org.jetbrains.kotlin.psi.psiUtil.visibilityModifier +import org.jetbrains.kotlin.resolve.BindingContext + +internal class GenerationVisitor( + private val primitive: Primitive<*, *>, + private val context: BindingContext, + private val collector: MessageCollector, + private val file: KtFile +) : + KtDefaultVisitor() { + private data class PrimitiveContext(val type1: Primitive<*, *>? = null, val type2: Primitive<*, *>? = null, val type3: Primitive<*, *>? = null) + + private val vectorize = true + private val builder = StringBuilder() + fun text() = builder.toString() + val removalProcessor = RemovalProcessor(context) + val replacementProcessor = ReplacementProcessor(context, collector, file, vectorize) + private var currentPrimitive = primitive + private var vecCount = 0 + + private fun KtElement.withPrimitive(primitive: Primitive<*, *>?, body: () -> Unit) { + collector.require(CompilerMessageSeverity.ERROR, this, primitive != null) { + "Primitive was bound with @${BindPrimitives::class.simpleName} sub-annotation," + + " but outer expression is not annotated with @${BindPrimitives::class.simpleName}" + } + + val tmp = currentPrimitive + currentPrimitive = primitive + body() + currentPrimitive = tmp + } + + private var primitiveContext = PrimitiveContext() + private fun withContext(context: PrimitiveContext, body: () -> Unit) { + val tmp = primitiveContext + primitiveContext = context + body() + primitiveContext = tmp + } + + override fun visitImportList(importList: KtImportList) { + if (file.isAnnotatedWith(context) && primitive.dataType in DataType.VECTORIZABLE.resolve() && vectorize) { + builder.appendLine("import io.kinference.ndarray.VecUtils.isModuleLoaded") + builder.appendLine("import jdk.incubator.vector.*") + } + super.visitImportList(importList) + } + + override fun visitModifierList(list: KtModifierList) { + if (replacementProcessor.shouldChangeVisibilityModifier(list)) { + replacementProcessor.prepareReplaceText(list.visibilityModifier(), "public") + } + + super.visitModifierList(list) + } + + override fun visitWhiteSpace(space: PsiWhiteSpace) { + if (removalProcessor.shouldRemoveWhiteSpace(space)) return + + super.visitWhiteSpace(space) + } + + override fun visitAnnotationEntry(annotationEntry: KtAnnotationEntry) { + if (removalProcessor.shouldRemoveAnnotation(annotationEntry)) { + removalProcessor.prepareRemoval(annotationEntry) + return + } + + builder.append(annotationEntry.text) + } + + override fun visitImportDirective(importDirective: KtImportDirective) { + if (removalProcessor.shouldRemoveImport(importDirective)) { + removalProcessor.prepareRemoval(importDirective) + return + } + + return super.visitImportDirective(importDirective) + } + + override fun visitNamedFunction(function: KtNamedFunction) { + if (primitive.dataType in function.getExcludes(context)) return + if (function.isAnnotatedWith(context) && primitive.dataType !in function.getIncludes(context)!!) return + + if (function.isAnnotatedWith(context)) { + for (annotation in function.annotationEntries.filter { it.isAnnotation(context) }) { + val primitives1 = annotation.getTypes(context, BindPrimitives::type1).flatMap { it.toPrimitive() }.toSet() + val primitives2 = annotation.getTypes(context, BindPrimitives::type2).flatMap { it.toPrimitive() }.toSet() + val primitives3 = annotation.getTypes(context, BindPrimitives::type3).flatMap { it.toPrimitive() }.toSet() + + + collector.require( + CompilerMessageSeverity.WARNING, annotation, + primitives1.isNotEmpty() || primitives2.isNotEmpty() || primitives3.isNotEmpty() + ) { + "All arguments of @${BindPrimitives::class.simpleName} are empty. It would lead to omitting of the function during generation." + } + + val combinations = crossProduct(primitives1, primitives2, primitives3) + + for (combination in combinations) { + var index = 0 + val primitive1 = if (primitives1.isEmpty()) null else combination[index++] + val primitive2 = if (primitives2.isEmpty()) null else combination[index++] + val primitive3 = if (primitives3.isEmpty()) null else combination[index] + + withContext(PrimitiveContext(primitive1, primitive2, primitive3)) { + super.visitNamedFunction(function) + } + builder.append('\n') + } + } + } else { + super.visitNamedFunction(function) + } + } + + override fun visitClass(klass: KtClass) { + if (primitive.dataType in klass.getExcludes(context)) return + if (klass.isAnnotatedWith(context) && primitive.dataType !in klass.getIncludes(context)!!) return + + super.visitClass(klass) + } + + override fun visitTypeReference(typeReference: KtTypeReference) { + when { + typeReference.isAnnotatedWith(context) -> typeReference.withPrimitive(primitiveContext.type1) { + super.visitTypeReference(typeReference) + } + + typeReference.isAnnotatedWith(context) -> typeReference.withPrimitive(primitiveContext.type2) { + super.visitTypeReference(typeReference) + } + + typeReference.isAnnotatedWith(context) -> typeReference.withPrimitive(primitiveContext.type3) { + super.visitTypeReference(typeReference) + } + + else -> super.visitTypeReference(typeReference) + } + } + + + override fun visitAnnotatedExpression(expression: KtAnnotatedExpression) { + when { + expression.isAnnotatedWith(context) -> expression.withPrimitive(primitiveContext.type1) { + super.visitAnnotatedExpression(expression) + } + + expression.isAnnotatedWith(context) -> expression.withPrimitive(primitiveContext.type2) { + super.visitAnnotatedExpression(expression) + } + + expression.isAnnotatedWith(context) -> expression.withPrimitive(primitiveContext.type3) { + super.visitAnnotatedExpression(expression) + } + + else -> super.visitAnnotatedExpression(expression) + } + } + + override fun visitLeafElement(element: LeafPsiElement) { + if (replacementProcessor.haveReplaceText(element)) { + builder.append(replacementProcessor.getReplacement(element)) + return + } + + if (element.elementType != KtTokens.IDENTIFIER) { + builder.append(element.text) + return + } + + when (val parent = element.parent) { + is KtClassOrObject -> builder.append(replacementProcessor.getReplacement(parent, currentPrimitive) ?: element.text) + is KtNamedFunction -> builder.append(replacementProcessor.getReplacement(parent, currentPrimitive) ?: element.text) + else -> builder.append(element.text) + } + } + + override fun visitDeclaration(dcl: KtDeclaration) { + if (file.isAnnotatedWith(context) && dcl is KtProperty) { + val init = dcl.initializer + if (isVectorClass(init, context)) return + } + super.visitDeclaration(dcl) + } + + override fun visitSimpleNameExpression(expression: KtSimpleNameExpression) { + val replacement = replacementProcessor.getReplacement(expression, currentPrimitive) + builder.append(replacement ?: expression.text) + } + + override fun visitDotQualifiedExpression(expression: KtDotQualifiedExpression) { + if (!file.isAnnotatedWith(context)) { + super.visitDotQualifiedExpression(expression) + return + } + val replacement = replacementProcessor.getReplacement(expression, currentPrimitive, vecCount) + if (replacement == null) { + super.visitDotQualifiedExpression(expression); return + } else { + vecCount += 1 + builder.append(replacement) + } + } +} diff --git a/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/generator/processor/ReplacementProcessor.kt b/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/generator/processor/ReplacementProcessor.kt index 9b05020..020a44a 100644 --- a/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/generator/processor/ReplacementProcessor.kt +++ b/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/generator/processor/ReplacementProcessor.kt @@ -25,7 +25,8 @@ import org.jetbrains.kotlin.types.typeUtil.supertypes internal class ReplacementProcessor( private val context: BindingContext, private val collector: MessageCollector, - private val vectorize: Boolean = false + private val file: KtFile, + private val vectorize: Boolean = true ) { companion object { internal fun toType(primitive: Primitive<*, *>): String { @@ -111,7 +112,7 @@ internal class ReplacementProcessor( if (!isVectorClass(receiver, context)) return null - val vecProcessor = VectorReplacementProcessor(context, primitive, collector) + val vecProcessor = VectorReplacementProcessor(context, primitive, collector, file) val (vecReplacement, linReplacement, isScalar) = vecProcessor.process(receiver) ?: return null val toPrimitive = "${toType(primitive)}()" diff --git a/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/generator/processor/VectorReplacementProcessor.kt b/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/generator/processor/VectorReplacementProcessor.kt index fae975a..8e1f4cb 100644 --- a/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/generator/processor/VectorReplacementProcessor.kt +++ b/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/generator/processor/VectorReplacementProcessor.kt @@ -1,54 +1,26 @@ package io.kinference.primitives.generator.processor -import io.kinference.primitives.annotations.BindPrimitives -import io.kinference.primitives.annotations.GenerateVector -import io.kinference.primitives.annotations.SpecifyPrimitives import io.kinference.primitives.generator.Primitive -import io.kinference.primitives.generator.PrimitiveGenerator.PrimitiveContext -import io.kinference.primitives.generator.errors.require -import io.kinference.primitives.generator.getExcludes -import io.kinference.primitives.generator.getIncludes -import io.kinference.primitives.generator.getTypes -import io.kinference.primitives.generator.isVectorClass -import io.kinference.primitives.generator.toPrimitive -import io.kinference.primitives.types.DataType -import io.kinference.primitives.utils.crossProduct -import io.kinference.primitives.utils.psi.KtDefaultVisitor -import io.kinference.primitives.utils.psi.isAnnotatedWith -import io.kinference.primitives.utils.psi.isAnnotation import io.kinference.primitives.vector.* -import org.jetbrains.kotlin.cli.common.messages.CompilerMessageSeverity import org.jetbrains.kotlin.cli.common.messages.MessageCollector -import org.jetbrains.kotlin.com.intellij.psi.PsiWhiteSpace -import org.jetbrains.kotlin.com.intellij.psi.impl.source.tree.LeafPsiElement import org.jetbrains.kotlin.js.descriptorUtils.getKotlinTypeFqName import org.jetbrains.kotlin.psi.KtCallExpression import org.jetbrains.kotlin.psi.KtExpression import org.jetbrains.kotlin.psi.KtSimpleNameExpression import org.jetbrains.kotlin.resolve.BindingContext -import org.jetbrains.kotlin.idea.references.KtReference -import org.jetbrains.kotlin.idea.references.KtSimpleNameReference -import org.jetbrains.kotlin.lexer.KtTokens import org.jetbrains.kotlin.load.kotlin.toSourceElement -import org.jetbrains.kotlin.psi.KtAnnotatedExpression -import org.jetbrains.kotlin.psi.KtAnnotationEntry -import org.jetbrains.kotlin.psi.KtClass -import org.jetbrains.kotlin.psi.KtClassOrObject -import org.jetbrains.kotlin.psi.KtDeclaration -import org.jetbrains.kotlin.psi.KtDotQualifiedExpression import org.jetbrains.kotlin.psi.KtElement -import org.jetbrains.kotlin.psi.KtImportDirective -import org.jetbrains.kotlin.psi.KtImportList -import org.jetbrains.kotlin.psi.KtModifierList -import org.jetbrains.kotlin.psi.KtNamedFunction -import org.jetbrains.kotlin.psi.KtProperty -import org.jetbrains.kotlin.psi.KtTypeReference +import org.jetbrains.kotlin.psi.KtFile import org.jetbrains.kotlin.psi.KtVariableDeclaration -import org.jetbrains.kotlin.psi.psiUtil.visibilityModifier import org.jetbrains.kotlin.resolve.source.getPsi -internal class VectorReplacementProcessor(private val context: BindingContext, val primitive: Primitive<*, *>, val collector: MessageCollector) { +internal class VectorReplacementProcessor( + private val context: BindingContext, + private val primitive: Primitive<*, *>, + private val collector: MessageCollector, + private val file: KtFile, +) { val vecName = "${primitive.typeName}Vector" val vecSpecies = "$vecName.SPECIES_PREFERRED" val vecLen = "$vecName.length()" @@ -227,7 +199,9 @@ internal class VectorReplacementProcessor(private val context: BindingContext, v scalarType -> { if (args.size != 1) return null - val linear = replaceLeaves(args[0].getArgumentExpression()?: return null) + val visitor = GenerationVisitor(primitive, context, collector, file) + args[0].accept(visitor) + val linear = visitor.text() val vectorized = "$vecName.broadcast($vecSpecies, $linear)" Triple(vectorized, linear, true) } @@ -316,55 +290,4 @@ internal class VectorReplacementProcessor(private val context: BindingContext, v } } - private fun replaceLeaves(expr: KtExpression): String { - val builder = StringBuilder() - expr.accept(object : KtDefaultVisitor() { - val replacementProcessor = ReplacementProcessor(context, collector) - private var currentPrimitive = primitive - - private fun KtElement.withPrimitive(primitive: Primitive<*, *>?, body: () -> Unit) { - collector.require(CompilerMessageSeverity.ERROR, this, primitive != null) { - "Primitive was bound with @${BindPrimitives::class.simpleName} sub-annotation," + - " but outer expression is not annotated with @${BindPrimitives::class.simpleName}" - } - - val tmp = currentPrimitive - currentPrimitive = primitive - body() - currentPrimitive = tmp - } - override fun visitClass(klass: KtClass) { - if (primitive.dataType in klass.getExcludes(context)) return - if (klass.isAnnotatedWith(context) && primitive.dataType !in klass.getIncludes(context)!!) return - - super.visitClass(klass) - } - - override fun visitLeafElement(element: LeafPsiElement) { - if (replacementProcessor.haveReplaceText(element)) { - builder.append(replacementProcessor.getReplacement(element)) - return - } - - if (element.elementType != KtTokens.IDENTIFIER) { - builder.append(element.text) - return - } - - when (val parent = element.parent) { - is KtClassOrObject -> builder.append(replacementProcessor.getReplacement(parent, currentPrimitive) ?: element.text) - is KtNamedFunction -> builder.append(replacementProcessor.getReplacement(parent, currentPrimitive) ?: element.text) - else -> builder.append(element.text) - } - } - - - override fun visitSimpleNameExpression(expression: KtSimpleNameExpression) { - val replacement = replacementProcessor.getReplacement(expression, currentPrimitive) - builder.append(replacement ?: expression.text) - } - - }) - return builder.toString() - } } diff --git a/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/utils/psi/KtDefaultVisitor.kt b/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/utils/psi/KtDefaultVisitor.kt index ea693f4..b0b4d29 100644 --- a/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/utils/psi/KtDefaultVisitor.kt +++ b/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/utils/psi/KtDefaultVisitor.kt @@ -2,6 +2,7 @@ package io.kinference.primitives.utils.psi import org.jetbrains.kotlin.com.intellij.psi.PsiElement import org.jetbrains.kotlin.com.intellij.psi.impl.source.tree.LeafPsiElement +import org.jetbrains.kotlin.psi.KtTreeVisitorVoid import org.jetbrains.kotlin.psi.KtVisitorVoid internal abstract class KtDefaultVisitor : KtVisitorVoid() { From c64d8caf4aaedef674159cdb93a6bd339c7acc89 Mon Sep 17 00:00:00 2001 From: Tommaso Dossi Date: Tue, 5 Aug 2025 12:12:11 +0200 Subject: [PATCH 12/16] added sqrt --- .../primitives/vector/OperationNode.kt | 78 ++++++++++++------- .../kinference/primitives/generator/Utils.kt | 11 +++ .../processor/VectorReplacementProcessor.kt | 29 ++++--- 3 files changed, 73 insertions(+), 45 deletions(-) diff --git a/plugin-build/primitives-annotations/src/commonMain/kotlin/io/kinference/primitives/vector/OperationNode.kt b/plugin-build/primitives-annotations/src/commonMain/kotlin/io/kinference/primitives/vector/OperationNode.kt index 0c81ac0..478734d 100644 --- a/plugin-build/primitives-annotations/src/commonMain/kotlin/io/kinference/primitives/vector/OperationNode.kt +++ b/plugin-build/primitives-annotations/src/commonMain/kotlin/io/kinference/primitives/vector/OperationNode.kt @@ -33,63 +33,81 @@ sealed class BinaryOp(val left: OpNode, val right: OpNode) : OpNode() { class IfElse(val condition: VecMask, val left: OpNode, val right: OpNode) : OpNode() {} -sealed class AssociativeWrapper(){} +sealed class AssociativeWrapper() {} -class Exp(arg: OpNode): UnaryOp(arg){ +class Exp(arg: OpNode) : UnaryOp(arg) { constructor(arg: OpNode, mask: VecMask) : this(arg) } -class Abs(arg: OpNode): UnaryOp(arg){ + +class Abs(arg: OpNode) : UnaryOp(arg) { + constructor(arg: OpNode, mask: VecMask) : this(arg) +} + +class Neg(arg: OpNode) : UnaryOp(arg) { constructor(arg: OpNode, mask: VecMask) : this(arg) } -class Neg(arg: OpNode): UnaryOp(arg){ + +class Log(arg: OpNode) : UnaryOp(arg) { + constructor(arg: OpNode, mask: VecMask) : this(arg) +} + +class Sqrt(arg: OpNode) : UnaryOp(arg) { constructor(arg: OpNode, mask: VecMask) : this(arg) } -class Log(arg: OpNode): UnaryOp(arg){ + +class Cbrt(arg: OpNode) : UnaryOp(arg) { constructor(arg: OpNode, mask: VecMask) : this(arg) } -class Add(left: OpNode, right: OpNode): BinaryOp(left, right){ +class Add(left: OpNode, right: OpNode) : BinaryOp(left, right) { constructor(left: OpNode, right: OpNode, mask: VecMask) : this(left, right) } -class Sub(left: OpNode, right: OpNode): BinaryOp(left, right){ + +class Sub(left: OpNode, right: OpNode) : BinaryOp(left, right) { constructor(left: OpNode, right: OpNode, mask: VecMask) : this(left, right) } -class Mul(left: OpNode, right: OpNode): BinaryOp(left, right){ + +class Mul(left: OpNode, right: OpNode) : BinaryOp(left, right) { constructor(left: OpNode, right: OpNode, mask: VecMask) : this(left, right) } -class Div(left: OpNode, right: OpNode): BinaryOp(left, right){ + +class Div(left: OpNode, right: OpNode) : BinaryOp(left, right) { constructor(left: OpNode, right: OpNode, mask: VecMask) : this(left, right) } -class Pow(left: OpNode, right: OpNode): BinaryOp(left, right){ + +class Pow(left: OpNode, right: OpNode) : BinaryOp(left, right) { constructor(left: OpNode, right: OpNode, mask: VecMask) : this(left, right) } -class Max(left: OpNode, right: OpNode): BinaryOp(left, right){ + +class Max(left: OpNode, right: OpNode) : BinaryOp(left, right) { constructor(left: OpNode, right: OpNode, mask: VecMask) : this(left, right) } -class Min(left: OpNode, right: OpNode): BinaryOp(left, right){ + +class Min(left: OpNode, right: OpNode) : BinaryOp(left, right) { constructor(left: OpNode, right: OpNode, mask: VecMask) : this(left, right) } -object ADD: AssociativeWrapper(){} -object MUL: AssociativeWrapper(){} -object MAX: AssociativeWrapper(){} -object MIN: AssociativeWrapper(){} -sealed class VecMask(){} +object ADD : AssociativeWrapper() {} +object MUL : AssociativeWrapper() {} +object MAX : AssociativeWrapper() {} +object MIN : AssociativeWrapper() {} + +sealed class VecMask() {} -sealed class Comparator(left: OpNode, right: OpNode): VecMask(){} -sealed class MaskBinaryOp(left: VecMask, right: VecMask): VecMask(){} -sealed class MaskUnaryOp(arg: VecMask): VecMask(){} +sealed class Comparator(left: OpNode, right: OpNode) : VecMask() {} +sealed class MaskBinaryOp(left: VecMask, right: VecMask) : VecMask() {} +sealed class MaskUnaryOp(arg: VecMask) : VecMask() {} -class Not(arg: VecMask): VecMask(){} +class Not(arg: VecMask) : VecMask() {} -class Eq(left: OpNode, right: OpNode): Comparator(left, right){} -class Neq(left: OpNode, right: OpNode): Comparator(left, right){} -class LT(left: OpNode, right: OpNode): Comparator(left, right){} -class LE(left: OpNode, right: OpNode): Comparator(left, right){} -class GT(left: OpNode, right: OpNode): Comparator(left, right){} -class GE(left: OpNode, right: OpNode): Comparator(left, right){} +class Eq(left: OpNode, right: OpNode) : Comparator(left, right) {} +class Neq(left: OpNode, right: OpNode) : Comparator(left, right) {} +class LT(left: OpNode, right: OpNode) : Comparator(left, right) {} +class LE(left: OpNode, right: OpNode) : Comparator(left, right) {} +class GT(left: OpNode, right: OpNode) : Comparator(left, right) {} +class GE(left: OpNode, right: OpNode) : Comparator(left, right) {} -class And(left: VecMask, right: VecMask): MaskBinaryOp(left, right){} -class Or(left: VecMask, right: VecMask): MaskBinaryOp(left, right){} -class Xor(left: VecMask, right: VecMask): MaskBinaryOp(left, right){} +class And(left: VecMask, right: VecMask) : MaskBinaryOp(left, right) {} +class Or(left: VecMask, right: VecMask) : MaskBinaryOp(left, right) {} +class Xor(left: VecMask, right: VecMask) : MaskBinaryOp(left, right) {} diff --git a/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/generator/Utils.kt b/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/generator/Utils.kt index 7c2f4eb..f78850a 100644 --- a/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/generator/Utils.kt +++ b/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/generator/Utils.kt @@ -11,9 +11,11 @@ import org.jetbrains.kotlin.descriptors.ClassConstructorDescriptor import org.jetbrains.kotlin.descriptors.DeclarationDescriptor import org.jetbrains.kotlin.js.descriptorUtils.getKotlinTypeFqName import org.jetbrains.kotlin.js.resolve.diagnostics.findPsi +import org.jetbrains.kotlin.load.kotlin.toSourceElement import org.jetbrains.kotlin.psi.* import org.jetbrains.kotlin.resolve.BindingContext import org.jetbrains.kotlin.resolve.isValueClass +import org.jetbrains.kotlin.resolve.source.getPsi import org.jetbrains.kotlin.types.typeUtil.supertypes import kotlin.reflect.KProperty @@ -58,6 +60,15 @@ internal fun DeclarationDescriptor.isCompanion() = findPsi() is KtObjectDeclarat internal fun DeclarationDescriptor.isConstructor() = this is ClassConstructorDescriptor || findPsi() is KtConstructor<*> && containingDeclaration?.findPsi() is KtClass +internal fun KtSimpleNameExpression.initializer(context: BindingContext): KtExpression? { + val descriptor = context.get(BindingContext.REFERENCE_TARGET, this) ?: return null + val declaration = descriptor.toSourceElement.getPsi() ?: return null + if (declaration !is KtProperty) return null + return declaration.initializer +} + +internal fun KtExpression.fqTypename(context: BindingContext): String? = context.getType(this)?.getKotlinTypeFqName(false) + internal fun KtNamedDeclaration.specialize(primitive: Primitive<*, *>, collector: MessageCollector): String { val name = name!! collector.require(CompilerMessageSeverity.WARNING, this, "Primitive" in name) { diff --git a/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/generator/processor/VectorReplacementProcessor.kt b/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/generator/processor/VectorReplacementProcessor.kt index 8e1f4cb..cdd7863 100644 --- a/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/generator/processor/VectorReplacementProcessor.kt +++ b/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/generator/processor/VectorReplacementProcessor.kt @@ -1,6 +1,8 @@ package io.kinference.primitives.generator.processor import io.kinference.primitives.generator.Primitive +import io.kinference.primitives.generator.fqTypename +import io.kinference.primitives.generator.initializer import io.kinference.primitives.vector.* import org.jetbrains.kotlin.cli.common.messages.MessageCollector import org.jetbrains.kotlin.js.descriptorUtils.getKotlinTypeFqName @@ -31,6 +33,8 @@ internal class VectorReplacementProcessor( "ABS" to { x: String -> "abs($x)" }, "NEG" to { x: String -> "(-$x)" }, "LOG" to { x: String -> "ln($x)" }, + "SQRT" to { x: String -> "sqrt($x)" }, + "CBRT" to { x: String -> "cbrt($x)" }, ).withDefault { null } val binaryLinearReplacements = mapOf( @@ -62,6 +66,8 @@ internal class VectorReplacementProcessor( "Log" to "LOG", "Neg" to "NEG", "Pow" to "POW", + "Sqrt" to "SQRT", + "Cbrt" to "CBRT", ).withDefault { null } val maskHandles = mapOf( @@ -110,16 +116,11 @@ internal class VectorReplacementProcessor( var linDeclarations: String = "" var localVariables: Set = emptySet() - private fun processDeclaration(expr: KtExpression): Triple? { + private fun processDeclaration(expr: KtExpression?): Triple? { if (expr !is KtSimpleNameExpression) return null val varName = expr.text - val descriptor = context.get(BindingContext.REFERENCE_TARGET, expr) ?: return null //Triple("NOT_DECL: ${expr.text}", "NOT_DECL", false) - val declaration = descriptor.toSourceElement.getPsi() ?: return null //Triple("NOT_DECL_PSI: ${expr.text}", "NOT_DECL", false) - - if (declaration !is KtVariableDeclaration) return null - val actualBody = declaration.initializer ?: return null //Triple("NOT_DECL_BODY: ${expr.text}", "NOT_DECL", false) - //return Triple("BODY: ${actualBody.text}", "", false) + val actualBody = expr.initializer(context) val (vecReplacement, linReplacement, scalar) = process(actualBody) ?: return null if (varName !in localVariables) { localVariables = localVariables + varName @@ -136,12 +137,10 @@ internal class VectorReplacementProcessor( fun process(expr: KtExpression?): Triple? { - if (expr == null) return null if (expr !is KtCallExpression) { return processDeclaration(expr) } - val exprType = context.getType(expr) ?: return null - val exprTypename = exprType.getKotlinTypeFqName(false) + val exprTypename = expr.fqTypename(context) ?: return null val shortName = exprTypename.substringAfterLast('.') val args = expr.valueArguments return when (exprTypename) { @@ -233,16 +232,16 @@ internal class VectorReplacementProcessor( Triple(vectorized, linear, isScalar) } - else -> null + else -> { + null + } } } fun processMask(expr: KtExpression?): Pair? { - if (expr == null) return null - val exprType = context.getType(expr) ?: return null - val exprTypename = exprType.getKotlinTypeFqName(false) - val shortName = exprTypename.substringAfterLast('.') if (expr !is KtCallExpression) return null + val exprTypename = expr.fqTypename(context) ?: return null + val shortName = exprTypename.substringAfterLast('.') val args = expr.valueArguments return when { exprTypename in maskUnaryOpTypes -> { From c88ef04431cec7818bbc96843618f0a64df65873 Mon Sep 17 00:00:00 2001 From: Tommaso Dossi Date: Tue, 5 Aug 2025 17:15:18 +0200 Subject: [PATCH 13/16] small refactoring --- .../processor/ReplacementProcessor.kt | 17 +- .../processor/VectorReplacementProcessor.kt | 145 +++++++++--------- 2 files changed, 77 insertions(+), 85 deletions(-) diff --git a/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/generator/processor/ReplacementProcessor.kt b/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/generator/processor/ReplacementProcessor.kt index 020a44a..ee42317 100644 --- a/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/generator/processor/ReplacementProcessor.kt +++ b/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/generator/processor/ReplacementProcessor.kt @@ -23,10 +23,7 @@ import org.jetbrains.kotlin.resolve.descriptorUtil.fqNameSafe import org.jetbrains.kotlin.types.typeUtil.supertypes internal class ReplacementProcessor( - private val context: BindingContext, - private val collector: MessageCollector, - private val file: KtFile, - private val vectorize: Boolean = true + private val context: BindingContext, private val collector: MessageCollector, private val file: KtFile, private val vectorize: Boolean = true ) { companion object { internal fun toType(primitive: Primitive<*, *>): String { @@ -51,8 +48,7 @@ internal class ReplacementProcessor( (PrimitiveArray::class.qualifiedName!!) to { it.arrayTypeName }, (PrimitiveArray::class.qualifiedName!! + ".") to { it.arrayTypeName }, - (PrimitiveArray::class.qualifiedName!! + ".Companion") to { it.arrayTypeName } - ) + (PrimitiveArray::class.qualifiedName!! + ".Companion") to { it.arrayTypeName }) } @@ -89,8 +85,7 @@ internal class ReplacementProcessor( defaultReplacements[type]!!.invoke(primitive) } - (target.isKtClassOrObject() && target.containingDeclaration!!.isAnnotatedWith()) || - (target.isNamedFunction() || target.isKtClassOrObject()) && target.isAnnotatedWith() -> { + (target.isKtClassOrObject() && target.containingDeclaration!!.isAnnotatedWith()) || (target.isNamedFunction() || target.isKtClassOrObject()) && target.isAnnotatedWith() -> { expression.text.specialize(primitive) } @@ -126,8 +121,7 @@ internal class ReplacementProcessor( val destOffset = args[1].text val len = args[2].text - if (primitive.dataType in DataType.VECTORIZABLE.resolve() && vectorize) - return """ + if (primitive.dataType in DataType.VECTORIZABLE.resolve() && vectorize) return """ if($vecEnabled) { val $vecLen = ${vecProcessor.vecSpecies}.length() val $vecEnd = $len - ($len % $vecLen) @@ -147,8 +141,7 @@ internal class ReplacementProcessor( } } """.trimIndent() - else - return """ + else return """ for($vecIdx in 0 until $len) { ${vecProcessor.linDeclarations} $dest[$destOffset + $vecIdx] = $linReplacement.$toPrimitive diff --git a/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/generator/processor/VectorReplacementProcessor.kt b/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/generator/processor/VectorReplacementProcessor.kt index cdd7863..e083e87 100644 --- a/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/generator/processor/VectorReplacementProcessor.kt +++ b/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/generator/processor/VectorReplacementProcessor.kt @@ -5,16 +5,12 @@ import io.kinference.primitives.generator.fqTypename import io.kinference.primitives.generator.initializer import io.kinference.primitives.vector.* import org.jetbrains.kotlin.cli.common.messages.MessageCollector -import org.jetbrains.kotlin.js.descriptorUtils.getKotlinTypeFqName import org.jetbrains.kotlin.psi.KtCallExpression import org.jetbrains.kotlin.psi.KtExpression import org.jetbrains.kotlin.psi.KtSimpleNameExpression import org.jetbrains.kotlin.resolve.BindingContext -import org.jetbrains.kotlin.load.kotlin.toSourceElement -import org.jetbrains.kotlin.psi.KtElement import org.jetbrains.kotlin.psi.KtFile -import org.jetbrains.kotlin.psi.KtVariableDeclaration -import org.jetbrains.kotlin.resolve.source.getPsi +import org.jetbrains.kotlin.psi.KtValueArgument internal class VectorReplacementProcessor( @@ -47,7 +43,7 @@ internal class VectorReplacementProcessor( "POW" to { x: String, y: String -> "($x).pow($y)" }, ).withDefault { null } - val isAssoc = mapOf( + val isAssoc = mapOf( "ADD" to true, "MUL" to true, "MIN" to true, @@ -135,65 +131,24 @@ internal class VectorReplacementProcessor( return Triple("${varName}_vec", "${varName}_lin", scalar) } - fun process(expr: KtExpression?): Triple? { if (expr !is KtCallExpression) { return processDeclaration(expr) } + val exprTypename = expr.fqTypename(context) ?: return null val shortName = exprTypename.substringAfterLast('.') val args = expr.valueArguments + return when (exprTypename) { in unaryOpNames -> { - if (args.size != 1 && args.size != 2) return null - val childExpr = args[0].getArgumentExpression() - val masked = args.size == 2 - var (childVector, childLinear, isScalar) = process(childExpr) ?: return null val handle = vectorHandles[shortName] ?: return null - val linReplace = unaryLinearReplacements[handle] ?: return null - var linear = linReplace(childLinear) - var vectorized = """$childVector - .lanewise(VectorOperators.$handle""".trimIndent() - if (masked) { - isScalar = false - val maskExpr = args[1].getArgumentExpression() - val (maskVector, maskLinear) = processMask(maskExpr) ?: return null - linear = "(if($maskLinear) $linear else $childLinear)" - vectorized += ", $maskVector)" - } else { - vectorized += ")" - } - Triple(vectorized, linear, isScalar) + processUnaryOperation(handle, args) } in binaryOpNames -> { - if (args.size != 2 && args.size != 3) return null - val masked = args.size == 3 - val leftExpr = args[0].getArgumentExpression() - val rightExpr = args[1].getArgumentExpression() val handle = vectorHandles[shortName] ?: return null - val (leftVector, leftLinear, leftScalar) = process(leftExpr) ?: return null - val (rightVector, rightLinear, rightScalar) = process(rightExpr) ?: return null - var isScalar = leftScalar && rightScalar - var linear = binaryLinearReplacements[handle]?.invoke(leftLinear, rightLinear) ?: return null - - var vectorized = if (rightScalar) """$leftVector. - lanewise(VectorOperators.$handle, $rightLinear""".trimIndent() - else """$leftVector - .lanewise(VectorOperators.$handle, $rightVector""".trimIndent() - - if (masked) { - isScalar = false - val maskExpr = args[2].getArgumentExpression() - val (maskVector, maskLinear) = processMask(maskExpr) ?: return null - linear = "(if($maskLinear) $linear else $leftLinear)" - vectorized += ", $maskVector)" - } else { - vectorized += ")" - } - Triple( - vectorized, linear, isScalar - ) + processBinaryOperation(handle, args) } scalarType -> { @@ -217,38 +172,85 @@ internal class VectorReplacementProcessor( ) } - ifElseType -> { - if (args.size != 3) return null - val mask = args[0].getArgumentExpression() ?: return null - val left = args[1].getArgumentExpression() ?: return null - val right = args[2].getArgumentExpression() ?: return null - val (maskVector, maskLinear) = processMask(mask) ?: return null - val (leftVector, leftLinear, leftScalar) = process(left) ?: return null - val (rightVector, rightLinear, rightScalar) = process(right) ?: return null - val isScalar = leftScalar && rightScalar - val linear = "(if($maskLinear) $leftLinear else $rightLinear)" - val vectorized = """ - $rightVector.blend($leftVector, $maskVector)""".trimIndent() - Triple(vectorized, linear, isScalar) - } + ifElseType -> processIfElse(args) + else -> null + } + } - else -> { - null - } + fun processUnaryOperation(handle: String, args: List): Triple? { + if (args.size != 1 && args.size != 2) return null + val childExpr = args[0].getArgumentExpression() + val masked = args.size == 2 + var (childVector, childLinear, isScalar) = process(childExpr) ?: return null + val linReplace = unaryLinearReplacements[handle] ?: return null + var linear = linReplace(childLinear) + var vectorized = """$childVector + .lanewise(VectorOperators.$handle""".trimIndent() + if (masked) { + isScalar = false + val maskExpr = args[1].getArgumentExpression() + val (maskVector, maskLinear) = processMask(maskExpr) ?: return null + linear = "(if($maskLinear) $linear else $childLinear)" + vectorized += ", $maskVector)" + } else { + vectorized += ")" } + return Triple(vectorized, linear, isScalar) } + fun processBinaryOperation(handle: String, args: List): Triple? { + if (args.size != 2 && args.size != 3) return null + val masked = args.size == 3 + val leftExpr = args[0].getArgumentExpression() + val rightExpr = args[1].getArgumentExpression() + val (leftVector, leftLinear, leftScalar) = process(leftExpr) ?: return null + val (rightVector, rightLinear, rightScalar) = process(rightExpr) ?: return null + var isScalar = leftScalar && rightScalar + var linear = binaryLinearReplacements[handle]?.invoke(leftLinear, rightLinear) ?: return null + + var vectorized = if (rightScalar) """$leftVector. + lanewise(VectorOperators.$handle, $rightLinear""".trimIndent() + else """$leftVector + .lanewise(VectorOperators.$handle, $rightVector""".trimIndent() + + if (masked) { + isScalar = false + val maskExpr = args[2].getArgumentExpression() + val (maskVector, maskLinear) = processMask(maskExpr) ?: return null + linear = "(if($maskLinear) $linear else $leftLinear)" + vectorized += ", $maskVector)" + } else { + vectorized += ")" + } + return Triple(vectorized, linear, isScalar) + } + + fun processIfElse(args: List): Triple? { + if (args.size != 3) return null + val mask = args[0].getArgumentExpression() ?: return null + val left = args[1].getArgumentExpression() ?: return null + val right = args[2].getArgumentExpression() ?: return null + val (maskVector, maskLinear) = processMask(mask) ?: return null + val (leftVector, leftLinear, leftScalar) = process(left) ?: return null + val (rightVector, rightLinear, rightScalar) = process(right) ?: return null + val isScalar = leftScalar && rightScalar + val linear = "(if($maskLinear) $leftLinear else $rightLinear)" + val vectorized = """ + $rightVector.blend($leftVector, $maskVector)""".trimIndent() + return Triple(vectorized, linear, isScalar) + } + + fun processMask(expr: KtExpression?): Pair? { if (expr !is KtCallExpression) return null val exprTypename = expr.fqTypename(context) ?: return null - val shortName = exprTypename.substringAfterLast('.') + val handle = maskHandles[exprTypename.substringAfterLast('.')] ?: return null val args = expr.valueArguments return when { exprTypename in maskUnaryOpTypes -> { if (args.size != 1) return null val child = args[0].getArgumentExpression() ?: return null val (vecReplacement, linReplacement) = processMask(child) ?: return null - val handle = maskHandles[shortName] ?: return null val linReplacer = maskUnaryReplacement[handle] ?: return null Pair( @@ -262,7 +264,6 @@ internal class VectorReplacementProcessor( val (leftVecReplacement, leftLinReplacement) = processMask(left) ?: return null val right = args[0].getArgumentExpression() ?: return null val (rightVecReplacement, rightLinReplacement) = processMask(right) ?: return null - val handle = maskHandles[shortName] ?: return null val linReplacer = maskBinaryReplacement[handle] ?: return null Pair( linReplacer(leftVecReplacement, rightVecReplacement), linReplacer(leftLinReplacement, rightLinReplacement) @@ -273,10 +274,8 @@ internal class VectorReplacementProcessor( if (args.size != 2) return null val leftExpr = args[0].getArgumentExpression() ?: return null val rightExpr = args[1].getArgumentExpression() ?: return null - val handle = maskHandles[shortName] ?: return null val (leftVector, leftLinear, leftScalar) = process(leftExpr) ?: return null val (rightVector, rightLinear, rightScalar) = process(rightExpr) ?: return null - val isScalar = leftScalar && rightScalar val linear = comparatorReplacement[handle]?.invoke(leftLinear, rightLinear) ?: return null val vectorized = """ $leftVector From f1fa7c6a325c0e0cfc423821bebbc412029ec875 Mon Sep 17 00:00:00 2001 From: Tommaso Dossi Date: Fri, 8 Aug 2025 15:10:52 +0200 Subject: [PATCH 14/16] updated version to 2.1.0-0 --- README.md | 3 +++ gradle.properties | 2 +- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 9bef627..28728a3 100644 --- a/README.md +++ b/README.md @@ -65,3 +65,6 @@ This code would generate specializations for Float and Int types via replacement of `PrimitiveType` with corresponding type (Float or Int). Also, note that standard functions, like MAX_VALUE, are available for `PrimitiveType` like it would be a real `Number`. +Version 2.1.0 adds the possibility of generating vectorized code using +Java's [vector API](https://download.java.net/java/early_access/jdk25/docs/api/jdk.incubator.vector/module-summary.html). +Currently, this only works inside of KInference and in combination with primitive specialization. diff --git a/gradle.properties b/gradle.properties index 138a42b..16812ce 100644 --- a/gradle.properties +++ b/gradle.properties @@ -1,2 +1,2 @@ GROUP=io.kinference.primitives -VERSION=2.1.0-dev +VERSION=2.1.0-0 From 23d1ed5f3add52437f933c4e5b3da6aef247c652 Mon Sep 17 00:00:00 2001 From: Tommaso Dossi Date: Fri, 8 Aug 2025 15:11:43 +0200 Subject: [PATCH 15/16] Refactoring and better error handling with logging --- .../kinference/primitives/PrimitivesTask.kt | 29 +++- .../generator/errors/KtElementError.kt | 2 +- .../processor/ReplacementProcessor.kt | 52 ++++---- .../processor/VectorReplacementProcessor.kt | 125 ++++++++++++------ 4 files changed, 137 insertions(+), 71 deletions(-) diff --git a/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/PrimitivesTask.kt b/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/PrimitivesTask.kt index b2ea59d..f7156de 100644 --- a/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/PrimitivesTask.kt +++ b/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/PrimitivesTask.kt @@ -11,6 +11,8 @@ import org.gradle.api.provider.Property import org.gradle.api.tasks.* import org.gradle.work.NormalizeLineEndings import org.jetbrains.kotlin.cli.common.config.addKotlinSourceRoot +import org.jetbrains.kotlin.cli.common.messages.CompilerMessageSeverity +import org.jetbrains.kotlin.cli.common.messages.CompilerMessageSourceLocation import org.jetbrains.kotlin.cli.common.messages.MessageCollector import org.jetbrains.kotlin.cli.jvm.config.addJvmClasspathRoots import org.jetbrains.kotlin.gradle.dsl.KotlinMultiplatformExtension @@ -41,6 +43,23 @@ abstract class PrimitivesTask : DefaultTask() { @get:Internal abstract val primitivesCache: Property + private val messageCollector = object : MessageCollector { + override fun clear() {} + + override fun hasErrors(): Boolean = false + + override fun report(severity: CompilerMessageSeverity, message: String, location: CompilerMessageSourceLocation?) { + when (severity) { + CompilerMessageSeverity.STRONG_WARNING -> logger.error("$message: $location") + CompilerMessageSeverity.ERROR -> logger.error("$message: $location") + CompilerMessageSeverity.INFO -> logger.info(message) + CompilerMessageSeverity.WARNING -> Unit + else -> logger.debug(message) + } + } + } + + init { group = "generate" description = "Generates primitives from sources" @@ -68,9 +87,9 @@ abstract class PrimitivesTask : DefaultTask() { FileWithCommon(source, isCommon) } - val analyzeFun = when(compilation.get().platformType) { - KotlinPlatformType.jvm -> Analyze::analyzeJvmSources - KotlinPlatformType.js -> Analyze::analyzeJsSources + val analyzeFun = when (compilation.get().platformType) { + KotlinPlatformType.jvm -> Analyze::analyzeJvmSources + KotlinPlatformType.js -> Analyze::analyzeJsSources KotlinPlatformType.common -> Analyze::analyzeCommonSources else -> error("Unsupported platform type ${compilation.get().platformType}") } @@ -89,11 +108,13 @@ abstract class PrimitivesTask : DefaultTask() { val annotated = ktSources.filter { it.isAnnotatedWith(result.bindingContext) } val notGeneratedYet = annotated.filterNot { it.virtualFilePath in primitivesCache.get().resolvedPaths } + + for (ktFile in notGeneratedYet) { val sourceSet = findSourceSetName(ktFile.virtualFilePath) val outputDir = generationPath.dir(sourceSet).get().asFile - PrimitiveGenerator(ktFile, result.bindingContext, outputDir, MessageCollector.NONE, vectorize.get()).generate() + PrimitiveGenerator(ktFile, result.bindingContext, outputDir, messageCollector, vectorize.get()).generate() primitivesCache.get().resolvedPaths.add(ktFile.virtualFilePath) } diff --git a/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/generator/errors/KtElementError.kt b/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/generator/errors/KtElementError.kt index 5ad19dd..73b9861 100644 --- a/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/generator/errors/KtElementError.kt +++ b/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/generator/errors/KtElementError.kt @@ -21,7 +21,7 @@ internal fun MessageCollector.require(severity: CompilerMessageSeverity, element } } -private fun KtElement.getLocation(): CompilerMessageSourceLocation? { +fun KtElement.getLocation(): CompilerMessageSourceLocation? { val lineToColumn = if (this !is KtFile) StringUtil.offsetToLineColumn(containingKtFile.text, textOffset) else null return CompilerMessageLocation.create(containingKtFile.virtualFilePath, lineToColumn?.line ?: 1, lineToColumn?.line ?: 1, null) } diff --git a/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/generator/processor/ReplacementProcessor.kt b/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/generator/processor/ReplacementProcessor.kt index ee42317..60bb23e 100644 --- a/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/generator/processor/ReplacementProcessor.kt +++ b/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/generator/processor/ReplacementProcessor.kt @@ -4,6 +4,7 @@ import io.kinference.primitives.annotations.GenerateNameFromPrimitives import io.kinference.primitives.annotations.GenerateVector import io.kinference.primitives.annotations.MakePublic import io.kinference.primitives.generator.* +import io.kinference.primitives.generator.errors.getLocation import io.kinference.primitives.generator.errors.require import io.kinference.primitives.types.* import io.kinference.primitives.vector.* @@ -108,7 +109,16 @@ internal class ReplacementProcessor( if (!isVectorClass(receiver, context)) return null val vecProcessor = VectorReplacementProcessor(context, primitive, collector, file) - val (vecReplacement, linReplacement, isScalar) = vecProcessor.process(receiver) ?: return null + val res = vecProcessor.process(receiver) + if (res == null) { + collector.report( + CompilerMessageSeverity.STRONG_WARNING, + "Could not process vectorized expression, the code will not be generated", + expr.getLocation() + ) + return "" + } + val (vecReplacement, linReplacement, isScalar) = res val toPrimitive = "${toType(primitive)}()" val vecLen = "_vecLen_$idx" @@ -120,6 +130,11 @@ internal class ReplacementProcessor( val dest = args[0].text val destOffset = args[1].text val len = args[2].text + val linearCode = """ + for($vecIdx in 0 until $len) { + ${vecProcessor.linDeclarations} + $dest[$destOffset + $vecIdx] = $linReplacement.$toPrimitive + }""".trimIndent() if (primitive.dataType in DataType.VECTORIZABLE.resolve() && vectorize) return """ if($vecEnabled) { @@ -135,26 +150,26 @@ internal class ReplacementProcessor( $dest[$destOffset + $vecIdx] = $linReplacement.$toPrimitive } }else{ - for($vecIdx in 0 until $len) { - ${vecProcessor.linDeclarations} - $dest[$destOffset + $vecIdx] = $linReplacement.$toPrimitive - } + $linearCode } """.trimIndent() - else return """ - for($vecIdx in 0 until $len) { - ${vecProcessor.linDeclarations} - $dest[$destOffset + $vecIdx] = $linReplacement.$toPrimitive - }""".trimIndent() + else return linearCode } else if (callName == "reduce" && args.size == 2) { val handle = args[0].text val len = args[1].text - if (VectorReplacementProcessor.isAssoc[handle] != true) return "" val neutral = vecProcessor.neutralElement[handle] ?: return "" val linearOp = VectorReplacementProcessor.binaryLinearReplacements[handle] ?: return "" val linAccumulate = linearOp("ret", linReplacement) + val linearCode = """ + var ret = $neutral + for($vecIdx in 0 until $len) { + ${vecProcessor.linDeclarations} + ret = $linAccumulate.$toPrimitive + } + ret.$toPrimitive + """.trimIndent() if (primitive.dataType in DataType.VECTORIZABLE.resolve() && vectorize) { return """{ @@ -174,22 +189,13 @@ internal class ReplacementProcessor( } ret.$toPrimitive }else{ - var ret = $neutral - for($vecIdx in 0 until $len) { - ${vecProcessor.linDeclarations} - ret = $linAccumulate.$toPrimitive - } - ret.$toPrimitive} + $linearCode + } }.invoke() """.trimIndent() } else { return """{ - var ret = $neutral - for($vecIdx in 0 until $len) { - ${vecProcessor.linDeclarations} - ret = $linAccumulate.$toPrimitive - } - ret.$toPrimitive + $linearCode }.invoke() """.trimIndent() } diff --git a/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/generator/processor/VectorReplacementProcessor.kt b/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/generator/processor/VectorReplacementProcessor.kt index e083e87..8bbf9af 100644 --- a/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/generator/processor/VectorReplacementProcessor.kt +++ b/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/generator/processor/VectorReplacementProcessor.kt @@ -1,9 +1,12 @@ package io.kinference.primitives.generator.processor import io.kinference.primitives.generator.Primitive +import io.kinference.primitives.generator.errors.require import io.kinference.primitives.generator.fqTypename import io.kinference.primitives.generator.initializer +import io.kinference.primitives.types.DataType import io.kinference.primitives.vector.* +import org.jetbrains.kotlin.cli.common.messages.CompilerMessageSeverity import org.jetbrains.kotlin.cli.common.messages.MessageCollector import org.jetbrains.kotlin.psi.KtCallExpression import org.jetbrains.kotlin.psi.KtExpression @@ -11,6 +14,7 @@ import org.jetbrains.kotlin.psi.KtSimpleNameExpression import org.jetbrains.kotlin.resolve.BindingContext import org.jetbrains.kotlin.psi.KtFile import org.jetbrains.kotlin.psi.KtValueArgument +import io.kinference.primitives.generator.errors.* internal class VectorReplacementProcessor( @@ -31,7 +35,7 @@ internal class VectorReplacementProcessor( "LOG" to { x: String -> "ln($x)" }, "SQRT" to { x: String -> "sqrt($x)" }, "CBRT" to { x: String -> "cbrt($x)" }, - ).withDefault { null } + ) val binaryLinearReplacements = mapOf( "ADD" to { x: String, y: String -> "($x + $y)" }, @@ -41,14 +45,7 @@ internal class VectorReplacementProcessor( "MAX" to { x: String, y: String -> "maxOf($x, $y)" }, "MIN" to { x: String, y: String -> "minOf($x, $y)" }, "POW" to { x: String, y: String -> "($x).pow($y)" }, - ).withDefault { null } - - val isAssoc = mapOf( - "ADD" to true, - "MUL" to true, - "MIN" to true, - "MAX" to true, - ).withDefault { false } + ) val vectorHandles = mapOf( "Add" to "ADD", @@ -64,21 +61,37 @@ internal class VectorReplacementProcessor( "Pow" to "POW", "Sqrt" to "SQRT", "Cbrt" to "CBRT", - ).withDefault { null } + ) + + val supportedTypes = mapOf( + "EXP" to setOf(DataType.FLOAT, DataType.DOUBLE), + "LOG" to setOf(DataType.FLOAT, DataType.DOUBLE), + "SQRT" to setOf(DataType.FLOAT, DataType.DOUBLE), + "CBRT" to setOf(DataType.FLOAT, DataType.DOUBLE), + ).withDefault { DataType.ALL.resolve() } val maskHandles = mapOf( - "And" to "AND", "Or" to "OR", "Xor" to "XOR", "Not" to "NOT", "Eq" to "EQ", "Neq" to "NEQ", "LT" to "LT", "LE" to "LE", "GT" to "GT", "GE" to "GE" - ).withDefault { null } + "And" to "AND", + "Or" to "OR", + "Xor" to "XOR", + "Not" to "NOT", + "Eq" to "EQ", + "Neq" to "NEQ", + "LT" to "LT", + "LE" to "LE", + "GT" to "GT", + "GE" to "GE" + ) val maskUnaryReplacement = mapOf( "Not" to ({ x: String -> "$x.not()" }), - ).withDefault { null } + ) val maskBinaryReplacement = mapOf( "And" to ({ x: String, y: String -> "($x.and($y)" }), "Or" to ({ x: String, y: String -> "$x.or($y)" }), "Xor" to ({ x: String, y: String -> "$x.xor($y)" }), - ).withDefault { null } + ) val comparatorReplacement = mapOf( "Eq" to ({ x: String, y: String -> "($x == $y)" }), @@ -105,35 +118,43 @@ internal class VectorReplacementProcessor( "MUL" to "1.${ReplacementProcessor.toType(primitive)}()", "MIN" to "${primitive.typeName}.MAX_VALUE", "MAX" to "${primitive.typeName}.MIN_VALUE" - ).withDefault { null } + ) var scalarDeclarations: String = "" var vecDeclarations: String = "" var linDeclarations: String = "" var localVariables: Set = emptySet() + var scalarVariables: Set = emptySet() - private fun processDeclaration(expr: KtExpression?): Triple? { + private fun processSimpleName(expr: KtExpression?): Triple? { if (expr !is KtSimpleNameExpression) return null val varName = expr.text val actualBody = expr.initializer(context) - val (vecReplacement, linReplacement, scalar) = process(actualBody) ?: return null if (varName !in localVariables) { - localVariables = localVariables + varName - if (scalar) { - scalarDeclarations += "val ${varName}_vec = $vecReplacement\n" - scalarDeclarations += "val ${varName}_lin = $linReplacement\n" - } else { - vecDeclarations += "val ${varName}_vec = $vecReplacement\n" - linDeclarations += "val ${varName}_lin = $linReplacement\n" - } + val success = addVariable(actualBody, varName) + if (!success) return null + } + return Triple("${varName}_vec", "${varName}_lin", varName in scalarVariables) + } + + private fun addVariable(expr: KtExpression?, varName: String): Boolean { + val (vecReplacement, linReplacement, scalar) = process(expr) ?: return false + localVariables = localVariables + varName + if (scalar) { + scalarVariables = scalarVariables + varName + scalarDeclarations += "val ${varName}_vec = $vecReplacement\n" + scalarDeclarations += "val ${varName}_lin = $linReplacement\n" + } else { + vecDeclarations += "val ${varName}_vec = $vecReplacement\n" + linDeclarations += "val ${varName}_lin = $linReplacement\n" } - return Triple("${varName}_vec", "${varName}_lin", scalar) + return true } fun process(expr: KtExpression?): Triple? { if (expr !is KtCallExpression) { - return processDeclaration(expr) + return processSimpleName(expr) } val exprTypename = expr.fqTypename(context) ?: return null @@ -143,11 +164,27 @@ internal class VectorReplacementProcessor( return when (exprTypename) { in unaryOpNames -> { val handle = vectorHandles[shortName] ?: return null + if (!supportedTypes.getValue(handle).contains(primitive.dataType)) { + collector.report( + CompilerMessageSeverity.STRONG_WARNING, + "$handle operation is not supported for ${primitive.dataType} type", + expr.getLocation() + ) + return null + } processUnaryOperation(handle, args) } in binaryOpNames -> { val handle = vectorHandles[shortName] ?: return null + if (!supportedTypes.getValue(handle).contains(primitive.dataType)) { + collector.report( + CompilerMessageSeverity.STRONG_WARNING, + "$handle operation is not supported for ${primitive.dataType} type", + expr.getLocation() + ) + return null + } processBinaryOperation(handle, args) } @@ -167,9 +204,9 @@ internal class VectorReplacementProcessor( 2 -> args[1].text else -> "0" } - Triple( - "${vecName}.fromArray($vecSpecies, $src, $offset + _vec_internal_idx)", "$src[$offset + _vec_internal_idx]", false - ) + val vectorized = "${vecName}.fromArray($vecSpecies, $src, $offset + _vec_internal_idx)" + val linear = "$src[$offset+_vec_internal_idx]" + Triple(vectorized, linear, false) } ifElseType -> processIfElse(args) @@ -177,7 +214,7 @@ internal class VectorReplacementProcessor( } } - fun processUnaryOperation(handle: String, args: List): Triple? { + private fun processUnaryOperation(handle: String, args: List): Triple? { if (args.size != 1 && args.size != 2) return null val childExpr = args[0].getArgumentExpression() val masked = args.size == 2 @@ -198,7 +235,7 @@ internal class VectorReplacementProcessor( return Triple(vectorized, linear, isScalar) } - fun processBinaryOperation(handle: String, args: List): Triple? { + private fun processBinaryOperation(handle: String, args: List): Triple? { if (args.size != 2 && args.size != 3) return null val masked = args.size == 3 val leftExpr = args[0].getArgumentExpression() @@ -225,7 +262,7 @@ internal class VectorReplacementProcessor( return Triple(vectorized, linear, isScalar) } - fun processIfElse(args: List): Triple? { + private fun processIfElse(args: List): Triple? { if (args.size != 3) return null val mask = args[0].getArgumentExpression() ?: return null val left = args[1].getArgumentExpression() ?: return null @@ -240,14 +277,16 @@ internal class VectorReplacementProcessor( return Triple(vectorized, linear, isScalar) } - - fun processMask(expr: KtExpression?): Pair? { - if (expr !is KtCallExpression) return null + private fun processMask(expr: KtExpression?): Pair? { + if (expr !is KtCallExpression) { + val (vecReplacement, linReplacement, _) = processSimpleName(expr) ?: return null + return Pair(vecReplacement, linReplacement) + } val exprTypename = expr.fqTypename(context) ?: return null val handle = maskHandles[exprTypename.substringAfterLast('.')] ?: return null val args = expr.valueArguments - return when { - exprTypename in maskUnaryOpTypes -> { + return when (exprTypename) { + in maskUnaryOpTypes -> { if (args.size != 1) return null val child = args[0].getArgumentExpression() ?: return null val (vecReplacement, linReplacement) = processMask(child) ?: return null @@ -258,7 +297,7 @@ internal class VectorReplacementProcessor( ) } - exprTypename in maskBinaryOpTypes -> { + in maskBinaryOpTypes -> { if (args.size != 2) return null val left = args[0].getArgumentExpression() ?: return null val (leftVecReplacement, leftLinReplacement) = processMask(left) ?: return null @@ -270,7 +309,7 @@ internal class VectorReplacementProcessor( ) } - exprTypename in comparatorTypes -> { + in comparatorTypes -> { if (args.size != 2) return null val leftExpr = args[0].getArgumentExpression() ?: return null val rightExpr = args[1].getArgumentExpression() ?: return null @@ -278,9 +317,9 @@ internal class VectorReplacementProcessor( val (rightVector, rightLinear, rightScalar) = process(rightExpr) ?: return null val linear = comparatorReplacement[handle]?.invoke(leftLinear, rightLinear) ?: return null val vectorized = """ - $leftVector - .compare(VectorOperators.$handle, $rightVector) - """.trimIndent() + $leftVector + .compare(VectorOperators.$handle, $rightVector) + """.trimIndent() Pair(vectorized, linear) } From f17bc89836700e1620ea368f60e09d15f6d3d01f Mon Sep 17 00:00:00 2001 From: Tommaso Dossi Date: Mon, 22 Sep 2025 15:00:16 +0200 Subject: [PATCH 16/16] Removed unused code --- .../kotlin/io/kinference/primitives/PrimitivesExtension.kt | 2 -- .../io/kinference/primitives/PrimitivesGradlePlugin.kt | 1 - .../main/kotlin/io/kinference/primitives/PrimitivesTask.kt | 6 +----- .../kinference/primitives/generator/PrimitiveGenerator.kt | 1 - .../primitives/generator/processor/GenerationVisitor.kt | 5 ++--- .../primitives/generator/processor/ReplacementProcessor.kt | 6 +++--- 6 files changed, 6 insertions(+), 15 deletions(-) diff --git a/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/PrimitivesExtension.kt b/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/PrimitivesExtension.kt index 13e86ba..858ef51 100644 --- a/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/PrimitivesExtension.kt +++ b/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/PrimitivesExtension.kt @@ -12,8 +12,6 @@ open class PrimitivesExtension @Inject constructor( ) { private val objects = project.objects - var vectorize: Boolean = false - val generationPath: DirectoryProperty = objects.directoryProperty().convention( project.layout.buildDirectory.dir("generated/primitives") ) diff --git a/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/PrimitivesGradlePlugin.kt b/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/PrimitivesGradlePlugin.kt index b61afec..429d7f5 100644 --- a/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/PrimitivesGradlePlugin.kt +++ b/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/PrimitivesGradlePlugin.kt @@ -40,7 +40,6 @@ class PrimitivesGradlePlugin : Plugin { primitiveTask.inputFiles.from(compileTask.sources) primitiveTask.libraries.from(compileTask.libraries) primitiveTask.compilation.set(compilation) - primitiveTask.vectorize.set(primitivesExt.vectorize) } compileTask.dependsOn(primitivesTask) diff --git a/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/PrimitivesTask.kt b/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/PrimitivesTask.kt index f7156de..087f71a 100644 --- a/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/PrimitivesTask.kt +++ b/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/PrimitivesTask.kt @@ -21,9 +21,6 @@ import org.jetbrains.kotlin.gradle.plugin.* import java.io.File abstract class PrimitivesTask : DefaultTask() { - @get:Input - abstract val vectorize: Property - @get:Internal abstract val generationPath: DirectoryProperty @@ -63,7 +60,6 @@ abstract class PrimitivesTask : DefaultTask() { init { group = "generate" description = "Generates primitives from sources" - vectorize.convention(false) } @TaskAction @@ -114,7 +110,7 @@ abstract class PrimitivesTask : DefaultTask() { val sourceSet = findSourceSetName(ktFile.virtualFilePath) val outputDir = generationPath.dir(sourceSet).get().asFile - PrimitiveGenerator(ktFile, result.bindingContext, outputDir, messageCollector, vectorize.get()).generate() + PrimitiveGenerator(ktFile, result.bindingContext, outputDir, messageCollector).generate() primitivesCache.get().resolvedPaths.add(ktFile.virtualFilePath) } diff --git a/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/generator/PrimitiveGenerator.kt b/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/generator/PrimitiveGenerator.kt index dd46cae..ce85c8a 100644 --- a/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/generator/PrimitiveGenerator.kt +++ b/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/generator/PrimitiveGenerator.kt @@ -26,7 +26,6 @@ internal class PrimitiveGenerator( private val context: BindingContext, private val output: File, private val collector: MessageCollector, - private val vectorize: Boolean = false ) { private data class PrimitiveContext(val type1: Primitive<*, *>? = null, val type2: Primitive<*, *>? = null, val type3: Primitive<*, *>? = null) diff --git a/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/generator/processor/GenerationVisitor.kt b/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/generator/processor/GenerationVisitor.kt index 5588118..d6ed837 100644 --- a/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/generator/processor/GenerationVisitor.kt +++ b/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/generator/processor/GenerationVisitor.kt @@ -48,11 +48,10 @@ internal class GenerationVisitor( KtDefaultVisitor() { private data class PrimitiveContext(val type1: Primitive<*, *>? = null, val type2: Primitive<*, *>? = null, val type3: Primitive<*, *>? = null) - private val vectorize = true private val builder = StringBuilder() fun text() = builder.toString() val removalProcessor = RemovalProcessor(context) - val replacementProcessor = ReplacementProcessor(context, collector, file, vectorize) + val replacementProcessor = ReplacementProcessor(context, collector, file) private var currentPrimitive = primitive private var vecCount = 0 @@ -77,7 +76,7 @@ internal class GenerationVisitor( } override fun visitImportList(importList: KtImportList) { - if (file.isAnnotatedWith(context) && primitive.dataType in DataType.VECTORIZABLE.resolve() && vectorize) { + if (file.isAnnotatedWith(context) && primitive.dataType in DataType.VECTORIZABLE.resolve()) { builder.appendLine("import io.kinference.ndarray.VecUtils.isModuleLoaded") builder.appendLine("import jdk.incubator.vector.*") } diff --git a/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/generator/processor/ReplacementProcessor.kt b/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/generator/processor/ReplacementProcessor.kt index 60bb23e..fc9bc43 100644 --- a/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/generator/processor/ReplacementProcessor.kt +++ b/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/generator/processor/ReplacementProcessor.kt @@ -24,7 +24,7 @@ import org.jetbrains.kotlin.resolve.descriptorUtil.fqNameSafe import org.jetbrains.kotlin.types.typeUtil.supertypes internal class ReplacementProcessor( - private val context: BindingContext, private val collector: MessageCollector, private val file: KtFile, private val vectorize: Boolean = true + private val context: BindingContext, private val collector: MessageCollector, private val file: KtFile ) { companion object { internal fun toType(primitive: Primitive<*, *>): String { @@ -136,7 +136,7 @@ internal class ReplacementProcessor( $dest[$destOffset + $vecIdx] = $linReplacement.$toPrimitive }""".trimIndent() - if (primitive.dataType in DataType.VECTORIZABLE.resolve() && vectorize) return """ + if (primitive.dataType in DataType.VECTORIZABLE.resolve()) return """ if($vecEnabled) { val $vecLen = ${vecProcessor.vecSpecies}.length() val $vecEnd = $len - ($len % $vecLen) @@ -171,7 +171,7 @@ internal class ReplacementProcessor( ret.$toPrimitive """.trimIndent() - if (primitive.dataType in DataType.VECTORIZABLE.resolve() && vectorize) { + if (primitive.dataType in DataType.VECTORIZABLE.resolve()) { return """{ if($vecEnabled) { val $vecLen = ${vecProcessor.vecSpecies}.length()