diff --git a/README.md b/README.md index 9bef627..28728a3 100644 --- a/README.md +++ b/README.md @@ -65,3 +65,6 @@ This code would generate specializations for Float and Int types via replacement of `PrimitiveType` with corresponding type (Float or Int). Also, note that standard functions, like MAX_VALUE, are available for `PrimitiveType` like it would be a real `Number`. +Version 2.1.0 adds the possibility of generating vectorized code using +Java's [vector API](https://download.java.net/java/early_access/jdk25/docs/api/jdk.incubator.vector/module-summary.html). +Currently, this only works inside of KInference and in combination with primitive specialization. diff --git a/build.gradle.kts b/build.gradle.kts index df09ca6..f84e7b0 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -12,6 +12,7 @@ plugins { allprojects { repositories { + mavenLocal() mavenCentral() gradlePluginPortal() } @@ -19,14 +20,14 @@ allprojects { subprojects { tasks.withType(JavaCompile::class.java).all { - sourceCompatibility = JavaVersion.VERSION_1_8.toString() - targetCompatibility = JavaVersion.VERSION_1_8.toString() + sourceCompatibility = JavaVersion.VERSION_21.toString() + targetCompatibility = JavaVersion.VERSION_21.toString() } tasks.withType(KotlinCompilationTask::class.java).all { compilerOptions { if (this is KotlinJvmCompilerOptions) { - jvmTarget.set(JvmTarget.JVM_1_8) + jvmTarget.set(JvmTarget.JVM_21) } apiVersion.set(KotlinVersion.KOTLIN_2_0) diff --git a/gradle.properties b/gradle.properties index 571db53..16812ce 100644 --- a/gradle.properties +++ b/gradle.properties @@ -1,2 +1,2 @@ GROUP=io.kinference.primitives -VERSION=2.0.0-1 +VERSION=2.1.0-0 diff --git a/plugin-build/primitives-annotations/build.gradle.kts b/plugin-build/primitives-annotations/build.gradle.kts index 5a73798..1b7287f 100644 --- a/plugin-build/primitives-annotations/build.gradle.kts +++ b/plugin-build/primitives-annotations/build.gradle.kts @@ -13,3 +13,6 @@ kotlin { } } +repositories { + mavenCentral() +} diff --git a/plugin-build/primitives-annotations/src/commonMain/kotlin/io/kinference/primitives/annotations/GenerateVector.kt b/plugin-build/primitives-annotations/src/commonMain/kotlin/io/kinference/primitives/annotations/GenerateVector.kt new file mode 100644 index 0000000..fd013d2 --- /dev/null +++ b/plugin-build/primitives-annotations/src/commonMain/kotlin/io/kinference/primitives/annotations/GenerateVector.kt @@ -0,0 +1,8 @@ +package io.kinference.primitives.annotations + +import io.kinference.primitives.types.* +import io.kinference.primitives.vector.* + +/** Specify that any usage of [OperationNode] is subject for replacement in this file */ +@Target(AnnotationTarget.FILE) +annotation class GenerateVector(vararg val types: DataType) diff --git a/plugin-build/primitives-annotations/src/commonMain/kotlin/io/kinference/primitives/types/DataType.kt b/plugin-build/primitives-annotations/src/commonMain/kotlin/io/kinference/primitives/types/DataType.kt index dc4292c..f469292 100644 --- a/plugin-build/primitives-annotations/src/commonMain/kotlin/io/kinference/primitives/types/DataType.kt +++ b/plugin-build/primitives-annotations/src/commonMain/kotlin/io/kinference/primitives/types/DataType.kt @@ -31,16 +31,18 @@ enum class DataType { BOOLEAN, ALL, - NUMBER; + NUMBER, + VECTORIZABLE; /** * Resolve DataType into actual primitives -- would flatten groups into collection of primitives. * Primitives would remain the same */ fun resolve(): Set { - return when(this) { + return when (this) { ALL -> setOf(BYTE, SHORT, INT, LONG, UBYTE, USHORT, UINT, ULONG, FLOAT, DOUBLE, BOOLEAN) NUMBER -> setOf(BYTE, SHORT, INT, LONG, UBYTE, USHORT, UINT, ULONG, FLOAT, DOUBLE) + VECTORIZABLE -> setOf(BYTE, SHORT, INT, LONG, FLOAT, DOUBLE) else -> setOf(this) } } diff --git a/plugin-build/primitives-annotations/src/commonMain/kotlin/io/kinference/primitives/vector/OperationNode.kt b/plugin-build/primitives-annotations/src/commonMain/kotlin/io/kinference/primitives/vector/OperationNode.kt new file mode 100644 index 0000000..478734d --- /dev/null +++ b/plugin-build/primitives-annotations/src/commonMain/kotlin/io/kinference/primitives/vector/OperationNode.kt @@ -0,0 +1,113 @@ +@file:Suppress("Unused", "UnusedReceiverParameter") + +package io.kinference.primitives.vector + +import io.kinference.primitives.types.PrimitiveArray +import io.kinference.primitives.types.PrimitiveType +import io.kinference.primitives.types.toPrimitive +import io.kinference.primitives.vector.Add +import io.kinference.primitives.vector.BinaryOp +import io.kinference.primitives.vector.Sub +import io.kinference.primitives.vector.UnaryOp + +sealed class OpNode() { + public fun into(dest: PrimitiveArray, offset: Int, len: Int): Unit = + throw UnsupportedOperationException() + + public fun reduce(operation: AssociativeWrapper, len: Int): PrimitiveType = + throw UnsupportedOperationException() + +} + +final class PrimitiveSlice(val src: PrimitiveArray, val offset: Int = 0) : OpNode() {} + +final class Value(val value: PrimitiveType) : OpNode() {} + +sealed class UnaryOp(val arg: OpNode) : OpNode() { + constructor(arg: OpNode, mask: VecMask) : this(arg) +} + +sealed class BinaryOp(val left: OpNode, val right: OpNode) : OpNode() { + constructor(left: OpNode, right: OpNode, mask: VecMask) : this(left, right) +} + +class IfElse(val condition: VecMask, val left: OpNode, val right: OpNode) : OpNode() {} + +sealed class AssociativeWrapper() {} + +class Exp(arg: OpNode) : UnaryOp(arg) { + constructor(arg: OpNode, mask: VecMask) : this(arg) +} + +class Abs(arg: OpNode) : UnaryOp(arg) { + constructor(arg: OpNode, mask: VecMask) : this(arg) +} + +class Neg(arg: OpNode) : UnaryOp(arg) { + constructor(arg: OpNode, mask: VecMask) : this(arg) +} + +class Log(arg: OpNode) : UnaryOp(arg) { + constructor(arg: OpNode, mask: VecMask) : this(arg) +} + +class Sqrt(arg: OpNode) : UnaryOp(arg) { + constructor(arg: OpNode, mask: VecMask) : this(arg) +} + +class Cbrt(arg: OpNode) : UnaryOp(arg) { + constructor(arg: OpNode, mask: VecMask) : this(arg) +} + +class Add(left: OpNode, right: OpNode) : BinaryOp(left, right) { + constructor(left: OpNode, right: OpNode, mask: VecMask) : this(left, right) +} + +class Sub(left: OpNode, right: OpNode) : BinaryOp(left, right) { + constructor(left: OpNode, right: OpNode, mask: VecMask) : this(left, right) +} + +class Mul(left: OpNode, right: OpNode) : BinaryOp(left, right) { + constructor(left: OpNode, right: OpNode, mask: VecMask) : this(left, right) +} + +class Div(left: OpNode, right: OpNode) : BinaryOp(left, right) { + constructor(left: OpNode, right: OpNode, mask: VecMask) : this(left, right) +} + +class Pow(left: OpNode, right: OpNode) : BinaryOp(left, right) { + constructor(left: OpNode, right: OpNode, mask: VecMask) : this(left, right) +} + +class Max(left: OpNode, right: OpNode) : BinaryOp(left, right) { + constructor(left: OpNode, right: OpNode, mask: VecMask) : this(left, right) +} + +class Min(left: OpNode, right: OpNode) : BinaryOp(left, right) { + constructor(left: OpNode, right: OpNode, mask: VecMask) : this(left, right) +} + + +object ADD : AssociativeWrapper() {} +object MUL : AssociativeWrapper() {} +object MAX : AssociativeWrapper() {} +object MIN : AssociativeWrapper() {} + +sealed class VecMask() {} + +sealed class Comparator(left: OpNode, right: OpNode) : VecMask() {} +sealed class MaskBinaryOp(left: VecMask, right: VecMask) : VecMask() {} +sealed class MaskUnaryOp(arg: VecMask) : VecMask() {} + +class Not(arg: VecMask) : VecMask() {} + +class Eq(left: OpNode, right: OpNode) : Comparator(left, right) {} +class Neq(left: OpNode, right: OpNode) : Comparator(left, right) {} +class LT(left: OpNode, right: OpNode) : Comparator(left, right) {} +class LE(left: OpNode, right: OpNode) : Comparator(left, right) {} +class GT(left: OpNode, right: OpNode) : Comparator(left, right) {} +class GE(left: OpNode, right: OpNode) : Comparator(left, right) {} + +class And(left: VecMask, right: VecMask) : MaskBinaryOp(left, right) {} +class Or(left: VecMask, right: VecMask) : MaskBinaryOp(left, right) {} +class Xor(left: VecMask, right: VecMask) : MaskBinaryOp(left, right) {} diff --git a/plugin-build/primitives-plugin/build.gradle.kts b/plugin-build/primitives-plugin/build.gradle.kts index b00aa4b..3fc2c11 100644 --- a/plugin-build/primitives-plugin/build.gradle.kts +++ b/plugin-build/primitives-plugin/build.gradle.kts @@ -24,3 +24,6 @@ gradlePlugin { } } } +repositories { + mavenCentral() +} diff --git a/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/PrimitivesExtension.kt b/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/PrimitivesExtension.kt index c537ca6..858ef51 100644 --- a/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/PrimitivesExtension.kt +++ b/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/PrimitivesExtension.kt @@ -1,7 +1,10 @@ package io.kinference.primitives +import com.sun.org.apache.xpath.internal.operations.Bool import org.gradle.api.Project import org.gradle.api.file.DirectoryProperty +import org.gradle.api.provider.Property +import org.jetbrains.kotlin.ir.declarations.DescriptorMetadataSource import javax.inject.Inject open class PrimitivesExtension @Inject constructor( diff --git a/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/PrimitivesGradlePlugin.kt b/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/PrimitivesGradlePlugin.kt index 243f94f..429d7f5 100644 --- a/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/PrimitivesGradlePlugin.kt +++ b/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/PrimitivesGradlePlugin.kt @@ -12,7 +12,6 @@ class PrimitivesGradlePlugin : Plugin { val primitivesExt = project.extensions.create(extensionName, PrimitivesExtension::class.java) - val primitivesCache = project.gradle.sharedServices.registerIfAbsent("${project.path}_${primitivesCacheName}", PrimitivesCache::class.java) { it.maxParallelUsages.set(1) } @@ -21,6 +20,7 @@ class PrimitivesGradlePlugin : Plugin { it.group = "generate" } + kotlinExt.sourceSets.all { sourceSet -> sourceSet.kotlin.srcDir(primitivesExt.generationPath.dir(sourceSet.name)) primitivesCache.get().sourceSetToResolved[sourceSet.name] = false diff --git a/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/PrimitivesTask.kt b/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/PrimitivesTask.kt index bc3b02c..087f71a 100644 --- a/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/PrimitivesTask.kt +++ b/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/PrimitivesTask.kt @@ -11,6 +11,8 @@ import org.gradle.api.provider.Property import org.gradle.api.tasks.* import org.gradle.work.NormalizeLineEndings import org.jetbrains.kotlin.cli.common.config.addKotlinSourceRoot +import org.jetbrains.kotlin.cli.common.messages.CompilerMessageSeverity +import org.jetbrains.kotlin.cli.common.messages.CompilerMessageSourceLocation import org.jetbrains.kotlin.cli.common.messages.MessageCollector import org.jetbrains.kotlin.cli.jvm.config.addJvmClasspathRoots import org.jetbrains.kotlin.gradle.dsl.KotlinMultiplatformExtension @@ -38,6 +40,23 @@ abstract class PrimitivesTask : DefaultTask() { @get:Internal abstract val primitivesCache: Property + private val messageCollector = object : MessageCollector { + override fun clear() {} + + override fun hasErrors(): Boolean = false + + override fun report(severity: CompilerMessageSeverity, message: String, location: CompilerMessageSourceLocation?) { + when (severity) { + CompilerMessageSeverity.STRONG_WARNING -> logger.error("$message: $location") + CompilerMessageSeverity.ERROR -> logger.error("$message: $location") + CompilerMessageSeverity.INFO -> logger.info(message) + CompilerMessageSeverity.WARNING -> Unit + else -> logger.debug(message) + } + } + } + + init { group = "generate" description = "Generates primitives from sources" @@ -64,9 +83,9 @@ abstract class PrimitivesTask : DefaultTask() { FileWithCommon(source, isCommon) } - val analyzeFun = when(compilation.get().platformType) { - KotlinPlatformType.jvm -> Analyze::analyzeJvmSources - KotlinPlatformType.js -> Analyze::analyzeJsSources + val analyzeFun = when (compilation.get().platformType) { + KotlinPlatformType.jvm -> Analyze::analyzeJvmSources + KotlinPlatformType.js -> Analyze::analyzeJsSources KotlinPlatformType.common -> Analyze::analyzeCommonSources else -> error("Unsupported platform type ${compilation.get().platformType}") } @@ -85,11 +104,13 @@ abstract class PrimitivesTask : DefaultTask() { val annotated = ktSources.filter { it.isAnnotatedWith(result.bindingContext) } val notGeneratedYet = annotated.filterNot { it.virtualFilePath in primitivesCache.get().resolvedPaths } + + for (ktFile in notGeneratedYet) { val sourceSet = findSourceSetName(ktFile.virtualFilePath) val outputDir = generationPath.dir(sourceSet).get().asFile - PrimitiveGenerator(ktFile, result.bindingContext, outputDir, MessageCollector.NONE).generate() + PrimitiveGenerator(ktFile, result.bindingContext, outputDir, messageCollector).generate() primitivesCache.get().resolvedPaths.add(ktFile.virtualFilePath) } diff --git a/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/generator/PrimitiveGenerator.kt b/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/generator/PrimitiveGenerator.kt index a55494e..ce85c8a 100644 --- a/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/generator/PrimitiveGenerator.kt +++ b/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/generator/PrimitiveGenerator.kt @@ -2,27 +2,36 @@ package io.kinference.primitives.generator import io.kinference.primitives.annotations.* import io.kinference.primitives.generator.errors.require +import io.kinference.primitives.generator.processor.GenerationVisitor import io.kinference.primitives.generator.processor.RemovalProcessor import io.kinference.primitives.generator.processor.ReplacementProcessor +import io.kinference.primitives.generator.processor.VectorReplacementProcessor +import io.kinference.primitives.types.DataType import io.kinference.primitives.utils.crossProduct import io.kinference.primitives.utils.psi.* import org.jetbrains.kotlin.cli.common.messages.CompilerMessageSeverity import org.jetbrains.kotlin.cli.common.messages.MessageCollector import org.jetbrains.kotlin.com.intellij.psi.PsiWhiteSpace import org.jetbrains.kotlin.com.intellij.psi.impl.source.tree.LeafPsiElement +import org.jetbrains.kotlin.js.descriptorUtils.getKotlinTypeFqName import org.jetbrains.kotlin.lexer.KtTokens import org.jetbrains.kotlin.psi.* import org.jetbrains.kotlin.psi.psiUtil.visibilityModifier import org.jetbrains.kotlin.resolve.BindingContext +import org.jetbrains.kotlin.types.typeUtil.supertypes import java.io.File internal class PrimitiveGenerator( - private val file: KtFile, private val context: BindingContext, private val output: File, - private val collector: MessageCollector + private val file: KtFile, + private val context: BindingContext, + private val output: File, + private val collector: MessageCollector, ) { private data class PrimitiveContext(val type1: Primitive<*, *>? = null, val type2: Primitive<*, *>? = null, val type3: Primitive<*, *>? = null) + private var vecCount = 0; + fun generate(): Set { val results = HashSet() @@ -32,173 +41,18 @@ internal class PrimitiveGenerator( } for (primitive in types.flatMap { it.toPrimitive() }.toSet()) { - val builder = StringBuilder() - - val removalProcessor = RemovalProcessor(context) - val replacementProcessor = ReplacementProcessor(context, collector) - - file.accept(object : KtDefaultVisitor() { - private var currentPrimitive = primitive - - private fun KtElement.withPrimitive(primitive: Primitive<*, *>?, body: () -> Unit) { - collector.require(CompilerMessageSeverity.ERROR, this, primitive != null) { - "Primitive was bound with @${BindPrimitives::class.simpleName} sub-annotation," + - " but outer expression is not annotated with @${BindPrimitives::class.simpleName}" - } - - val tmp = currentPrimitive - currentPrimitive = primitive - body() - currentPrimitive = tmp - } - - private var primitiveContext = PrimitiveContext() - private fun withContext(context: PrimitiveContext, body: () -> Unit) { - val tmp = primitiveContext - primitiveContext = context - body() - primitiveContext = tmp - } - - override fun visitModifierList(list: KtModifierList) { - if (replacementProcessor.shouldChangeVisibilityModifier(list)) { - replacementProcessor.prepareReplaceText(list.visibilityModifier(), "public") - } - - super.visitModifierList(list) - } - - override fun visitWhiteSpace(space: PsiWhiteSpace) { - if (removalProcessor.shouldRemoveWhiteSpace(space)) return - - super.visitWhiteSpace(space) - } - - override fun visitAnnotationEntry(annotationEntry: KtAnnotationEntry) { - if (removalProcessor.shouldRemoveAnnotation(annotationEntry)) { - removalProcessor.prepareRemoval(annotationEntry) - return - } - - builder.append(annotationEntry.text) - } - - override fun visitImportDirective(importDirective: KtImportDirective) { - if (removalProcessor.shouldRemoveImport(importDirective)) { - removalProcessor.prepareRemoval(importDirective) - return - } - - return super.visitImportDirective(importDirective) - } - - override fun visitNamedFunction(function: KtNamedFunction) { - if (primitive.dataType in function.getExcludes(context)) return - if (function.isAnnotatedWith(context) && primitive.dataType !in function.getIncludes(context)!!) return - - if (function.isAnnotatedWith(context)) { - for (annotation in function.annotationEntries.filter { it.isAnnotation(context) }) { - val primitives1 = annotation.getTypes(context, BindPrimitives::type1).flatMap { it.toPrimitive() }.toSet() - val primitives2 = annotation.getTypes(context, BindPrimitives::type2).flatMap { it.toPrimitive() }.toSet() - val primitives3 = annotation.getTypes(context, BindPrimitives::type3).flatMap { it.toPrimitive() }.toSet() - - - collector.require( - CompilerMessageSeverity.WARNING, annotation, - primitives1.isNotEmpty() || primitives2.isNotEmpty() || primitives3.isNotEmpty() - ) { - "All arguments of @${BindPrimitives::class.simpleName} are empty. It would lead to omitting of the function during generation." - } - - val combinations = crossProduct(primitives1, primitives2, primitives3) - - for (combination in combinations) { - var index = 0 - val primitive1 = if (primitives1.isEmpty()) null else combination[index++] - val primitive2 = if (primitives2.isEmpty()) null else combination[index++] - val primitive3 = if (primitives3.isEmpty()) null else combination[index] - - withContext(PrimitiveContext(primitive1, primitive2, primitive3)) { - super.visitNamedFunction(function) - } - builder.append('\n') - } - } - } else { - super.visitNamedFunction(function) - } - } - - override fun visitClass(klass: KtClass) { - if (primitive.dataType in klass.getExcludes(context)) return - if (klass.isAnnotatedWith(context) && primitive.dataType !in klass.getIncludes(context)!!) return - - super.visitClass(klass) - } - - override fun visitTypeReference(typeReference: KtTypeReference) { - when { - typeReference.isAnnotatedWith(context) -> typeReference.withPrimitive(primitiveContext.type1) { - super.visitTypeReference(typeReference) - } - typeReference.isAnnotatedWith(context) -> typeReference.withPrimitive(primitiveContext.type2) { - super.visitTypeReference(typeReference) - } - typeReference.isAnnotatedWith(context) -> typeReference.withPrimitive(primitiveContext.type3) { - super.visitTypeReference(typeReference) - } - else -> super.visitTypeReference(typeReference) - } - } - - - override fun visitAnnotatedExpression(expression: KtAnnotatedExpression) { - when { - expression.isAnnotatedWith(context) -> expression.withPrimitive(primitiveContext.type1) { - super.visitAnnotatedExpression(expression) - } - expression.isAnnotatedWith(context) -> expression.withPrimitive(primitiveContext.type2) { - super.visitAnnotatedExpression(expression) - } - expression.isAnnotatedWith(context) -> expression.withPrimitive(primitiveContext.type3) { - super.visitAnnotatedExpression(expression) - } - else -> super.visitAnnotatedExpression(expression) - } - } - - override fun visitLeafElement(element: LeafPsiElement) { - if (replacementProcessor.haveReplaceText(element)) { - builder.append(replacementProcessor.getReplacement(element)) - return - } - - if (element.elementType != KtTokens.IDENTIFIER) { - builder.append(element.text) - return - } - - when (val parent = element.parent) { - is KtClassOrObject -> builder.append(replacementProcessor.getReplacement(parent, currentPrimitive) ?: element.text) - is KtNamedFunction -> builder.append(replacementProcessor.getReplacement(parent, currentPrimitive) ?: element.text) - else -> builder.append(element.text) - } - } - - override fun visitSimpleNameExpression(expression: KtSimpleNameExpression) { - val replacement = replacementProcessor.getReplacement(expression, currentPrimitive) - builder.append(replacement ?: expression.text) - } - }) + val visitor = GenerationVisitor(primitive, context, collector, file) + file.accept(visitor) + val text = visitor.text() - if (builder.isNotBlank()) { + if (text.isNotBlank()) { val file = File( output, "${file.packageFqName.asString().replace('.', '/')}/${file.name.replace("Primitive", primitive.typeName)}" ) results.add(file) file.parentFile.mkdirs() - file.writeText(removalProcessor.reformat(builder.toString())) + file.writeText(visitor.removalProcessor.reformat(text)) } } diff --git a/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/generator/Utils.kt b/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/generator/Utils.kt index 45602a3..f78850a 100644 --- a/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/generator/Utils.kt +++ b/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/generator/Utils.kt @@ -2,16 +2,21 @@ package io.kinference.primitives.generator import io.kinference.primitives.annotations.* import io.kinference.primitives.generator.errors.require +import io.kinference.primitives.generator.processor.VectorReplacementProcessor import io.kinference.primitives.types.DataType import io.kinference.primitives.utils.psi.* import org.jetbrains.kotlin.cli.common.messages.CompilerMessageSeverity import org.jetbrains.kotlin.cli.common.messages.MessageCollector import org.jetbrains.kotlin.descriptors.ClassConstructorDescriptor import org.jetbrains.kotlin.descriptors.DeclarationDescriptor +import org.jetbrains.kotlin.js.descriptorUtils.getKotlinTypeFqName import org.jetbrains.kotlin.js.resolve.diagnostics.findPsi +import org.jetbrains.kotlin.load.kotlin.toSourceElement import org.jetbrains.kotlin.psi.* import org.jetbrains.kotlin.resolve.BindingContext import org.jetbrains.kotlin.resolve.isValueClass +import org.jetbrains.kotlin.resolve.source.getPsi +import org.jetbrains.kotlin.types.typeUtil.supertypes import kotlin.reflect.KProperty @@ -45,13 +50,24 @@ internal fun KtAnnotationEntry.isPluginAnnotation(context: BindingContext): Bool isAnnotation(context) || isAnnotation(context) || isAnnotation(context) || - isAnnotation(context) + isAnnotation(context) || + isAnnotation(context) } internal fun DeclarationDescriptor.isNamedFunction() = findPsi() is KtNamedFunction internal fun DeclarationDescriptor.isKtClassOrObject() = findPsi() is KtClassOrObject || isValueClass() internal fun DeclarationDescriptor.isCompanion() = findPsi() is KtObjectDeclaration && containingDeclaration?.findPsi() is KtClass -internal fun DeclarationDescriptor.isConstructor() = this is ClassConstructorDescriptor || findPsi() is KtConstructor<*> && containingDeclaration?.findPsi() is KtClass +internal fun DeclarationDescriptor.isConstructor() = + this is ClassConstructorDescriptor || findPsi() is KtConstructor<*> && containingDeclaration?.findPsi() is KtClass + +internal fun KtSimpleNameExpression.initializer(context: BindingContext): KtExpression? { + val descriptor = context.get(BindingContext.REFERENCE_TARGET, this) ?: return null + val declaration = descriptor.toSourceElement.getPsi() ?: return null + if (declaration !is KtProperty) return null + return declaration.initializer +} + +internal fun KtExpression.fqTypename(context: BindingContext): String? = context.getType(this)?.getKotlinTypeFqName(false) internal fun KtNamedDeclaration.specialize(primitive: Primitive<*, *>, collector: MessageCollector): String { val name = name!! @@ -63,3 +79,9 @@ internal fun KtNamedDeclaration.specialize(primitive: Primitive<*, *>, collector internal fun String.specialize(primitive: Primitive<*, *>) = replace("Primitive", primitive.typeName) +internal fun isVectorClass(expr: KtExpression?, context: BindingContext): Boolean { + if (expr == null) return false + val type = context.getType(expr) ?: return false + val fqSupertypes = type.supertypes().map { it.getKotlinTypeFqName(false) } + return VectorReplacementProcessor.opNodeTypename in fqSupertypes +} diff --git a/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/generator/errors/KtElementError.kt b/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/generator/errors/KtElementError.kt index 5ad19dd..73b9861 100644 --- a/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/generator/errors/KtElementError.kt +++ b/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/generator/errors/KtElementError.kt @@ -21,7 +21,7 @@ internal fun MessageCollector.require(severity: CompilerMessageSeverity, element } } -private fun KtElement.getLocation(): CompilerMessageSourceLocation? { +fun KtElement.getLocation(): CompilerMessageSourceLocation? { val lineToColumn = if (this !is KtFile) StringUtil.offsetToLineColumn(containingKtFile.text, textOffset) else null return CompilerMessageLocation.create(containingKtFile.virtualFilePath, lineToColumn?.line ?: 1, lineToColumn?.line ?: 1, null) } diff --git a/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/generator/processor/GenerationVisitor.kt b/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/generator/processor/GenerationVisitor.kt new file mode 100644 index 0000000..d6ed837 --- /dev/null +++ b/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/generator/processor/GenerationVisitor.kt @@ -0,0 +1,243 @@ +package io.kinference.primitives.generator.processor + +import io.kinference.primitives.annotations.BindPrimitives +import io.kinference.primitives.annotations.GenerateVector +import io.kinference.primitives.annotations.SpecifyPrimitives +import io.kinference.primitives.generator.Primitive +import io.kinference.primitives.generator.PrimitiveGenerator.PrimitiveContext +import io.kinference.primitives.generator.errors.require +import io.kinference.primitives.generator.getExcludes +import io.kinference.primitives.generator.getIncludes +import io.kinference.primitives.generator.getTypes +import io.kinference.primitives.generator.isVectorClass +import io.kinference.primitives.generator.toPrimitive +import io.kinference.primitives.types.DataType +import io.kinference.primitives.utils.crossProduct +import io.kinference.primitives.utils.psi.KtDefaultVisitor +import io.kinference.primitives.utils.psi.isAnnotatedWith +import io.kinference.primitives.utils.psi.isAnnotation +import org.jetbrains.kotlin.cli.common.messages.CompilerMessageSeverity +import org.jetbrains.kotlin.cli.common.messages.MessageCollector +import org.jetbrains.kotlin.com.intellij.psi.PsiWhiteSpace +import org.jetbrains.kotlin.com.intellij.psi.impl.source.tree.LeafPsiElement +import org.jetbrains.kotlin.lexer.KtTokens +import org.jetbrains.kotlin.psi.KtAnnotatedExpression +import org.jetbrains.kotlin.psi.KtAnnotationEntry +import org.jetbrains.kotlin.psi.KtClass +import org.jetbrains.kotlin.psi.KtClassOrObject +import org.jetbrains.kotlin.psi.KtDeclaration +import org.jetbrains.kotlin.psi.KtDotQualifiedExpression +import org.jetbrains.kotlin.psi.KtElement +import org.jetbrains.kotlin.psi.KtFile +import org.jetbrains.kotlin.psi.KtImportDirective +import org.jetbrains.kotlin.psi.KtImportList +import org.jetbrains.kotlin.psi.KtModifierList +import org.jetbrains.kotlin.psi.KtNamedFunction +import org.jetbrains.kotlin.psi.KtProperty +import org.jetbrains.kotlin.psi.KtSimpleNameExpression +import org.jetbrains.kotlin.psi.KtTypeReference +import org.jetbrains.kotlin.psi.psiUtil.visibilityModifier +import org.jetbrains.kotlin.resolve.BindingContext + +internal class GenerationVisitor( + private val primitive: Primitive<*, *>, + private val context: BindingContext, + private val collector: MessageCollector, + private val file: KtFile +) : + KtDefaultVisitor() { + private data class PrimitiveContext(val type1: Primitive<*, *>? = null, val type2: Primitive<*, *>? = null, val type3: Primitive<*, *>? = null) + + private val builder = StringBuilder() + fun text() = builder.toString() + val removalProcessor = RemovalProcessor(context) + val replacementProcessor = ReplacementProcessor(context, collector, file) + private var currentPrimitive = primitive + private var vecCount = 0 + + private fun KtElement.withPrimitive(primitive: Primitive<*, *>?, body: () -> Unit) { + collector.require(CompilerMessageSeverity.ERROR, this, primitive != null) { + "Primitive was bound with @${BindPrimitives::class.simpleName} sub-annotation," + + " but outer expression is not annotated with @${BindPrimitives::class.simpleName}" + } + + val tmp = currentPrimitive + currentPrimitive = primitive + body() + currentPrimitive = tmp + } + + private var primitiveContext = PrimitiveContext() + private fun withContext(context: PrimitiveContext, body: () -> Unit) { + val tmp = primitiveContext + primitiveContext = context + body() + primitiveContext = tmp + } + + override fun visitImportList(importList: KtImportList) { + if (file.isAnnotatedWith(context) && primitive.dataType in DataType.VECTORIZABLE.resolve()) { + builder.appendLine("import io.kinference.ndarray.VecUtils.isModuleLoaded") + builder.appendLine("import jdk.incubator.vector.*") + } + super.visitImportList(importList) + } + + override fun visitModifierList(list: KtModifierList) { + if (replacementProcessor.shouldChangeVisibilityModifier(list)) { + replacementProcessor.prepareReplaceText(list.visibilityModifier(), "public") + } + + super.visitModifierList(list) + } + + override fun visitWhiteSpace(space: PsiWhiteSpace) { + if (removalProcessor.shouldRemoveWhiteSpace(space)) return + + super.visitWhiteSpace(space) + } + + override fun visitAnnotationEntry(annotationEntry: KtAnnotationEntry) { + if (removalProcessor.shouldRemoveAnnotation(annotationEntry)) { + removalProcessor.prepareRemoval(annotationEntry) + return + } + + builder.append(annotationEntry.text) + } + + override fun visitImportDirective(importDirective: KtImportDirective) { + if (removalProcessor.shouldRemoveImport(importDirective)) { + removalProcessor.prepareRemoval(importDirective) + return + } + + return super.visitImportDirective(importDirective) + } + + override fun visitNamedFunction(function: KtNamedFunction) { + if (primitive.dataType in function.getExcludes(context)) return + if (function.isAnnotatedWith(context) && primitive.dataType !in function.getIncludes(context)!!) return + + if (function.isAnnotatedWith(context)) { + for (annotation in function.annotationEntries.filter { it.isAnnotation(context) }) { + val primitives1 = annotation.getTypes(context, BindPrimitives::type1).flatMap { it.toPrimitive() }.toSet() + val primitives2 = annotation.getTypes(context, BindPrimitives::type2).flatMap { it.toPrimitive() }.toSet() + val primitives3 = annotation.getTypes(context, BindPrimitives::type3).flatMap { it.toPrimitive() }.toSet() + + + collector.require( + CompilerMessageSeverity.WARNING, annotation, + primitives1.isNotEmpty() || primitives2.isNotEmpty() || primitives3.isNotEmpty() + ) { + "All arguments of @${BindPrimitives::class.simpleName} are empty. It would lead to omitting of the function during generation." + } + + val combinations = crossProduct(primitives1, primitives2, primitives3) + + for (combination in combinations) { + var index = 0 + val primitive1 = if (primitives1.isEmpty()) null else combination[index++] + val primitive2 = if (primitives2.isEmpty()) null else combination[index++] + val primitive3 = if (primitives3.isEmpty()) null else combination[index] + + withContext(PrimitiveContext(primitive1, primitive2, primitive3)) { + super.visitNamedFunction(function) + } + builder.append('\n') + } + } + } else { + super.visitNamedFunction(function) + } + } + + override fun visitClass(klass: KtClass) { + if (primitive.dataType in klass.getExcludes(context)) return + if (klass.isAnnotatedWith(context) && primitive.dataType !in klass.getIncludes(context)!!) return + + super.visitClass(klass) + } + + override fun visitTypeReference(typeReference: KtTypeReference) { + when { + typeReference.isAnnotatedWith(context) -> typeReference.withPrimitive(primitiveContext.type1) { + super.visitTypeReference(typeReference) + } + + typeReference.isAnnotatedWith(context) -> typeReference.withPrimitive(primitiveContext.type2) { + super.visitTypeReference(typeReference) + } + + typeReference.isAnnotatedWith(context) -> typeReference.withPrimitive(primitiveContext.type3) { + super.visitTypeReference(typeReference) + } + + else -> super.visitTypeReference(typeReference) + } + } + + + override fun visitAnnotatedExpression(expression: KtAnnotatedExpression) { + when { + expression.isAnnotatedWith(context) -> expression.withPrimitive(primitiveContext.type1) { + super.visitAnnotatedExpression(expression) + } + + expression.isAnnotatedWith(context) -> expression.withPrimitive(primitiveContext.type2) { + super.visitAnnotatedExpression(expression) + } + + expression.isAnnotatedWith(context) -> expression.withPrimitive(primitiveContext.type3) { + super.visitAnnotatedExpression(expression) + } + + else -> super.visitAnnotatedExpression(expression) + } + } + + override fun visitLeafElement(element: LeafPsiElement) { + if (replacementProcessor.haveReplaceText(element)) { + builder.append(replacementProcessor.getReplacement(element)) + return + } + + if (element.elementType != KtTokens.IDENTIFIER) { + builder.append(element.text) + return + } + + when (val parent = element.parent) { + is KtClassOrObject -> builder.append(replacementProcessor.getReplacement(parent, currentPrimitive) ?: element.text) + is KtNamedFunction -> builder.append(replacementProcessor.getReplacement(parent, currentPrimitive) ?: element.text) + else -> builder.append(element.text) + } + } + + override fun visitDeclaration(dcl: KtDeclaration) { + if (file.isAnnotatedWith(context) && dcl is KtProperty) { + val init = dcl.initializer + if (isVectorClass(init, context)) return + } + super.visitDeclaration(dcl) + } + + override fun visitSimpleNameExpression(expression: KtSimpleNameExpression) { + val replacement = replacementProcessor.getReplacement(expression, currentPrimitive) + builder.append(replacement ?: expression.text) + } + + override fun visitDotQualifiedExpression(expression: KtDotQualifiedExpression) { + if (!file.isAnnotatedWith(context)) { + super.visitDotQualifiedExpression(expression) + return + } + val replacement = replacementProcessor.getReplacement(expression, currentPrimitive, vecCount) + if (replacement == null) { + super.visitDotQualifiedExpression(expression); return + } else { + vecCount += 1 + builder.append(replacement) + } + } +} diff --git a/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/generator/processor/RemovalProcessor.kt b/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/generator/processor/RemovalProcessor.kt index 2b9dbb4..8aa6e21 100644 --- a/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/generator/processor/RemovalProcessor.kt +++ b/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/generator/processor/RemovalProcessor.kt @@ -4,6 +4,7 @@ import io.kinference.primitives.annotations.* import io.kinference.primitives.generator.isPluginAnnotation import io.kinference.primitives.types.PrimitiveArray import io.kinference.primitives.types.PrimitiveType +import io.kinference.primitives.vector.* import org.jetbrains.kotlin.com.intellij.openapi.util.Key import org.jetbrains.kotlin.com.intellij.psi.PsiElement import org.jetbrains.kotlin.com.intellij.psi.PsiWhiteSpace @@ -24,6 +25,7 @@ internal class RemovalProcessor(private val context: BindingContext) { GenerateNameFromPrimitives::class.qualifiedName, GeneratePrimitives::class.qualifiedName, SpecifyPrimitives::class.qualifiedName, + GenerateVector::class.qualifiedName, ) } diff --git a/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/generator/processor/ReplacementProcessor.kt b/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/generator/processor/ReplacementProcessor.kt index 9d7ec89..fc9bc43 100644 --- a/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/generator/processor/ReplacementProcessor.kt +++ b/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/generator/processor/ReplacementProcessor.kt @@ -1,10 +1,13 @@ package io.kinference.primitives.generator.processor import io.kinference.primitives.annotations.GenerateNameFromPrimitives +import io.kinference.primitives.annotations.GenerateVector import io.kinference.primitives.annotations.MakePublic import io.kinference.primitives.generator.* +import io.kinference.primitives.generator.errors.getLocation import io.kinference.primitives.generator.errors.require import io.kinference.primitives.types.* +import io.kinference.primitives.vector.* import io.kinference.primitives.utils.psi.forced import io.kinference.primitives.utils.psi.isAnnotatedWith import org.jetbrains.kotlin.cli.common.messages.CompilerMessageSeverity @@ -12,15 +15,19 @@ import org.jetbrains.kotlin.cli.common.messages.MessageCollector import org.jetbrains.kotlin.com.intellij.openapi.util.Key import org.jetbrains.kotlin.com.intellij.psi.PsiElement import org.jetbrains.kotlin.com.intellij.psi.impl.source.tree.LeafPsiElement +import org.jetbrains.kotlin.js.descriptorUtils.getKotlinTypeFqName import org.jetbrains.kotlin.psi.* import org.jetbrains.kotlin.psi.psiUtil.isExtensionDeclaration import org.jetbrains.kotlin.psi.psiUtil.visibilityModifier import org.jetbrains.kotlin.resolve.BindingContext import org.jetbrains.kotlin.resolve.descriptorUtil.fqNameSafe +import org.jetbrains.kotlin.types.typeUtil.supertypes -internal class ReplacementProcessor(private val context: BindingContext, private val collector: MessageCollector) { +internal class ReplacementProcessor( + private val context: BindingContext, private val collector: MessageCollector, private val file: KtFile +) { companion object { - private fun toType(primitive: Primitive<*, *>): String { + internal fun toType(primitive: Primitive<*, *>): String { return when (primitive.dataType) { DataType.BYTE -> "toInt().toByte" DataType.SHORT -> "toInt().toShort" @@ -42,19 +49,21 @@ internal class ReplacementProcessor(private val context: BindingContext, private (PrimitiveArray::class.qualifiedName!!) to { it.arrayTypeName }, (PrimitiveArray::class.qualifiedName!! + ".") to { it.arrayTypeName }, - (PrimitiveArray::class.qualifiedName!! + ".Companion") to { it.arrayTypeName } - ) + (PrimitiveArray::class.qualifiedName!! + ".Companion") to { it.arrayTypeName }) + } fun getReplacement(klass: KtClassOrObject, primitive: Primitive<*, *>): String? { - if (klass.isAnnotatedWith(context)) { - return klass.specialize(primitive, collector) - } - collector.require(CompilerMessageSeverity.WARNING, klass, !klass.isTopLevel()) { - "Class is not annotated with ${GenerateNameFromPrimitives::class.simpleName}, so its name would not be specialized. It may lead to redeclaration compile error." + if (!klass.isAnnotatedWith(context)) { + collector.require(CompilerMessageSeverity.WARNING, klass, !klass.isTopLevel()) { + "Class is not annotated with ${GenerateNameFromPrimitives::class.simpleName}, so its name would not be specialized. It may lead to redeclaration compile error." + } + return null } - return null + + return klass.specialize(primitive, collector) + } fun getReplacement(function: KtNamedFunction, primitive: Primitive<*, *>): String? { @@ -77,18 +86,124 @@ internal class ReplacementProcessor(private val context: BindingContext, private defaultReplacements[type]!!.invoke(primitive) } - (target.isKtClassOrObject() && target.containingDeclaration!!.isAnnotatedWith()) || - (target.isNamedFunction() || target.isKtClassOrObject()) && target.isAnnotatedWith() -> { + (target.isKtClassOrObject() && target.containingDeclaration!!.isAnnotatedWith()) || (target.isNamedFunction() || target.isKtClassOrObject()) && target.isAnnotatedWith() -> { expression.text.specialize(primitive) } (target.isCompanion() || target.isConstructor()) && target.containingDeclaration!!.isAnnotatedWith() -> { name.specialize(primitive) } + else -> null } } + fun getReplacement(expr: KtDotQualifiedExpression, primitive: Primitive<*, *>, idx: Int): String? { + val receiver = expr.receiverExpression + val selector = expr.selectorExpression ?: return null + if (selector !is KtCallExpression) return null + + val args = selector.valueArguments + val callName = selector.calleeExpression?.text ?: return null + + if (!isVectorClass(receiver, context)) return null + + val vecProcessor = VectorReplacementProcessor(context, primitive, collector, file) + val res = vecProcessor.process(receiver) + if (res == null) { + collector.report( + CompilerMessageSeverity.STRONG_WARNING, + "Could not process vectorized expression, the code will not be generated", + expr.getLocation() + ) + return "" + } + val (vecReplacement, linReplacement, isScalar) = res + + val toPrimitive = "${toType(primitive)}()" + val vecLen = "_vecLen_$idx" + val vecEnd = "_vecEnd_$idx" + val vecIdx = "_vec_internal_idx" + val vecEnabled = "isModuleLoaded" + + if (callName == "into" && args.size == 3) { + val dest = args[0].text + val destOffset = args[1].text + val len = args[2].text + val linearCode = """ + for($vecIdx in 0 until $len) { + ${vecProcessor.linDeclarations} + $dest[$destOffset + $vecIdx] = $linReplacement.$toPrimitive + }""".trimIndent() + + if (primitive.dataType in DataType.VECTORIZABLE.resolve()) return """ + if($vecEnabled) { + val $vecLen = ${vecProcessor.vecSpecies}.length() + val $vecEnd = $len - ($len % $vecLen) + ${vecProcessor.scalarDeclarations} + for ($vecIdx in 0 until $vecEnd step $vecLen) { + ${vecProcessor.vecDeclarations} + $vecReplacement.intoArray($dest, $destOffset + _vec_internal_idx) + } + for($vecIdx in $vecEnd until $len) { + ${vecProcessor.linDeclarations} + $dest[$destOffset + $vecIdx] = $linReplacement.$toPrimitive + } + }else{ + $linearCode + } + """.trimIndent() + else return linearCode + } else if (callName == "reduce" && args.size == 2) { + val handle = args[0].text + val len = args[1].text + + val neutral = vecProcessor.neutralElement[handle] ?: return "" + + val linearOp = VectorReplacementProcessor.binaryLinearReplacements[handle] ?: return "" + val linAccumulate = linearOp("ret", linReplacement) + val linearCode = """ + var ret = $neutral + for($vecIdx in 0 until $len) { + ${vecProcessor.linDeclarations} + ret = $linAccumulate.$toPrimitive + } + ret.$toPrimitive + """.trimIndent() + + if (primitive.dataType in DataType.VECTORIZABLE.resolve()) { + return """{ + if($vecEnabled) { + val $vecLen = ${vecProcessor.vecSpecies}.length() + val $vecEnd = $len - ($len % $vecLen) + var accumulator = ${vecProcessor.vecName}.broadcast(${vecProcessor.vecSpecies}, $neutral) + ${vecProcessor.scalarDeclarations} + for ($vecIdx in 0 until $vecEnd step $vecLen) { + ${vecProcessor.vecDeclarations} + accumulator = accumulator.lanewise(VectorOperators.$handle, $vecReplacement) + } + var ret = accumulator.reduceLanes(VectorOperators.$handle) + for($vecIdx in $vecEnd until $len) { + ${vecProcessor.linDeclarations} + ret = $linAccumulate.$toPrimitive + } + ret.$toPrimitive + }else{ + $linearCode + } + }.invoke() + """.trimIndent() + } else { + return """{ + $linearCode + }.invoke() + """.trimIndent() + } + + } else return "" + } + + fun shouldChangeVisibilityModifier(list: KtModifierList): Boolean { val owner = list.owner val visibilityModifier = list.visibilityModifier() diff --git a/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/generator/processor/VectorReplacementProcessor.kt b/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/generator/processor/VectorReplacementProcessor.kt new file mode 100644 index 0000000..8bbf9af --- /dev/null +++ b/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/generator/processor/VectorReplacementProcessor.kt @@ -0,0 +1,330 @@ +package io.kinference.primitives.generator.processor + +import io.kinference.primitives.generator.Primitive +import io.kinference.primitives.generator.errors.require +import io.kinference.primitives.generator.fqTypename +import io.kinference.primitives.generator.initializer +import io.kinference.primitives.types.DataType +import io.kinference.primitives.vector.* +import org.jetbrains.kotlin.cli.common.messages.CompilerMessageSeverity +import org.jetbrains.kotlin.cli.common.messages.MessageCollector +import org.jetbrains.kotlin.psi.KtCallExpression +import org.jetbrains.kotlin.psi.KtExpression +import org.jetbrains.kotlin.psi.KtSimpleNameExpression +import org.jetbrains.kotlin.resolve.BindingContext +import org.jetbrains.kotlin.psi.KtFile +import org.jetbrains.kotlin.psi.KtValueArgument +import io.kinference.primitives.generator.errors.* + + +internal class VectorReplacementProcessor( + private val context: BindingContext, + private val primitive: Primitive<*, *>, + private val collector: MessageCollector, + private val file: KtFile, +) { + val vecName = "${primitive.typeName}Vector" + val vecSpecies = "$vecName.SPECIES_PREFERRED" + val vecLen = "$vecName.length()" + + companion object { + val unaryLinearReplacements = mapOf( + "EXP" to { x: String -> "FastMath.exp($x)" }, + "ABS" to { x: String -> "abs($x)" }, + "NEG" to { x: String -> "(-$x)" }, + "LOG" to { x: String -> "ln($x)" }, + "SQRT" to { x: String -> "sqrt($x)" }, + "CBRT" to { x: String -> "cbrt($x)" }, + ) + + val binaryLinearReplacements = mapOf( + "ADD" to { x: String, y: String -> "($x + $y)" }, + "SUB" to { x: String, y: String -> "($x - $y)" }, + "MUL" to { x: String, y: String -> "($x * $y)" }, + "DIV" to { x: String, y: String -> "($x / $y)" }, + "MAX" to { x: String, y: String -> "maxOf($x, $y)" }, + "MIN" to { x: String, y: String -> "minOf($x, $y)" }, + "POW" to { x: String, y: String -> "($x).pow($y)" }, + ) + + val vectorHandles = mapOf( + "Add" to "ADD", + "Sub" to "SUB", + "Mul" to "MUL", + "Div" to "DIV", + "Exp" to "EXP", + "Max" to "MAX", + "Min" to "MIN", + "Abs" to "ABS", + "Log" to "LOG", + "Neg" to "NEG", + "Pow" to "POW", + "Sqrt" to "SQRT", + "Cbrt" to "CBRT", + ) + + val supportedTypes = mapOf( + "EXP" to setOf(DataType.FLOAT, DataType.DOUBLE), + "LOG" to setOf(DataType.FLOAT, DataType.DOUBLE), + "SQRT" to setOf(DataType.FLOAT, DataType.DOUBLE), + "CBRT" to setOf(DataType.FLOAT, DataType.DOUBLE), + ).withDefault { DataType.ALL.resolve() } + + val maskHandles = mapOf( + "And" to "AND", + "Or" to "OR", + "Xor" to "XOR", + "Not" to "NOT", + "Eq" to "EQ", + "Neq" to "NEQ", + "LT" to "LT", + "LE" to "LE", + "GT" to "GT", + "GE" to "GE" + ) + + val maskUnaryReplacement = mapOf( + "Not" to ({ x: String -> "$x.not()" }), + ) + + val maskBinaryReplacement = mapOf( + "And" to ({ x: String, y: String -> "($x.and($y)" }), + "Or" to ({ x: String, y: String -> "$x.or($y)" }), + "Xor" to ({ x: String, y: String -> "$x.xor($y)" }), + ) + + val comparatorReplacement = mapOf( + "Eq" to ({ x: String, y: String -> "($x == $y)" }), + "Neq" to ({ x: String, y: String -> "($x != $y)" }), + "LT" to ({ x: String, y: String -> "($x < $y)" }), + "LE" to ({ x: String, y: String -> "($x <= $y)" }), + "GT" to ({ x: String, y: String -> "($x > $y)" }), + "GE" to ({ x: String, y: String -> "($x >= $y)" }), + ) + + val opNodeTypename = OpNode::class.qualifiedName + val unaryOpNames = UnaryOp::class.sealedSubclasses.map { it.qualifiedName } + val binaryOpNames = BinaryOp::class.sealedSubclasses.map { it.qualifiedName } + val scalarType = Value::class.qualifiedName + val primitiveSliceType = PrimitiveSlice::class.qualifiedName + val maskBinaryOpTypes = MaskBinaryOp::class.sealedSubclasses.map { it.qualifiedName } + val maskUnaryOpTypes = MaskUnaryOp::class.sealedSubclasses.map { it.qualifiedName } + val comparatorTypes = Comparator::class.sealedSubclasses.map { it.qualifiedName } + val ifElseType = IfElse::class.qualifiedName + } + + val neutralElement = mapOf( + "ADD" to "0.${ReplacementProcessor.toType(primitive)}()", + "MUL" to "1.${ReplacementProcessor.toType(primitive)}()", + "MIN" to "${primitive.typeName}.MAX_VALUE", + "MAX" to "${primitive.typeName}.MIN_VALUE" + ) + + var scalarDeclarations: String = "" + var vecDeclarations: String = "" + var linDeclarations: String = "" + var localVariables: Set = emptySet() + var scalarVariables: Set = emptySet() + + private fun processSimpleName(expr: KtExpression?): Triple? { + if (expr !is KtSimpleNameExpression) return null + val varName = expr.text + + val actualBody = expr.initializer(context) + if (varName !in localVariables) { + val success = addVariable(actualBody, varName) + if (!success) return null + } + return Triple("${varName}_vec", "${varName}_lin", varName in scalarVariables) + } + + private fun addVariable(expr: KtExpression?, varName: String): Boolean { + val (vecReplacement, linReplacement, scalar) = process(expr) ?: return false + localVariables = localVariables + varName + if (scalar) { + scalarVariables = scalarVariables + varName + scalarDeclarations += "val ${varName}_vec = $vecReplacement\n" + scalarDeclarations += "val ${varName}_lin = $linReplacement\n" + } else { + vecDeclarations += "val ${varName}_vec = $vecReplacement\n" + linDeclarations += "val ${varName}_lin = $linReplacement\n" + } + return true + } + + fun process(expr: KtExpression?): Triple? { + if (expr !is KtCallExpression) { + return processSimpleName(expr) + } + + val exprTypename = expr.fqTypename(context) ?: return null + val shortName = exprTypename.substringAfterLast('.') + val args = expr.valueArguments + + return when (exprTypename) { + in unaryOpNames -> { + val handle = vectorHandles[shortName] ?: return null + if (!supportedTypes.getValue(handle).contains(primitive.dataType)) { + collector.report( + CompilerMessageSeverity.STRONG_WARNING, + "$handle operation is not supported for ${primitive.dataType} type", + expr.getLocation() + ) + return null + } + processUnaryOperation(handle, args) + } + + in binaryOpNames -> { + val handle = vectorHandles[shortName] ?: return null + if (!supportedTypes.getValue(handle).contains(primitive.dataType)) { + collector.report( + CompilerMessageSeverity.STRONG_WARNING, + "$handle operation is not supported for ${primitive.dataType} type", + expr.getLocation() + ) + return null + } + processBinaryOperation(handle, args) + } + + scalarType -> { + if (args.size != 1) return null + val visitor = GenerationVisitor(primitive, context, collector, file) + args[0].accept(visitor) + val linear = visitor.text() + val vectorized = "$vecName.broadcast($vecSpecies, $linear)" + Triple(vectorized, linear, true) + } + + primitiveSliceType -> { + if (args.size != 2 && args.size != 1) return null + val src = args[0].text + val offset = when (args.size) { + 2 -> args[1].text + else -> "0" + } + val vectorized = "${vecName}.fromArray($vecSpecies, $src, $offset + _vec_internal_idx)" + val linear = "$src[$offset+_vec_internal_idx]" + Triple(vectorized, linear, false) + } + + ifElseType -> processIfElse(args) + else -> null + } + } + + private fun processUnaryOperation(handle: String, args: List): Triple? { + if (args.size != 1 && args.size != 2) return null + val childExpr = args[0].getArgumentExpression() + val masked = args.size == 2 + var (childVector, childLinear, isScalar) = process(childExpr) ?: return null + val linReplace = unaryLinearReplacements[handle] ?: return null + var linear = linReplace(childLinear) + var vectorized = """$childVector + .lanewise(VectorOperators.$handle""".trimIndent() + if (masked) { + isScalar = false + val maskExpr = args[1].getArgumentExpression() + val (maskVector, maskLinear) = processMask(maskExpr) ?: return null + linear = "(if($maskLinear) $linear else $childLinear)" + vectorized += ", $maskVector)" + } else { + vectorized += ")" + } + return Triple(vectorized, linear, isScalar) + } + + private fun processBinaryOperation(handle: String, args: List): Triple? { + if (args.size != 2 && args.size != 3) return null + val masked = args.size == 3 + val leftExpr = args[0].getArgumentExpression() + val rightExpr = args[1].getArgumentExpression() + val (leftVector, leftLinear, leftScalar) = process(leftExpr) ?: return null + val (rightVector, rightLinear, rightScalar) = process(rightExpr) ?: return null + var isScalar = leftScalar && rightScalar + var linear = binaryLinearReplacements[handle]?.invoke(leftLinear, rightLinear) ?: return null + + var vectorized = if (rightScalar) """$leftVector. + lanewise(VectorOperators.$handle, $rightLinear""".trimIndent() + else """$leftVector + .lanewise(VectorOperators.$handle, $rightVector""".trimIndent() + + if (masked) { + isScalar = false + val maskExpr = args[2].getArgumentExpression() + val (maskVector, maskLinear) = processMask(maskExpr) ?: return null + linear = "(if($maskLinear) $linear else $leftLinear)" + vectorized += ", $maskVector)" + } else { + vectorized += ")" + } + return Triple(vectorized, linear, isScalar) + } + + private fun processIfElse(args: List): Triple? { + if (args.size != 3) return null + val mask = args[0].getArgumentExpression() ?: return null + val left = args[1].getArgumentExpression() ?: return null + val right = args[2].getArgumentExpression() ?: return null + val (maskVector, maskLinear) = processMask(mask) ?: return null + val (leftVector, leftLinear, leftScalar) = process(left) ?: return null + val (rightVector, rightLinear, rightScalar) = process(right) ?: return null + val isScalar = leftScalar && rightScalar + val linear = "(if($maskLinear) $leftLinear else $rightLinear)" + val vectorized = """ + $rightVector.blend($leftVector, $maskVector)""".trimIndent() + return Triple(vectorized, linear, isScalar) + } + + private fun processMask(expr: KtExpression?): Pair? { + if (expr !is KtCallExpression) { + val (vecReplacement, linReplacement, _) = processSimpleName(expr) ?: return null + return Pair(vecReplacement, linReplacement) + } + val exprTypename = expr.fqTypename(context) ?: return null + val handle = maskHandles[exprTypename.substringAfterLast('.')] ?: return null + val args = expr.valueArguments + return when (exprTypename) { + in maskUnaryOpTypes -> { + if (args.size != 1) return null + val child = args[0].getArgumentExpression() ?: return null + val (vecReplacement, linReplacement) = processMask(child) ?: return null + val linReplacer = maskUnaryReplacement[handle] ?: return null + + Pair( + linReplacer(vecReplacement), linReplacer(linReplacement) + ) + } + + in maskBinaryOpTypes -> { + if (args.size != 2) return null + val left = args[0].getArgumentExpression() ?: return null + val (leftVecReplacement, leftLinReplacement) = processMask(left) ?: return null + val right = args[0].getArgumentExpression() ?: return null + val (rightVecReplacement, rightLinReplacement) = processMask(right) ?: return null + val linReplacer = maskBinaryReplacement[handle] ?: return null + Pair( + linReplacer(leftVecReplacement, rightVecReplacement), linReplacer(leftLinReplacement, rightLinReplacement) + ) + } + + in comparatorTypes -> { + if (args.size != 2) return null + val leftExpr = args[0].getArgumentExpression() ?: return null + val rightExpr = args[1].getArgumentExpression() ?: return null + val (leftVector, leftLinear, leftScalar) = process(leftExpr) ?: return null + val (rightVector, rightLinear, rightScalar) = process(rightExpr) ?: return null + val linear = comparatorReplacement[handle]?.invoke(leftLinear, rightLinear) ?: return null + val vectorized = """ + $leftVector + .compare(VectorOperators.$handle, $rightVector) + """.trimIndent() + Pair(vectorized, linear) + } + + else -> null + } + } + +} diff --git a/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/utils/psi/KtDefaultVisitor.kt b/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/utils/psi/KtDefaultVisitor.kt index ea693f4..b0b4d29 100644 --- a/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/utils/psi/KtDefaultVisitor.kt +++ b/plugin-build/primitives-plugin/src/main/kotlin/io/kinference/primitives/utils/psi/KtDefaultVisitor.kt @@ -2,6 +2,7 @@ package io.kinference.primitives.utils.psi import org.jetbrains.kotlin.com.intellij.psi.PsiElement import org.jetbrains.kotlin.com.intellij.psi.impl.source.tree.LeafPsiElement +import org.jetbrains.kotlin.psi.KtTreeVisitorVoid import org.jetbrains.kotlin.psi.KtVisitorVoid internal abstract class KtDefaultVisitor : KtVisitorVoid() {