Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,6 @@ This code would generate specializations for Float and Int types via replacement
of `PrimitiveType` with corresponding type (Float or Int). Also, note that standard functions, like MAX_VALUE, are available for `PrimitiveType` like it would
be a real `Number`.

Version 2.1.0 adds the possibility of generating vectorized code using
Java's [vector API](https://download.java.net/java/early_access/jdk25/docs/api/jdk.incubator.vector/module-summary.html).
Currently, this only works inside of KInference and in combination with primitive specialization.
7 changes: 4 additions & 3 deletions build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -12,21 +12,22 @@ plugins {

allprojects {
repositories {
mavenLocal()
mavenCentral()
gradlePluginPortal()
}
}

subprojects {
tasks.withType(JavaCompile::class.java).all {
sourceCompatibility = JavaVersion.VERSION_1_8.toString()
targetCompatibility = JavaVersion.VERSION_1_8.toString()
sourceCompatibility = JavaVersion.VERSION_21.toString()
targetCompatibility = JavaVersion.VERSION_21.toString()
}

tasks.withType(KotlinCompilationTask::class.java).all {
compilerOptions {
if (this is KotlinJvmCompilerOptions) {
jvmTarget.set(JvmTarget.JVM_1_8)
jvmTarget.set(JvmTarget.JVM_21)
}

apiVersion.set(KotlinVersion.KOTLIN_2_0)
Expand Down
2 changes: 1 addition & 1 deletion gradle.properties
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
GROUP=io.kinference.primitives
VERSION=2.0.0-1
VERSION=2.1.0-0
3 changes: 3 additions & 0 deletions plugin-build/primitives-annotations/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,6 @@ kotlin {
}
}

repositories {
mavenCentral()
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
package io.kinference.primitives.annotations

import io.kinference.primitives.types.*
import io.kinference.primitives.vector.*

/** Specify that any usage of [OperationNode] is subject for replacement in this file */
@Target(AnnotationTarget.FILE)
annotation class GenerateVector(vararg val types: DataType)
Original file line number Diff line number Diff line change
Expand Up @@ -31,16 +31,18 @@ enum class DataType {
BOOLEAN,

ALL,
NUMBER;
NUMBER,
VECTORIZABLE;

/**
* Resolve DataType into actual primitives -- would flatten groups into collection of primitives.
* Primitives would remain the same
*/
fun resolve(): Set<DataType> {
return when(this) {
return when (this) {
ALL -> setOf(BYTE, SHORT, INT, LONG, UBYTE, USHORT, UINT, ULONG, FLOAT, DOUBLE, BOOLEAN)
NUMBER -> setOf(BYTE, SHORT, INT, LONG, UBYTE, USHORT, UINT, ULONG, FLOAT, DOUBLE)
VECTORIZABLE -> setOf(BYTE, SHORT, INT, LONG, FLOAT, DOUBLE)
else -> setOf(this)
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
@file:Suppress("Unused", "UnusedReceiverParameter")

package io.kinference.primitives.vector

import io.kinference.primitives.types.PrimitiveArray
import io.kinference.primitives.types.PrimitiveType
import io.kinference.primitives.types.toPrimitive
import io.kinference.primitives.vector.Add
import io.kinference.primitives.vector.BinaryOp
import io.kinference.primitives.vector.Sub
import io.kinference.primitives.vector.UnaryOp

sealed class OpNode() {
public fun into(dest: PrimitiveArray, offset: Int, len: Int): Unit =
throw UnsupportedOperationException()

public fun reduce(operation: AssociativeWrapper, len: Int): PrimitiveType =
throw UnsupportedOperationException()

}

final class PrimitiveSlice(val src: PrimitiveArray, val offset: Int = 0) : OpNode() {}

final class Value(val value: PrimitiveType) : OpNode() {}

sealed class UnaryOp(val arg: OpNode) : OpNode() {
constructor(arg: OpNode, mask: VecMask) : this(arg)
}

sealed class BinaryOp(val left: OpNode, val right: OpNode) : OpNode() {
constructor(left: OpNode, right: OpNode, mask: VecMask) : this(left, right)
}

class IfElse(val condition: VecMask, val left: OpNode, val right: OpNode) : OpNode() {}

sealed class AssociativeWrapper() {}

class Exp(arg: OpNode) : UnaryOp(arg) {
constructor(arg: OpNode, mask: VecMask) : this(arg)
}

class Abs(arg: OpNode) : UnaryOp(arg) {
constructor(arg: OpNode, mask: VecMask) : this(arg)
}

class Neg(arg: OpNode) : UnaryOp(arg) {
constructor(arg: OpNode, mask: VecMask) : this(arg)
}

class Log(arg: OpNode) : UnaryOp(arg) {
constructor(arg: OpNode, mask: VecMask) : this(arg)
}

class Sqrt(arg: OpNode) : UnaryOp(arg) {
constructor(arg: OpNode, mask: VecMask) : this(arg)
}

class Cbrt(arg: OpNode) : UnaryOp(arg) {
constructor(arg: OpNode, mask: VecMask) : this(arg)
}

class Add(left: OpNode, right: OpNode) : BinaryOp(left, right) {
constructor(left: OpNode, right: OpNode, mask: VecMask) : this(left, right)
}

class Sub(left: OpNode, right: OpNode) : BinaryOp(left, right) {
constructor(left: OpNode, right: OpNode, mask: VecMask) : this(left, right)
}

class Mul(left: OpNode, right: OpNode) : BinaryOp(left, right) {
constructor(left: OpNode, right: OpNode, mask: VecMask) : this(left, right)
}

class Div(left: OpNode, right: OpNode) : BinaryOp(left, right) {
constructor(left: OpNode, right: OpNode, mask: VecMask) : this(left, right)
}

class Pow(left: OpNode, right: OpNode) : BinaryOp(left, right) {
constructor(left: OpNode, right: OpNode, mask: VecMask) : this(left, right)
}

class Max(left: OpNode, right: OpNode) : BinaryOp(left, right) {
constructor(left: OpNode, right: OpNode, mask: VecMask) : this(left, right)
}

class Min(left: OpNode, right: OpNode) : BinaryOp(left, right) {
constructor(left: OpNode, right: OpNode, mask: VecMask) : this(left, right)
}


object ADD : AssociativeWrapper() {}
object MUL : AssociativeWrapper() {}
object MAX : AssociativeWrapper() {}
object MIN : AssociativeWrapper() {}

sealed class VecMask() {}

sealed class Comparator(left: OpNode, right: OpNode) : VecMask() {}
sealed class MaskBinaryOp(left: VecMask, right: VecMask) : VecMask() {}
sealed class MaskUnaryOp(arg: VecMask) : VecMask() {}

class Not(arg: VecMask) : VecMask() {}

class Eq(left: OpNode, right: OpNode) : Comparator(left, right) {}
class Neq(left: OpNode, right: OpNode) : Comparator(left, right) {}
class LT(left: OpNode, right: OpNode) : Comparator(left, right) {}
class LE(left: OpNode, right: OpNode) : Comparator(left, right) {}
class GT(left: OpNode, right: OpNode) : Comparator(left, right) {}
class GE(left: OpNode, right: OpNode) : Comparator(left, right) {}

class And(left: VecMask, right: VecMask) : MaskBinaryOp(left, right) {}
class Or(left: VecMask, right: VecMask) : MaskBinaryOp(left, right) {}
class Xor(left: VecMask, right: VecMask) : MaskBinaryOp(left, right) {}
3 changes: 3 additions & 0 deletions plugin-build/primitives-plugin/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,6 @@ gradlePlugin {
}
}
}
repositories {
mavenCentral()
}
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
package io.kinference.primitives

import com.sun.org.apache.xpath.internal.operations.Bool
import org.gradle.api.Project
import org.gradle.api.file.DirectoryProperty
import org.gradle.api.provider.Property
import org.jetbrains.kotlin.ir.declarations.DescriptorMetadataSource
import javax.inject.Inject

open class PrimitivesExtension @Inject constructor(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ class PrimitivesGradlePlugin : Plugin<Project> {

val primitivesExt = project.extensions.create(extensionName, PrimitivesExtension::class.java)


val primitivesCache = project.gradle.sharedServices.registerIfAbsent("${project.path}_${primitivesCacheName}", PrimitivesCache::class.java) {
it.maxParallelUsages.set(1)
}
Expand All @@ -21,6 +20,7 @@ class PrimitivesGradlePlugin : Plugin<Project> {
it.group = "generate"
}


kotlinExt.sourceSets.all { sourceSet ->
sourceSet.kotlin.srcDir(primitivesExt.generationPath.dir(sourceSet.name))
primitivesCache.get().sourceSetToResolved[sourceSet.name] = false
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ import org.gradle.api.provider.Property
import org.gradle.api.tasks.*
import org.gradle.work.NormalizeLineEndings
import org.jetbrains.kotlin.cli.common.config.addKotlinSourceRoot
import org.jetbrains.kotlin.cli.common.messages.CompilerMessageSeverity
import org.jetbrains.kotlin.cli.common.messages.CompilerMessageSourceLocation
import org.jetbrains.kotlin.cli.common.messages.MessageCollector
import org.jetbrains.kotlin.cli.jvm.config.addJvmClasspathRoots
import org.jetbrains.kotlin.gradle.dsl.KotlinMultiplatformExtension
Expand Down Expand Up @@ -38,6 +40,23 @@ abstract class PrimitivesTask : DefaultTask() {
@get:Internal
abstract val primitivesCache: Property<PrimitivesCache>

private val messageCollector = object : MessageCollector {
override fun clear() {}

override fun hasErrors(): Boolean = false

override fun report(severity: CompilerMessageSeverity, message: String, location: CompilerMessageSourceLocation?) {
when (severity) {
CompilerMessageSeverity.STRONG_WARNING -> logger.error("$message: $location")
CompilerMessageSeverity.ERROR -> logger.error("$message: $location")
CompilerMessageSeverity.INFO -> logger.info(message)
CompilerMessageSeverity.WARNING -> Unit
else -> logger.debug(message)
}
}
}


init {
group = "generate"
description = "Generates primitives from sources"
Expand All @@ -64,9 +83,9 @@ abstract class PrimitivesTask : DefaultTask() {
FileWithCommon(source, isCommon)
}

val analyzeFun = when(compilation.get().platformType) {
KotlinPlatformType.jvm -> Analyze::analyzeJvmSources
KotlinPlatformType.js -> Analyze::analyzeJsSources
val analyzeFun = when (compilation.get().platformType) {
KotlinPlatformType.jvm -> Analyze::analyzeJvmSources
KotlinPlatformType.js -> Analyze::analyzeJsSources
KotlinPlatformType.common -> Analyze::analyzeCommonSources
else -> error("Unsupported platform type ${compilation.get().platformType}")
}
Expand All @@ -85,11 +104,13 @@ abstract class PrimitivesTask : DefaultTask() {
val annotated = ktSources.filter { it.isAnnotatedWith<GeneratePrimitives>(result.bindingContext) }
val notGeneratedYet = annotated.filterNot { it.virtualFilePath in primitivesCache.get().resolvedPaths }



for (ktFile in notGeneratedYet) {
val sourceSet = findSourceSetName(ktFile.virtualFilePath)
val outputDir = generationPath.dir(sourceSet).get().asFile

PrimitiveGenerator(ktFile, result.bindingContext, outputDir, MessageCollector.NONE).generate()
PrimitiveGenerator(ktFile, result.bindingContext, outputDir, messageCollector).generate()

primitivesCache.get().resolvedPaths.add(ktFile.virtualFilePath)
}
Expand Down
Loading