From a7056e86511ac0d62534f32124905bcd63e0d221 Mon Sep 17 00:00:00 2001 From: Michel Steuwer Date: Wed, 11 May 2022 16:34:11 +0100 Subject: [PATCH] Avoid anonymous objects, change unapply, remove unicode method names --- src/main/scala/rise/core/DSL/Type.scala | 2 +- src/main/scala/rise/core/DSL/package.scala | 82 ++---------- .../Compilation/AcceptorTranslation.scala | 90 ++++++------- .../Compilation/ContinuationTranslation.scala | 122 +++++++++--------- .../DPIA/Compilation/FedeTranslation.scala | 10 +- .../DPIA/Compilation/StreamTranslation.scala | 49 ++++--- src/main/scala/shine/DPIA/DSL/Core.scala | 45 ++----- .../shine/DPIA/DSL/ImperativePrimitives.scala | 23 ++-- src/main/scala/shine/DPIA/DSL/package.scala | 16 ++- src/main/scala/shine/DPIA/Lifting.scala | 4 +- .../scala/shine/DPIA/Phrases/Phrase.scala | 6 +- .../scala/shine/DPIA/Types/MatchingDSL.scala | 2 +- src/main/scala/shine/DPIA/fromRise.scala | 8 +- src/main/scala/shine/DPIA/package.scala | 42 +----- .../Passes/HoistMemoryAllocations.scala | 8 +- src/main/scala/shine/OpenCL/DSL/package.scala | 6 +- .../OpenMP/DSL/ImperativePrimitives.scala | 6 +- .../shine/OpenMP/TranslationContext.scala | 2 +- .../Passes/HoistMemoryAllocations.scala | 12 +- .../cuda/Compilation/TranslationContext.scala | 4 +- src/main/scala/util/monads.scala | 5 +- src/test/scala/shine/cuda/basic.scala | 10 +- 22 files changed, 222 insertions(+), 332 deletions(-) diff --git a/src/main/scala/rise/core/DSL/Type.scala b/src/main/scala/rise/core/DSL/Type.scala index 4a5161489..12357e669 100644 --- a/src/main/scala/rise/core/DSL/Type.scala +++ b/src/main/scala/rise/core/DSL/Type.scala @@ -187,7 +187,7 @@ object Type { object ->: { def unapply[T <: ExprType, U <: ExprType](funType: FunType[T, U]): Option[(T, U)] = { - FunType.unapply(funType) + Some((funType.inT, funType.outT)) } } diff --git a/src/main/scala/rise/core/DSL/package.scala b/src/main/scala/rise/core/DSL/package.scala index 4d63b4cbe..c15498140 100644 --- a/src/main/scala/rise/core/DSL/package.scala +++ b/src/main/scala/rise/core/DSL/package.scala @@ -37,13 +37,13 @@ package object DSL { def toMemFun(f: ToBeTyped[Expr]): ToBeTyped[Expr] = fun(x => toMem(f(x))) case class `if`(b: ToBeTyped[Expr]) { - def `then`(tE: ToBeTyped[Expr]): Object { - def `else` (eE: ToBeTyped[Expr] ): ToBeTyped[Expr] - } = { - new { - def `else`(eE: ToBeTyped[Expr]): ToBeTyped[Expr] = { - select(b)(tE)(eE) - } + def `then`(tE: ToBeTyped[Expr]): `if`.`then` = `if`.`then`(b, tE) + } + + object `if` { + case class `then`(b: ToBeTyped[Expr], tE: ToBeTyped[Expr]) { + def `else`(eE: ToBeTyped[Expr]): ToBeTyped[Expr] = { + select(b)(tE)(eE) } } } @@ -265,54 +265,8 @@ package object DSL { } // noinspection TypeAnnotation - // scalastyle:off structural.type - def apply(ft: FunType[ExprType, ExprType]): Object { - def apply(f: (ToBeTyped[Identifier], ToBeTyped[Identifier], - ToBeTyped[Identifier], ToBeTyped[Identifier], - ToBeTyped[Identifier], ToBeTyped[Identifier], - ToBeTyped[Identifier], ToBeTyped[Identifier], - ToBeTyped[Identifier]) => ToBeTyped[Expr] - ): ToBeTyped[Expr] - - def apply(f: (ToBeTyped[Identifier], ToBeTyped[Identifier], - ToBeTyped[Identifier], ToBeTyped[Identifier], - ToBeTyped[Identifier], ToBeTyped[Identifier], - ToBeTyped[Identifier], ToBeTyped[Identifier] - ) => ToBeTyped[Expr] - ): ToBeTyped[Expr] - - def apply(f: (ToBeTyped[Identifier], ToBeTyped[Identifier], - ToBeTyped[Identifier], ToBeTyped[Identifier], - ToBeTyped[Identifier], ToBeTyped[Identifier], - ToBeTyped[Identifier]) => ToBeTyped[Expr] - ): ToBeTyped[Expr] - - def apply(f: (ToBeTyped[Identifier], ToBeTyped[Identifier], - ToBeTyped[Identifier], ToBeTyped[Identifier], - ToBeTyped[Identifier], ToBeTyped[Identifier] - ) => ToBeTyped[Expr] - ): ToBeTyped[Expr] - - def apply(f: (ToBeTyped[Identifier], ToBeTyped[Identifier], - ToBeTyped[Identifier], ToBeTyped[Identifier], - ToBeTyped[Identifier]) => ToBeTyped[Expr] - ): ToBeTyped[Expr] - - def apply(f: (ToBeTyped[Identifier], ToBeTyped[Identifier], - ToBeTyped[Identifier], ToBeTyped[Identifier] - ) => ToBeTyped[Expr] - ): ToBeTyped[Expr] - - def apply(f: (ToBeTyped[Identifier], ToBeTyped[Identifier], - ToBeTyped[Identifier]) => ToBeTyped[Expr] - ): ToBeTyped[Expr] - - def apply(f: (ToBeTyped[Identifier], ToBeTyped[Identifier] - ) => ToBeTyped[Expr] - ): ToBeTyped[Expr] - - def apply(f: ToBeTyped[Identifier] => ToBeTyped[Expr]): ToBeTyped[Expr] - } = new { + def apply(ft: FunType[ExprType, ExprType]): WithFunType = WithFunType(ft) + case class WithFunType(ft: FunType[ExprType, ExprType]) { def apply(f: ToBeTyped[Identifier] => ToBeTyped[Expr]): ToBeTyped[Expr] = fun(f) :: ft @@ -397,7 +351,6 @@ package object DSL { ) => ToBeTyped[Expr] ): ToBeTyped[Expr] = fun(f) :: ft } - // scalastyle:on structural.type } object depFun { @@ -471,19 +424,12 @@ package object DSL { } // noinspection ScalaUnusedSymbol - // scalastyle:off structural.type - object let { - def apply(e: ToBeTyped[Expr]): Object { - def be(in: ToBeTyped[Expr] => ToBeTyped[Expr]): ToBeTyped[Expr] - def be(in: ToBeTyped[Expr]): ToBeTyped[Expr] - } = new { - def be(in: ToBeTyped[Expr] => ToBeTyped[Expr]): ToBeTyped[Expr] = - primitives.let(e)(fun(in)) - def be(in: ToBeTyped[Expr]): ToBeTyped[Expr] = - primitives.let(e)(in) - } + case class let(e: ToBeTyped[Expr]) { + def be(in: ToBeTyped[Expr] => ToBeTyped[Expr]): ToBeTyped[Expr] = + primitives.let(e)(fun(in)) + def be(in: ToBeTyped[Expr]): ToBeTyped[Expr] = + primitives.let(e)(in) } - // scalastyle:on structural.type object letf { def apply(in: ToBeTyped[Expr] => ToBeTyped[Expr]): ToBeTyped[Expr] = { diff --git a/src/main/scala/shine/DPIA/Compilation/AcceptorTranslation.scala b/src/main/scala/shine/DPIA/Compilation/AcceptorTranslation.scala index e1603b1ba..786022d9d 100644 --- a/src/main/scala/shine/DPIA/Compilation/AcceptorTranslation.scala +++ b/src/main/scala/shine/DPIA/Compilation/AcceptorTranslation.scala @@ -47,7 +47,7 @@ object AcceptorTranslation { acc(fst)(pairAcc1(dt1, dt2, A)) `;` acc(snd)(pairAcc2(dt1, dt2, A)) case _ => - con(e)(λ(e.t)(a => A :=| e.t.dataType | a)) + con(e)(fun(e.t)(a => A :=| e.t.dataType | a)) } case c: Literal => A :=|c.t.dataType| c @@ -57,13 +57,13 @@ object AcceptorTranslation { case n: Natural => A :=|n.t.dataType| n case u@UnaryOp(op, e) => - con(e)(λ(u.t)(x => + con(e)(fun(u.t)(x => A :=|u.t.dataType| UnaryOp(op, x) )) case b@BinOp(op, e1, e2) => - con(e1)(λ(b.t)(x => - con(e2)(λ(b.t)(y => + con(e1)(fun(b.t)(x => + con(e2)(fun(b.t)(y => A :=|b.t.dataType| BinOp(op, x, y) )) )) @@ -73,7 +73,7 @@ object AcceptorTranslation { case LetNat(binder, defn, body) => LetNat(binder, defn, acc(body)(A)) case IfThenElse(cond, thenP, elseP) => - con(cond)(λ(cond.t) { x => + con(cond)(fun(cond.t) { x => `if` (x) `then` acc(thenP)(A) `else` acc(elseP)(A) }) @@ -95,7 +95,7 @@ object AcceptorTranslation { acc(array)(AsVectorAcc(n, m, dt, A)) case DepIdx(n, ft, index, array) => - con(array)(λ(expT(n`.d`ft, read))(x => + con(array)(fun(expT(n`.d`ft, read))(x => A :=| ft(index) | DepIdx(n, ft, index, x))) case DepJoin(n, lenF, dt, array) => @@ -103,7 +103,7 @@ object AcceptorTranslation { case depMapSeq@DepMapSeq(unroll) => val (n, ft1, ft2, f, array) = depMapSeq.unwrap - con(array)(λ(expT(n`.d`ft1, read))(x => + con(array)(fun(expT(n`.d`ft1, read))(x => forNat(unroll, n, i => acc(f(i)(x `@d` i))(A `@d` i)) )) @@ -113,19 +113,19 @@ object AcceptorTranslation { case DMatch(x, elemT, outT, a, f, input) => // Turn the f imperative by means of forwarding the acceptor translation - con(input)(λ(expT(DepPairType(NatKind, x, elemT), read))(pair => + con(input)(fun(expT(DepPairType(NatKind, x, elemT), read))(pair => DMatchI(x, elemT, outT, - _Λ_(NatKind)((fst: NatIdentifier) => - λ(expT(substituteNatInType(fst, x, elemT), read))(snd => + depFun(NatKind)((fst: NatIdentifier) => + fun(expT(substituteNatInType(fst, x, elemT), read))(snd => acc(f(fst)(snd))(A) )), pair))) case IdxVec(n, st, index, vector) => - con(vector)(λ(expT(vec(n, st), read))(x => + con(vector)(fun(expT(vec(n, st), read))(x => A :=| st | IdxVec(n, st, index, x))) case Iterate(n, m, k, dt, f, array) => - con(array)(λ(expT((m * n.pow(k))`.`dt, read))(x => { + con(array)(fun(expT((m * n.pow(k))`.`dt, read))(x => { val sz = n.pow(k) * m newDoubleBuffer(sz`.`dt, m`.`dt, sz`.`dt, x, A, @@ -144,7 +144,7 @@ object AcceptorTranslation { })) case IterateStream(n, dt1, dt2, f, array) => - val fI = λ(expT(dt1, read))(x => λ(accT(dt2))(o => acc(f(x))(o))) + val fI = fun(expT(dt1, read))(x => fun(accT(dt2))(o => acc(f(x))(o))) val i = NatIdentifier(freshName("i")) str(array)(fun((i: NatIdentifier) ->: (expT(dt1, read) ->: (comm: CommType)) ->: (comm: CommType) @@ -178,7 +178,7 @@ object AcceptorTranslation { val o = Identifier(freshName("fede_o"), otype) acc(array)(MapAcc(n, dt2, dt1, - Lambda(o, fedAcc(scala.Predef.Map((x, o)))(f(x))(λ(otype)(x => x))), + Lambda(o, fedAcc(scala.Predef.Map((x, o)))(f(x))(fun(otype)(x => x))), A)) case MapFst(w, dt1, dt2, dt3, f, record) => @@ -193,7 +193,7 @@ object AcceptorTranslation { case mapSeq@MapSeq(unroll) => val (n, dt1, dt2, f, array) = mapSeq.unwrap - con(array)(λ(expT(n`.`dt1, read))(x => + con(array)(fun(expT(n`.`dt1, read))(x => comment("mapSeq")`;` `for`(unroll, n, i => acc(f(x `@` i))(A `@` i)) )) @@ -209,7 +209,7 @@ object AcceptorTranslation { A)) case MapVec(n, dt1, dt2, f, array) => - con(array)(λ(expT(vec(n, dt1), read))(x => + con(array)(fun(expT(vec(n, dt1), read))(x => shine.OpenMP.DSL.parForVec(n, dt2, A, i => a => acc(f(x `@v` i))(a)) )) @@ -221,15 +221,15 @@ object AcceptorTranslation { case reduceSeq@ReduceSeq(unroll) => val (n, dt1, dt2, f, init, array) = reduceSeq.unwrap - con(reduceSeq)(λ(expT(dt2, write))(r => + con(reduceSeq)(fun(expT(dt2, write))(r => acc(r)(A))) case Reorder(n, dt, access, idxF, idxFinv, input) => acc(input)(ReorderAcc(n, dt, idxFinv, A)) case ScanSeq(n, dt1, dt2, f, init, array) => - con(array)(λ(expT(n`.`dt1, read))(x => - con(init)(λ(expT(dt2, read))(y => + con(array)(fun(expT(n`.`dt1, read))(x => + con(init)(fun(expT(dt2, read))(y => comment("scanSeq")`;` `new`(dt2, accumulator => acc(y)(accumulator.wr) `;` @@ -243,7 +243,7 @@ object AcceptorTranslation { acc(input)(ScatterAcc(n, m, dt, y, A)))) case slide@Slide(n, sz, sp, dt, input) => - con(slide)(λ(expT(n`.`(sz`.`dt), read))(x => + con(slide)(fun(expT(n`.`(sz`.`dt), read))(x => A :=|(n`.`(sz`.`dt))| x )) case Split(n, m, w, dt, array) => @@ -256,7 +256,7 @@ object AcceptorTranslation { acc(e)(UnzipAcc(n, dt1, dt2, A)) case VectorFromScalar(n, dt, arg) => - con(arg)(λ(expT(dt, read))(e => + con(arg)(fun(expT(dt, read))(e => A :=|vec(n, dt)| VectorFromScalar(n, dt, e))) case Zip(n, dt1, dt2, access, e1, e2) => @@ -265,22 +265,22 @@ object AcceptorTranslation { // OpenMP case omp.DepMapPar(n, ft1, ft2, f, array) => - con(array)(λ(expT(n`.d`ft1, read))(x => { + con(array)(fun(expT(n`.d`ft1, read))(x => { shine.OpenMP.DSL.parForNat(n, ft2, A, idx => a => acc(f(idx)(x `@d` idx))(a)) })) case omp.MapPar(n, dt1, dt2, f, array) => - con(array)(λ(expT(n`.`dt1, read))(x => + con(array)(fun(expT(n`.`dt1, read))(x => shine.OpenMP.DSL.parFor(n, dt2, A, i => a => acc(f(x `@` i))(a)))) case reducePar@omp.ReducePar(n, dt1, dt2, f, init, array) => - con(reducePar)(λ(expT(dt2, write))(r => + con(reducePar)(fun(expT(dt2, write))(r => acc(r)(A))) // OpenCL case depMap@ocl.DepMap(level, dim) => val (n, ft1, ft2, f, array) = depMap.unwrap - con(array)(λ(expT(n`.d`ft1, read))(x => { + con(array)(fun(expT(n`.d`ft1, read))(x => { import shine.OpenCL.DSL._ level match { case shine.OpenCL.Global => @@ -294,7 +294,7 @@ object AcceptorTranslation { }})) case ocl.Iterate(a, n, m, k, dt, f, array) => - con(array)(λ(expT({m * n.pow(k)}`.`dt, read))(x => { + con(array)(fun(expT({m * n.pow(k)}`.`dt, read))(x => { import arithexpr.arithmetic.Cst val sz = n.pow(k) * m @@ -318,7 +318,7 @@ object AcceptorTranslation { case Nil => oclImp.KernelCallCmd(name, localSize, globalSize, n)(kc.inTs, kc.outT, kc.args, A) case Seq(arg, tail@_*) => - con(arg)(λ(expT(arg.t.dataType, read))(e => rec(tail, es :+ e))) + con(arg)(fun(expT(arg.t.dataType, read))(e => rec(tail, es :+ e))) } } @@ -326,10 +326,10 @@ object AcceptorTranslation { case map@ocl.Map(level, dim) => val (n, dt1, dt2, f, array) = map.unwrap - con(array)(λ(expT(n `.` dt1, read))(x => { + con(array)(fun(expT(n `.` dt1, read))(x => { comment(s"map${level.toString}") `;` shine.OpenCL.DSL.parFor(level, dim, unroll = false)(n, dt2, A, - λ(expT(idx(n), read))(i => λ(accT(dt2))(a => acc(f(x `@` i))(a)))) + fun(expT(idx(n), read))(i => fun(accT(dt2))(a => acc(f(x `@` i))(a)))) })) case fc@ocl.OpenCLFunctionCall(name, n) => @@ -339,11 +339,11 @@ object AcceptorTranslation { ts match { // with only one argument left to process return the assignment of the OpenCLFunction call case Seq( (arg, inT) ) => - con(arg)(λ(expT(inT, read))(e => + con(arg)(fun(expT(inT, read))(e => A :=|fc.outT| ocl.OpenCLFunctionCall(name, n)(inTs :+ inT, fc.outT, exps :+ e) )) // with a `tail` of arguments left, recurse case Seq( (arg, inT), tail@_* ) => - con(arg)(λ(expT(inT, read))(e => rec(tail, exps :+ e, inTs :+ inT) )) + con(arg)(fun(expT(inT, read))(e => rec(tail, exps :+ e, inTs :+ inT) )) } } @@ -351,23 +351,23 @@ object AcceptorTranslation { // CUDA case cuda.AsFragment(rows, columns, layers, dataType, fragmentKind, layout, matrix) => - con(matrix)(λ(ExpType(ArrayType(rows, ArrayType(columns, dataType)), read))(matrix => + con(matrix)(fun(ExpType(ArrayType(rows, ArrayType(columns, dataType)), read))(matrix => cudaImp.WmmaLoad(rows, columns, layers, dataType, fragmentKind, layout, matrix, A))) case cuda.AsMatrix(rows, columns, layers, dataType, fragment) => - con(fragment)(λ(ExpType(fragment.t.dataType, read))(fragment => + con(fragment)(fun(ExpType(fragment.t.dataType, read))(fragment => cudaImp.WmmaStore(rows, columns, layers, dataType, fragment, A))) case cuda.GenerateFragment(rows, columns, layers, dataType, frag, layout, fill) => - con(fill)(λ(ExpType(dataType, read))(fill => + con(fill)(fun(ExpType(dataType, read))(fill => cudaImp.WmmaFill(rows, columns, layers, dataType, frag, layout, fill, A))) case map@cuda.Map(level, dim) => val (n, dt1, dt2, f, array) = map.unwrap - con(array)(λ(expT(n `.` dt1, read))(x => { + con(array)(fun(expT(n `.` dt1, read))(x => { val forLoop = comment(s"map${level.toString}") `;` shine.cuda.DSL.parFor(level, dim, unroll = false)(n, dt2, A, - λ(expT(idx(n), read))(i => λ(accT(dt2))(a => + fun(expT(idx(n), read))(i => fun(accT(dt2))(a => acc(f(x `@` i))(a)))) //TODO use other InsertMemoryBarrieres-mechanism level match { @@ -378,17 +378,17 @@ object AcceptorTranslation { } })) - case cuda.MapFragment(rows, columns, layers, dt, frag, layout, fun, input) => - con(input)(λ(expT(FragmentType(rows, columns, layers, dt, frag, layout), read))(input => + case cuda.MapFragment(rows, columns, layers, dt, frag, layout, f, input) => + con(input)(fun(expT(FragmentType(rows, columns, layers, dt, frag, layout), read))(input => shine.cuda.primitives.imperative.ForFragment(rows, columns, layers, dt, frag, layout, input, A, - λ(expT(dt, read))(x => - λ(accT(dt))(o => - acc(fun(x))(o)))))) + fun(expT(dt, read))(x => + fun(accT(dt))(o => + acc(f(x))(o)))))) case cuda.TensorMatMultAdd(m, n, k, layoutA, layoutB, dataType, dataTypeAcc, aMatrix, bMatrix, cMatrix) => - con(aMatrix)(λ(ExpType(FragmentType(m, n, k, dataType, Fragment.AMatrix, layoutA), read))(aMatrix => - con(bMatrix)(λ(ExpType(FragmentType(m, n, k, dataType, Fragment.BMatrix, layoutB), read))(bMatrix => - con(cMatrix)(λ(ExpType(FragmentType(m, n, k, dataTypeAcc, Fragment.Accumulator, MatrixLayout.None), read))(cMatrix => + con(aMatrix)(fun(ExpType(FragmentType(m, n, k, dataType, Fragment.AMatrix, layoutA), read))(aMatrix => + con(bMatrix)(fun(ExpType(FragmentType(m, n, k, dataType, Fragment.BMatrix, layoutB), read))(bMatrix => + con(cMatrix)(fun(ExpType(FragmentType(m, n, k, dataTypeAcc, Fragment.Accumulator, MatrixLayout.None), read))(cMatrix => cudaImp.WmmaMMA(m, n, k, layoutA, layoutB, dataType, dataTypeAcc, aMatrix, bMatrix, cMatrix, A))))))) //GAP8 @@ -402,7 +402,7 @@ object AcceptorTranslation { case Nil => shine.GAP8.primitives.imperative.KernelCallCmd(name, cores, n)(kc.inTs, kc.outT, kc.args, A) case Seq(arg, tail@_*) => - con(arg)(λ(expT(arg.t.dataType, read))(e => rec(tail, es :+ e))) + con(arg)(fun(expT(arg.t.dataType, read))(e => rec(tail, es :+ e))) } rec(kc.args, Seq()) diff --git a/src/main/scala/shine/DPIA/Compilation/ContinuationTranslation.scala b/src/main/scala/shine/DPIA/Compilation/ContinuationTranslation.scala index db385a900..7f20236a6 100644 --- a/src/main/scala/shine/DPIA/Compilation/ContinuationTranslation.scala +++ b/src/main/scala/shine/DPIA/Compilation/ContinuationTranslation.scala @@ -30,13 +30,13 @@ object ContinuationTranslation { case n: Natural => C(n) case u@UnaryOp(op, e) => - con(e)(λ(u.t)(x => + con(e)(fun(u.t)(x => C(UnaryOp(op, x)) )) case b@BinOp(op, e1, e2) => - con(e1)(λ(b.t)(x => - con(e2)(λ(b.t)(y => + con(e1)(fun(b.t)(x => + con(e2)(fun(b.t)(y => C(BinOp(op, x, y)) )) )) @@ -55,7 +55,7 @@ object ContinuationTranslation { } case IfThenElse(cond, thenP, elseP) => - con(cond)(λ(cond.t) { x => + con(cond)(fun(cond.t) { x => `if`(x) `then` con(thenP)(C) `else` con(elseP)(C) }) @@ -69,15 +69,15 @@ object ContinuationTranslation { (C: Phrase[ExpType ->: CommType]) (implicit context: TranslationContext): Phrase[CommType] = E match { case AsScalar(n, m, dt, access, array) => - con(array)(λ(array.t)(x => + con(array)(fun(array.t)(x => C(AsScalar(n, m, dt, access, x)))) case AsVector(n, m, dt, access, array) => - con(array)(λ(array.t)(x => + con(array)(fun(array.t)(x => C(AsVector(n, m, dt, access, x)))) case AsVectorAligned(n, m, w, dt, array) => - con(array)(λ(array.t)(x => + con(array)(fun(array.t)(x => C(AsVectorAligned(n, m, w, dt, x)) )) case Cast(dt1, dt2, e) => @@ -89,34 +89,34 @@ object ContinuationTranslation { C(Cycle(n, m, dt, x)))) case DepIdx(n, ft, index, array) => - con(array)(λ(expT(n`.d`ft, read))(e => + con(array)(fun(expT(n`.d`ft, read))(e => C(DepIdx(n, ft, index, e)))) case DepJoin(n, lenF, dt, array) => - con(array)(λ(expT(n `.d` { i => lenF(i)`.`dt }, read))(x => + con(array)(fun(expT(n `.d` { i => lenF(i)`.`dt }, read))(x => C(DepJoin(n, lenF, dt, x)))) case depMapSeq: DepMapSeq => val (n, _, ft2, _, _) = depMapSeq.unwrap - `new`(n`.d`ft2, λ(varT(n`.d`ft2))(tmp => + `new`(n`.d`ft2, fun(varT(n`.d`ft2))(tmp => acc(depMapSeq)(tmp.wr) `;` C(tmp.rd) )) case DepZip(n, ft1, ft2, e1, e2) => - con(e1)(λ(ExpType(DepArrayType(n, ft1), read))(x => - con(e2)(λ(ExpType(DepArrayType(n, ft2), read))(y => + con(e1)(fun(ExpType(DepArrayType(n, ft1), read))(x => + con(e2)(fun(ExpType(DepArrayType(n, ft2), read))(y => C(DepZip(n, ft1, ft2, x, y)) )) )) case DMatch(x, elemT, outT, a, f, input) => // Turn the f imperative by means of forwarding the continuation translation - con(input)(λ(expT(DepPairType(NatKind, x, elemT), read))(pair => + con(input)(fun(expT(DepPairType(NatKind, x, elemT), read))(pair => DMatchI(x, elemT, outT, - _Λ_(NatKind)((fst: NatIdentifier) => - λ(expT(substituteNatInType(fst, x, elemT), read))(snd => + depFun(NatKind)((fst: NatIdentifier) => + fun(expT(substituteNatInType(fst, x, elemT), read))(snd => con(f(fst)(snd))(C) )), pair))) case Drop(n, m, dt, array) => - con(array)(λ(expT((n + m)`.` dt, read))(x => + con(array)(fun(expT((n + m)`.` dt, read))(x => C(Drop(n, m, dt, x)))) case ffc@ForeignFunctionCall(funDecl, n) => @@ -126,7 +126,7 @@ object ContinuationTranslation { ts match { // with only one argument left to process return the assignment of the function call case Seq( (arg, inT) ) => - con(arg)(λ(expT(inT, read))(e => + con(arg)(fun(expT(inT, read))(e => ffc.outT match { // TODO: this is an ugly fix to avoid calling the function multiple times // for pair assignment, see: @@ -147,7 +147,7 @@ object ContinuationTranslation { })) // with a `tail` of arguments left, rec case Seq( (arg, inT), tail@_* ) => - con(arg)(λ(expT(inT, read))(e => + con(arg)(fun(expT(inT, read))(e => rec(tail, exps :+ e, inTs :+ inT) )) } } @@ -155,7 +155,7 @@ object ContinuationTranslation { rec(ffc.args zip ffc.inTs, Seq(), Seq()) case Fst(dt1, dt2, pair) => - con(pair)(λ(expT(dt1 x dt2, read))(x => + con(pair)(fun(expT(dt1 x dt2, read))(x => C(Fst(dt1, dt2, x)))) case Gather(n, m, dt, indices, input) => @@ -171,25 +171,25 @@ object ContinuationTranslation { con(f(i))(fun(expT(dt, read))(g => Apply(cont, g))))))) case Idx(n, dt, index, array) => - con(array)(λ(expT(n`.`dt, read))(e => + con(array)(fun(expT(n`.`dt, read))(e => con(index)(fun(index.t)(i => C(Idx(n, dt, i, e)))))) case idx@shine.OpenCL.primitives.imperative.IdxDistribute(level) => val (m, n, stride, dt, array) = idx.unwrap - con(array)(λ(expT(m`.`dt, read))(e => + con(array)(fun(expT(m`.`dt, read))(e => C(shine.OpenCL.primitives.imperative.IdxDistribute(level)(m, n, stride, dt, e)))) case IdxVec(n, st, index, vector) => - con(vector)(λ(expT(vec(n, st), read))(e => + con(vector)(fun(expT(vec(n, st), read))(e => C(IdxVec(n, st, index, e)))) case IndexAsNat(n, e) => - con(e)(λ(expT(idx(n), read))(x => + con(e)(fun(expT(idx(n), read))(x => C(IndexAsNat(n, x)))) case Join(n, m, w, dt, array) => - con(array)(λ(expT(n`.`(m`.`dt), read))(x => + con(array)(fun(expT(n`.`(m`.`dt), read))(x => C(Join(n, m, w, dt, x)))) case Let(dt1, dt2, access, value, f) => @@ -223,12 +223,12 @@ object ContinuationTranslation { }) case MakePair(dt1, dt2, access, fst, snd) => - con(fst)(λ(expT(dt1, read))(x => - con(snd)(λ(expT(dt2, read))(y => + con(fst)(fun(expT(dt1, read))(x => + con(snd)(fun(expT(dt2, read))(y => C(MakePair(dt1, dt2, access, x, y)))))) case Map(n, dt1, dt2, access, f, array) => - con(array)(λ(expT(n`.`dt1, read))(x => + con(array)(fun(expT(n`.`dt1, read))(x => C(MapRead(n, dt1, dt2, fun(expT(dt1, read))(a => fun(expT(dt2, read) ->: (comm: CommType))(cont => @@ -243,7 +243,7 @@ object ContinuationTranslation { val (n, _, dt2, _, _) = mapSeq.unwrap println("WARNING: map loop continuation translation allocates memory") // TODO should be removed - `new`(n`.`dt2, λ(varT(n`.`dt2))(tmp => + `new`(n`.`dt2, fun(varT(n`.`dt2))(tmp => acc(mapSeq)(tmp.wr) `;` C(tmp.rd) )) case MapSnd(w, dt1, dt2, dt3, f, record) => @@ -254,25 +254,25 @@ object ContinuationTranslation { println("WARNING: map loop continuation translation allocates memory") // TODO should be removed `new`(vec(n, dt2), - λ(varT(vec(n, dt2)))(tmp => + fun(varT(vec(n, dt2)))(tmp => acc(mapVec)(tmp.wr) `;` C(tmp.rd))) case NatAsIndex(n, e) => - con(e)(λ(e.t)(x => + con(e)(fun(e.t)(x => C(NatAsIndex(n, x)))) case PadCst(n, l, r, dt, padExp, array) => - con(array)(λ(expT(n`.`dt, read))(x => - con(padExp)(λ(expT(dt, read))(p => + con(array)(fun(expT(n`.`dt, read))(x => + con(padExp)(fun(expT(dt, read))(p => C(PadCst(n, l, r, dt, p, x)))))) case PadClamp(n, l, r, dt, array) => - con(array)(λ(expT(n`.`dt, read))(x => + con(array)(fun(expT(n`.`dt, read))(x => C(PadClamp(n, l, r, dt, x)))) case Partition(n, m, lenF, dt, array) => - con(array)(λ(expT(n`.`dt, read))(x => + con(array)(fun(expT(n`.`dt, read))(x => C(Partition(n, m, lenF, dt, x)))) case PrintType(msg, dt, access, input) => @@ -280,7 +280,7 @@ object ContinuationTranslation { case reduceSeq@ReduceSeq(unroll) => val (n, dt1, dt2, f, init, array) = reduceSeq.unwrap - con(array)(λ(expT(n`.`dt1, read))(X => { + con(array)(fun(expT(n`.`dt1, read))(X => { comment("reduceSeq") `;` `new`(dt2, accum => acc(init)(accum.wr) `;` @@ -290,28 +290,28 @@ object ContinuationTranslation { })) case Reorder(n, dt, access, idxF, idxFinv, input) => - con(input)(λ(expT(n`.`dt, read))(x => + con(input)(fun(expT(n`.`dt, read))(x => C(Reorder(n, dt, access, idxF, idxFinv, x)))) case scanSeq@ScanSeq(n, dt1, dt2, f, init, array) => - `new`(n`.`dt2, λ(varT(n`.`dt2))(tmp => + `new`(n`.`dt2, fun(varT(n`.`dt2))(tmp => acc(scanSeq)(tmp.wr) `;` C(tmp.rd) )) case slide@Slide(n, sz, sp, dt, input) => val inputSize = sp*n+sz - con(input)(λ(expT(inputSize`.`dt, read))(x => + con(input)(fun(expT(inputSize`.`dt, read))(x => C(Slide(n, sz, sp, dt, x)) )) case Snd(dt1, dt2, pair) => - con(pair)(λ(expT(dt1 x dt2, read))(x => + con(pair)(fun(expT(dt1 x dt2, read))(x => C(Snd(dt1, dt2, x)))) case Split(n, m, w, dt, array) => - con(array)(λ(expT((m * n)`.`dt, read))(x => + con(array)(fun(expT((m * n)`.`dt, read))(x => C(Split(n, m, w, dt, x)))) case Take(n, m, dt, array) => - con(array)(λ(expT((n + m)`.`dt, read))(x => + con(array)(fun(expT((n + m)`.`dt, read))(x => C(Take(n, m, dt, x)))) case ToMem(dt, input) => @@ -322,49 +322,49 @@ object ContinuationTranslation { C(Transpose(n, m, dt, access, x)))) case TransposeDepArray(n, m, f, array) => - con(array)(λ(expT(n`.`(m`.d`f), read))(x => + con(array)(fun(expT(n`.`(m`.d`f), read))(x => C(TransposeDepArray(n, m, f, x)))) case Unzip(n, dt1, dt2, access, e) => - con(e)(λ(expT(n`.`(dt1 x dt2), read))(x => + con(e)(fun(expT(n`.`(dt1 x dt2), read))(x => C(Unzip(n, dt1, dt2, access, x)))) case VectorFromScalar(n, dt, arg) => - con(arg)(λ(expT(dt, read))(e => + con(arg)(fun(expT(dt, read))(e => C(VectorFromScalar(n, dt, e)) )) case Zip(n, dt1, dt2, access, e1, e2) => - con(e1)(λ(expT(n`.`dt1, read))(x => - con(e2)(λ(expT(n`.`dt2, read))(y => + con(e1)(fun(expT(n`.`dt1, read))(x => + con(e2)(fun(expT(n`.`dt2, read))(y => C(Zip(n, dt1, dt2, access, x, y)) )) )) // OpenMP case depMapPar@omp.DepMapPar(n, ft1, ft2, f, array) => - `new`(n`.d`ft2, λ(varT(n`.d`ft2))(tmp => + `new`(n`.d`ft2, fun(varT(n`.d`ft2))(tmp => acc(depMapPar)(tmp.wr) `;` C(tmp.rd) )) case mapPar@omp.MapPar(n, dt1, dt2, f, array) => println("WARNING: map loop continuation translation allocates memory") // TODO should be removed - `new`(n`.`dt2, λ(varT(n`.`dt2))(tmp => + `new`(n`.`dt2, fun(varT(n`.`dt2))(tmp => acc(mapPar)(tmp.wr) `;` C(tmp.rd) )) case omp.ReducePar(n, dt1, dt2, f, init, array) => - con(array)(λ(expT(n`.`dt1, read))(X => - con(init)(λ(expT(dt2, read))(Y => + con(array)(fun(expT(n`.`dt1, read))(X => + con(init)(fun(expT(dt2, read))(Y => ??? )))) // OpenCL case depMap: ocl.DepMap => val (n, _, ft2, _, _) = depMap.unwrap - `new`(n`.d`ft2, λ(varT(n`.d`ft2))(tmp => + `new`(n`.d`ft2, fun(varT(n`.d`ft2))(tmp => acc(depMap)(tmp.wr) `;` C(tmp.rd) )) case map@ocl.Map(level, dim) => println("WARNING: map loop continuation translation allocates memory") // TODO should be removed - `new`(map.n `.` map.dt2, λ(varT(map.n `.` map.dt2))(tmp => + `new`(map.n `.` map.dt2, fun(varT(map.n `.` map.dt2))(tmp => acc(map)(tmp.wr) `;` C(tmp.rd))) case fc@ocl.OpenCLFunctionCall(name, n) => @@ -374,11 +374,11 @@ object ContinuationTranslation { ts match { // with only one argument left to process continue with the OpenCLFunction call case Seq( (arg, inT) ) => - con(arg)(λ(expT(inT, read))(e => + con(arg)(fun(expT(inT, read))(e => C(ocl.OpenCLFunctionCall(name, n)(inTs :+ inT, fc.outT, es :+ e)) )) // with a `tail` of arguments left, rec case Seq( (arg, inT), tail@_* ) => - con(arg)(λ(expT(inT, read))(e => rec(tail, es :+ e, inTs :+ inT) )) + con(arg)(fun(expT(inT, read))(e => rec(tail, es :+ e, inTs :+ inT) )) } } @@ -387,7 +387,7 @@ object ContinuationTranslation { case reduceSeq@ocl.ReduceSeq(unroll) => val (n, a, dt1, dt2, f, init, array) = reduceSeq.unwrap - con(array)(λ(expT(n`.`dt1, read))(X => { + con(array)(fun(expT(n`.`dt1, read))(X => { val adj = AdjustArraySizesForAllocations(init, dt2, a) comment("oclReduceSeq") `;` @@ -420,19 +420,19 @@ object ContinuationTranslation { println("WARNING: map loop continuation translation allocates memory") // TODO should be removed - `new`(n `.` dt2, λ(varT(n `.` dt2))(tmp => + `new`(n `.` dt2, fun(varT(n `.` dt2))(tmp => acc(map)(tmp.wr) `;` C(tmp.rd))) - case m@cuda.MapFragment(rows, columns, layers, dt, frag, layout, fun, input) => + case m@cuda.MapFragment(rows, columns, layers, dt, frag, layout, f, input) => val fragType = FragmentType(rows, columns, layers, dt, frag, layout) shine.OpenCL.primitives.imperative.New(AddressSpace.Private, fragType, - λ(VarType(fragType))(fragmentAcc => + fun(VarType(fragType))(fragmentAcc => (if (input.t.accessType.toString == write.toString) acc(input)(fragmentAcc.wr) `;` cudaIm.ForFragment(rows, columns, layers, dt, frag, layout, fragmentAcc.rd, fragmentAcc.wr, - λ(expT(dt, read))(x => - λ(accT(dt))(o => - acc(fun(x))(o)))) + fun(expT(dt, read))(x => + fun(accT(dt))(o => + acc(f(x))(o)))) else acc(m)(fragmentAcc.wr)) `;` C(fragmentAcc.rd))) diff --git a/src/main/scala/shine/DPIA/Compilation/FedeTranslation.scala b/src/main/scala/shine/DPIA/Compilation/FedeTranslation.scala index f24e605a3..c9e96593f 100644 --- a/src/main/scala/shine/DPIA/Compilation/FedeTranslation.scala +++ b/src/main/scala/shine/DPIA/Compilation/FedeTranslation.scala @@ -54,7 +54,7 @@ object FedeTranslation { AsScalarAcc(n, m, dt, C(o)))) case Join(n, m, _, dt, array) => - fedAcc(env)(array)(λ(accT(C.t.inT.dataType))(o => + fedAcc(env)(array)(fun(accT(C.t.inT.dataType))(o => JoinAcc(n, m, dt, C(o)))) case Map(n, dt1, dt2, access, f, array) => @@ -63,10 +63,10 @@ object FedeTranslation { val otype = AccType(dt2) val o = Identifier(freshName("fede_o"), otype) - fedAcc(env)(array)(λ(env.toList.head._2.t)(y => + fedAcc(env)(array)(fun(env.toList.head._2.t)(y => MapAcc(n, dt2, dt1, Lambda(o, - fedAcc(Predef.Map((x, o)))(f(x))(λ(otype)(x => x))), C(y)))) + fedAcc(Predef.Map((x, o)))(f(x))(fun(otype)(x => x))), C(y)))) case MapFst(w, dt1, dt2, dt3, f, record) => val x = Identifier(freshName("fede_x"), ExpType(dt1, write)) @@ -95,11 +95,11 @@ object FedeTranslation { TakeAcc(n, r, dt, C(o)))) case Reorder(n, dt, _, _, idxFinv, input) => - fedAcc(env)(input)(λ(accT(C.t.inT.dataType))(o => + fedAcc(env)(input)(fun(accT(C.t.inT.dataType))(o => ReorderAcc(n, dt, idxFinv, C(o)))) case Split(n, m, _, dt, array) => - fedAcc(env)(array)(λ(accT(C.t.inT.dataType))(o => + fedAcc(env)(array)(fun(accT(C.t.inT.dataType))(o => SplitAcc(n, m, dt, C(o)))) case Transpose(n, m, dt, _, array) => diff --git a/src/main/scala/shine/DPIA/Compilation/StreamTranslation.scala b/src/main/scala/shine/DPIA/Compilation/StreamTranslation.scala index 93cfe90be..89984db0a 100644 --- a/src/main/scala/shine/DPIA/Compilation/StreamTranslation.scala +++ b/src/main/scala/shine/DPIA/Compilation/StreamTranslation.scala @@ -65,7 +65,7 @@ object StreamTranslation { // prologue initialisation forNat(size - 1, i => streamNext(nextInput, i, fun(expT(dt1, read))(x => acc(load(x))(bufWr `@` i)))) `;` - C(nFun(i => + C(nFun(arithexpr.arithmetic.RangeAdd(0, n, 1))(i => fun(expT(size `.` dt2, read) ->: (comm: CommType))(k => // load next value streamNext(nextInput, i + size - 1, fun(expT(dt1, read))(x => @@ -75,8 +75,7 @@ object StreamTranslation { // use neighborhood k(Take(size, n - i - size, dt2, Drop(i, n - i, dt2, Cycle(n, size, dt2, bufRd)))) - ), - arithexpr.arithmetic.RangeAdd(0, n, 1) + ) )) }) })) @@ -86,12 +85,12 @@ object StreamTranslation { str(array)(fun((i: NatIdentifier) ->: (expT(dt1, read) ->: (comm: CommType)) ->: (comm: CommType) )(next => - C(nFun(i => + C(nFun(arithexpr.arithmetic.RangeAdd(0, n, 1))(i => fun(expT(dt2, read) ->: (comm: CommType))(k => streamNext(next, i, fun(expT(dt1, read))(x => con(f(x))(k) )) - ), arithexpr.arithmetic.RangeAdd(0, n, 1))))) + ))))) case RotateValues(n, size, dt, write, input) => val i = NatIdentifier(freshName("i")) @@ -103,7 +102,7 @@ object StreamTranslation { // prologue initialisation forNat(unroll = true, size - 1, i => streamNext(nextInput, i, fun(expT(dt, read))(x => acc(write(x))(rs.wr `@` i) ))) `;` - C(nFun(i => + C(nFun(arithexpr.arithmetic.RangeAdd(0, n, 1))(i => fun(expT(size `.` dt, read) ->: (comm: CommType))(k => // load next value streamNext(nextInput, i + size - 1, fun(expT(dt, read))(x => @@ -115,8 +114,8 @@ object StreamTranslation { comment("mapSeq") `;` `for`(unroll = true, size - 1, i => acc(write(Drop(1, size - 1, dt, rs.rd) `@` i))(TakeAcc(size - 1, 1, dt, rs.wr) `@` i)) - ), - arithexpr.arithmetic.RangeAdd(0, n, 1))) + ) + )) })) })) @@ -128,14 +127,14 @@ object StreamTranslation { str(e2)(fun((i: NatIdentifier) ->: (expT(dt2, read) ->: (comm: CommType)) ->: (comm: CommType) )(next2 => - C(nFun(i => fun(expT(dt1 x dt2, read) ->: (comm: CommType))(k => - Apply(DepApply(NatKind, next1, i), - fun(expT(dt1, read))(x1 => - Apply(DepApply(NatKind, next2, i), - fun(expT(dt2, read))(x2 => - k(MakePair(dt1, dt2, read, x1, x2)) - ))))), - arithexpr.arithmetic.RangeAdd(0, n, 1))))))) + C(nFun(arithexpr.arithmetic.RangeAdd(0, n, 1))(i => + fun(expT(dt1 x dt2, read) ->: (comm: CommType))(k => + Apply(DepApply(NatKind, next1, i), + fun(expT(dt1, read))(x1 => + Apply(DepApply(NatKind, next2, i), + fun(expT(dt2, read))(x2 => + k(MakePair(dt1, dt2, read, x1, x2)) + ))))))))))) // OpenCL case ocl.CircularBuffer(a, n, alloc, size, dt1, dt2, load, input) => @@ -166,7 +165,7 @@ object StreamTranslation { // prologue initialisation forNat(size - 1, i => streamNext(nextInput, i, fun(expT(dt1, read))(x => acc(load(x))(bufWr `@` i)))) `;` - C(nFun(i => + C(nFun(arithexpr.arithmetic.RangeAdd(0, n, 1))(i => fun(expT(size`.`dt2, read) ->: (comm: CommType))(k => // load next value streamNext(nextInput, i + size - 1, fun(expT(dt1, read))(x => @@ -176,8 +175,7 @@ object StreamTranslation { // use neighborhood k(Drop(i, size, dt2, Cycle(i + size, alloc, dt2, bufRd))) - ), - arithexpr.arithmetic.RangeAdd(0, n, 1) + ) )) }) })) @@ -192,7 +190,7 @@ object StreamTranslation { // prologue initialisation forNat(unroll = true, size - 1, i => streamNext(nextInput, i, fun(expT(dt, read))(x => acc(write(x))(rs.wr `@` i) ))) `;` - C(nFun(i => + C(nFun(arithexpr.arithmetic.RangeAdd(0, n, 1))(i => fun(expT(size`.`dt, read) ->: (comm: CommType))(k => // load next value streamNext(nextInput, i + size - 1, fun(expT(dt, read))(x => @@ -204,8 +202,8 @@ object StreamTranslation { comment("mapSeq")`;` `for`(unroll = true, size - 1, i => acc(write(Drop(1, size - 1, dt, rs.rd) `@` i))(TakeAcc(size - 1, 1, dt, rs.wr) `@` i)) - ), - arithexpr.arithmetic.RangeAdd(0, n, 1))) + ) + )) }) })) @@ -221,9 +219,10 @@ object StreamTranslation { E.t match { case ExpType(ArrayType(n, dt), read) => - C(nFun(i => fun(expT(dt, read) ->: (comm: CommType))(k => - con(E `@` i)(k) - ), arithexpr.arithmetic.RangeAdd(0, n, 1))) + C(nFun(arithexpr.arithmetic.RangeAdd(0, n, 1))(i => + fun(expT(dt, read) ->: (comm: CommType))(k => + con(E `@` i)(k) + ))) case _ => throw new Exception("this should not happen") } } diff --git a/src/main/scala/shine/DPIA/DSL/Core.scala b/src/main/scala/shine/DPIA/DSL/Core.scala index 0869268bd..88770be49 100644 --- a/src/main/scala/shine/DPIA/DSL/Core.scala +++ b/src/main/scala/shine/DPIA/DSL/Core.scala @@ -6,53 +6,26 @@ import shine.DPIA.Types._ import shine.DPIA._ object identifier { - def apply[T <: PhraseType](name: String, t: T) = Identifier(name, t) + def apply[T <: PhraseType](name: String, t: T): Identifier[T] = Identifier(name, t) } -trait funDef { - - def apply[T1 <: PhraseType, T2 <: PhraseType](t: T1) - (f: Identifier[T1] => Phrase[T2]): Lambda[T1, T2] = { +case class fun[T1 <: PhraseType, T2 <: PhraseType](t: T1) { + def apply(f: Identifier[T1] => Phrase[T2]): Lambda[T1, T2] = { val param = identifier(freshName("x"), t) Lambda(param, f(param)) } - } -object fun extends funDef - -object \ extends funDef - -object λ extends funDef - -object nFun { - def apply[T <: PhraseType](f: NatIdentifier => Phrase[T], - range: arithexpr.arithmetic.Range): DepLambda[Nat, NatIdentifier, T] = { +case class nFun(range: arithexpr.arithmetic.Range) { + def apply[T <: PhraseType](f: NatIdentifier => Phrase[T]): DepLambda[Nat, NatIdentifier, T] = { val x = NatIdentifier(freshName("n"), range) DepLambda(NatKind, x, f(x)) } } -trait depFunDef { - def apply[T, I](kind: Kind[T, I]): Object { - def apply[U <: PhraseType](f: I => Phrase[U]): DepLambda[T, I, U] - } = new { - def apply[U <: PhraseType](f: I => Phrase[U]): DepLambda[T, I, U] = { - val x = Kind.makeIdentifier(kind) - DepLambda(kind, x, f(x)) - } +case class depFun[T, I](kind: Kind[T, I]) { + def apply[U <: PhraseType](f: I => Phrase[U]): DepLambda[T, I, U] = { + val x = Kind.makeIdentifier(kind) + DepLambda(kind, x, f(x)) } } - -object depFun extends depFunDef -object _Λ_ extends depFunDef - -object π1 { - def apply[T1 <: PhraseType, T2 <: PhraseType](pair: Phrase[T1 x T2]) = - Proj1(pair) -} - -object π2 { - def apply[T1 <: PhraseType, T2 <: PhraseType](pair: Phrase[T1 x T2]) = - Proj2(pair) -} diff --git a/src/main/scala/shine/DPIA/DSL/ImperativePrimitives.scala b/src/main/scala/shine/DPIA/DSL/ImperativePrimitives.scala index 742c0f3d0..6099b4953 100644 --- a/src/main/scala/shine/DPIA/DSL/ImperativePrimitives.scala +++ b/src/main/scala/shine/DPIA/DSL/ImperativePrimitives.scala @@ -20,7 +20,7 @@ object `new` { def apply(dt: DataType, f: Phrase[VarType] => Phrase[CommType]): New = - New(dt, λ(varT(dt))( v => f(v) )) + New(dt, fun(varT(dt))(v => f(v) )) } object newDoubleBuffer { @@ -30,7 +30,7 @@ object newDoubleBuffer { in: Phrase[ExpType], out: Phrase[AccType], f: (Phrase[VarType], Phrase[CommType], Phrase[CommType]) => Phrase[CommType]): NewDoubleBuffer = - NewDoubleBuffer(dt1, dt2, dt3.elemType, dt3.size, in, out, λ(varT(dt1) x CommType() x CommType())(ps => { + NewDoubleBuffer(dt1, dt2, dt3.elemType, dt3.size, in, out, fun(varT(dt1) x CommType() x CommType())(ps => { val v: Phrase[VarType] = ps._1._1 val swap: Phrase[CommType] = ps._1._2 val done: Phrase[CommType] = ps._2 @@ -45,11 +45,16 @@ object `if` { IfThenElse(cond, thenP, elseP) //noinspection TypeAnnotation - def apply(cond: Phrase[ExpType]) = new { - def `then`[T <: PhraseType](thenP: Phrase[T]) = new { - def `else`(elseP: Phrase[T]): IfThenElse[T] = { - IfThenElse(cond, thenP, elseP) - } + def apply(cond: Phrase[ExpType]): IfHelper = IfHelper(cond) + + case class IfHelper(cond: Phrase[ExpType]) { + def `then`[T <: PhraseType](thenP: Phrase[T]): ThenHelper[T] = + ThenHelper(cond, thenP) + } + + case class ThenHelper[T <: PhraseType](cond: Phrase[ExpType], thenP: Phrase[T]) { + def `else`(elseP: Phrase[T]): IfThenElse[T] = { + IfThenElse(cond, thenP, elseP) } } } @@ -57,14 +62,14 @@ object `if` { object `for` { def apply(n: Nat, f: Identifier[ExpType] => Phrase[CommType]): For = apply(false, n, f) def apply(unroll: Boolean, n: Nat, f: Identifier[ExpType] => Phrase[CommType]): For = - For(unroll)(n, λ(expT(idx(n), read))( i => f(i) )) + For(unroll)(n, fun(expT(idx(n), read))(i => f(i) )) } object forNat { def apply(n: Nat, f: NatIdentifier => Phrase[CommType]): ForNat = apply(false, n, f) def apply(unroll: Boolean, n: Nat, f: NatIdentifier => Phrase[CommType]): ForNat = { import arithexpr.arithmetic.RangeAdd - ForNat(unroll)(n, nFun(i => f(i), RangeAdd(0, n, 1))) + ForNat(unroll)(n, nFun(RangeAdd(0, n, 1))(i => f(i))) } } diff --git a/src/main/scala/shine/DPIA/DSL/package.scala b/src/main/scala/shine/DPIA/DSL/package.scala index 381a4980b..b384f8199 100644 --- a/src/main/scala/shine/DPIA/DSL/package.scala +++ b/src/main/scala/shine/DPIA/DSL/package.scala @@ -77,7 +77,11 @@ package object DSL { //noinspection TypeAnnotation implicit class AssignmentHelper(lhs: Phrase[AccType]) { - def :=|(dt: DataType) = new { + def :=|(dt: DataType): AssignmentHelper.SyntaxHelper = AssignmentHelper.SyntaxHelper(lhs, dt) + } + + object AssignmentHelper { + case class SyntaxHelper(lhs: Phrase[AccType], dt: DataType) { def |(rhs: Phrase[ExpType]) (implicit context: TranslationContext): Phrase[CommType] = { context.assign(dt, lhs, rhs) @@ -113,7 +117,7 @@ package object DSL { implicit class FunComp[T1 <: PhraseType, T2 <: PhraseType](f: Phrase[T1 ->: T2]) { def o[T3 <: PhraseType](g: Phrase[T3 ->: T1]): Phrase[T3 ->: T2] = { - λ(g.t.inT)(arg => f(g(arg))) + fun(g.t.inT)(arg => f(g(arg))) } } @@ -122,13 +126,13 @@ package object DSL { } implicit class VarExtensions(v: Phrase[VarType]) { - def rd: Proj1[ExpType, AccType] = π1(v) - def wr: Proj2[ExpType, AccType] = π2(v) + def rd: Proj1[ExpType, AccType] = Proj1(v) + def wr: Proj2[ExpType, AccType] = Proj2(v) } implicit class PairExtensions[T1 <: PhraseType, T2 <: PhraseType](v: Phrase[T1 x T2]) { - def _1: Proj1[T1, T2] = π1(v) - def _2: Proj2[T1, T2] = π2(v) + def _1: Proj1[T1, T2] = Proj1(v) + def _2: Proj2[T1, T2] = Proj2(v) } def mapTransientNat(natExpr: Phrase[ExpType], f: Nat => Nat): Phrase[ExpType] = { diff --git a/src/main/scala/shine/DPIA/Lifting.scala b/src/main/scala/shine/DPIA/Lifting.scala index 2f02892d6..b0bffa86e 100644 --- a/src/main/scala/shine/DPIA/Lifting.scala +++ b/src/main/scala/shine/DPIA/Lifting.scala @@ -36,7 +36,7 @@ object Lifting { p match { case l: Lambda[T1, T2] => - Reducing((arg: Phrase[T1]) => l.body `[` arg `/` l.param `]`) + Reducing((arg: Phrase[T1]) => Phrase.substitute(arg, `for` = l.param, in = l.body)) case app: Apply[_, T1 ->: T2] => chain(liftFunction(app.fun).map(lf => lf(app.arg))) case DepApply(_, f, arg) => @@ -56,7 +56,7 @@ object Lifting { def liftFunctionToNatLambda[T <: PhraseType](p: Phrase[ExpType ->: T]): Nat => Phrase[T] = { p match { case l: Lambda[ExpType, T] => - (arg: Nat) => l.body `[` arg `/` NatIdentifier(l.param.name) `]` + (arg: Nat) => Types.substitute(arg, `for` = NatIdentifier(l.param.name), in = l.body) case app: Apply[_, ExpType ->: T] => val fun = liftFunction(app.fun).reducing liftFunctionToNatLambda(fun(app.arg)) diff --git a/src/main/scala/shine/DPIA/Phrases/Phrase.scala b/src/main/scala/shine/DPIA/Phrases/Phrase.scala index 240f927cd..53bfdda69 100644 --- a/src/main/scala/shine/DPIA/Phrases/Phrase.scala +++ b/src/main/scala/shine/DPIA/Phrases/Phrase.scala @@ -47,9 +47,9 @@ final case class DepLambda[T, I, U <: PhraseType](kind: Kind[T, I], } object DepLambda { - def apply[T, I](kind: Kind[T, I], x: I): Object { - def apply[U <: PhraseType](body: Phrase[U]): DepLambda[T, I, U] - } = new { + def apply[T, I](kind: Kind[T, I], x: I): Helper[T, I] = Helper(kind, x) + + case class Helper[T, I](kind: Kind[T, I], x: I) { def apply[U <: PhraseType](body: Phrase[U]): DepLambda[T, I, U] = DepLambda(kind, x, body) } } diff --git a/src/main/scala/shine/DPIA/Types/MatchingDSL.scala b/src/main/scala/shine/DPIA/Types/MatchingDSL.scala index 234116c63..c2fdc51b1 100644 --- a/src/main/scala/shine/DPIA/Types/MatchingDSL.scala +++ b/src/main/scala/shine/DPIA/Types/MatchingDSL.scala @@ -5,7 +5,7 @@ object MatchingDSL { object ->: { def unapply[T <: PhraseType, R <: PhraseType](funType: FunType[T, R] ): Option[(T, R)] = { - FunType.unapply(funType) + Some((funType.inT, funType.outT)) } } diff --git a/src/main/scala/shine/DPIA/fromRise.scala b/src/main/scala/shine/DPIA/fromRise.scala index 2944bf635..1bc9ae792 100644 --- a/src/main/scala/shine/DPIA/fromRise.scala +++ b/src/main/scala/shine/DPIA/fromRise.scala @@ -87,12 +87,8 @@ object fromRise { Lambda(x, f(x)) } - object depFun { - def apply[T, I](kind: rt.Kind[T, I], x: I): Object { - def apply[U <: PhraseType](body: Phrase[U]): DepLambda[T, I, U] - } = new { - def apply[U <: PhraseType](body: Phrase[U]): DepLambda[T, I, U] = DepLambda(kind, x, body) - } + case class depFun[T, I](kind: rt.Kind[T, I], x: I) { + def apply[U <: PhraseType](body: Phrase[U]): DepLambda[T, I, U] = DepLambda(kind, x, body) } private def primitive(p: r.Primitive, diff --git a/src/main/scala/shine/DPIA/package.scala b/src/main/scala/shine/DPIA/package.scala index 96118c8b3..a3b5cce20 100644 --- a/src/main/scala/shine/DPIA/package.scala +++ b/src/main/scala/shine/DPIA/package.scala @@ -10,7 +10,7 @@ import rise.core.{types => rt} package object DPIA { def error(found: String, expected: String): Nothing = { - throw new TypeException(s"Found $found but expected $expected") + throw TypeException(s"Found $found but expected $expected") } def error(msg: String = "This should not happen"): Nothing = { @@ -48,42 +48,6 @@ package object DPIA { def apply(dt: DataType): PhrasePairType[ExpType, AccType] = ExpType(dt, read) x AccType(dt) } - //noinspection TypeAnnotation - implicit class PhraseTypeSubstitutionHelper[T <: PhraseType](t: PhraseType) { - def `[`(e: Nat) = new { - def `/`(a: NatIdentifier) = new { - def `]`: PhraseType = shine.DPIA.Types.substitute(e, `for`=a, in=t) - } - } - - def `[`(e: DataType) = new { - def `/`(a: DataType) = new { - def `]`: PhraseType = shine.DPIA.Types.substitute(e, `for`=a, in=t) - } - } - } - - //noinspection TypeAnnotation - implicit class PhraseSubstitutionHelper[T1 <: PhraseType](in: Phrase[T1]) { - def `[`[T2 <: PhraseType](p: Phrase[T2]) = new { - def `/`(`for`: Phrase[T2]) = new { - def `]`: Phrase[T1] = Phrase.substitute(p, `for`, in) - } - } - - def `[`(e: Nat) = new { - def `/`(`for`: NatIdentifier) = new { - def `]`: Phrase[T1] = shine.DPIA.Types.substitute(e, `for`, in) - } - } - - def `[`(dt: DataType) = new { - def `/`(`for`: DataTypeIdentifier) = new { - def `]`: Phrase[T1] = shine.DPIA.Types.substitute(dt, `for`, in) - } - } - } - implicit class PairTypeConstructor[T1 <: PhraseType](t1: T1) { @inline def x[T2 <: PhraseType](t2: T2): T1 x T2 = PhrasePairType(t1, t2) @@ -109,7 +73,7 @@ package object DPIA { def apply(dt: DataType, a: Access): ExpType = ExpType(dt, a) def unapply(et: ExpType): Option[(DataType, Access)] = { - ExpType.unapply(et) + Some((et.dataType, et.accessType)) } } @@ -117,7 +81,7 @@ package object DPIA { def apply(dt: DataType): AccType = AccType(dt) def unapply(at: AccType): Option[DataType] = { - AccType.unapply(at) + Some(at.dataType) } } diff --git a/src/main/scala/shine/OpenCL/Compilation/Passes/HoistMemoryAllocations.scala b/src/main/scala/shine/OpenCL/Compilation/Passes/HoistMemoryAllocations.scala index d123e97bb..d05dae14b 100644 --- a/src/main/scala/shine/OpenCL/Compilation/Passes/HoistMemoryAllocations.scala +++ b/src/main/scala/shine/OpenCL/Compilation/Passes/HoistMemoryAllocations.scala @@ -109,16 +109,16 @@ object HoistMemoryAllocations { case Left(identExpr) => Phrase.substitute( substitutionMap = Map( - π1(oldVariable) -> (π1(newVariable) `@` identExpr), - π2(oldVariable) -> (π2(newVariable) `@` identExpr) + Proj1(oldVariable) -> (Proj1(newVariable) `@` identExpr), + Proj2(oldVariable) -> (Proj2(newVariable) `@` identExpr) ), in = oldBody ) case Right(identNat) => Phrase.substitute( substitutionMap = Map( - π1(oldVariable) -> (π1(newVariable) `@` identNat), - π2(oldVariable) -> (π2(newVariable) `@` identNat) + Proj1(oldVariable) -> (Proj1(newVariable) `@` identNat), + Proj2(oldVariable) -> (Proj2(newVariable) `@` identNat) ), in = oldBody ) diff --git a/src/main/scala/shine/OpenCL/DSL/package.scala b/src/main/scala/shine/OpenCL/DSL/package.scala index 8b6de3fb8..8c1f041ea 100644 --- a/src/main/scala/shine/OpenCL/DSL/package.scala +++ b/src/main/scala/shine/OpenCL/DSL/package.scala @@ -42,7 +42,7 @@ package object DSL { private def parForBodyFunction(n:Nat, ft:NatToData, f:NatIdentifier => Phrase[AccType] => Phrase[CommType] ): DepLambda[Nat, NatIdentifier, AccType ->: CommType] = { - nFun(idx => λ(accT(ft(idx)))(o => f(idx)(o)), RangeAdd(0, n, 1)) + nFun(RangeAdd(0, n, 1))(idx => fun(accT(ft(idx)))(o => f(idx)(o))) } def parForNatGlobal(dim:Int)(n:Nat, ft:NatToData, out:Phrase[AccType], @@ -60,7 +60,7 @@ package object DSL { object `new` { def apply(addrSpace: AddressSpace) (dt: DataType, f: Phrase[VarType] => Phrase[CommType]): New = - New(addrSpace, dt, λ(varT(dt))(v => f(v) )) + New(addrSpace, dt, fun(varT(dt))(v => f(v) )) } object newDoubleBuffer { @@ -72,7 +72,7 @@ package object DSL { out: Phrase[AccType], f: (Phrase[VarType], Phrase[CommType], Phrase[CommType]) => Phrase[CommType]): NewDoubleBuffer = NewDoubleBuffer(a, dt1, dt2, dt3.elemType, dt3.size, in, out, - λ(varT(dt1) x CommType() x CommType())(ps => { + fun(varT(dt1) x CommType() x CommType())(ps => { val v: Phrase[VarType] = ps._1._1 val swap: Phrase[CommType] = ps._1._2 val done: Phrase[CommType] = ps._2 diff --git a/src/main/scala/shine/OpenMP/DSL/ImperativePrimitives.scala b/src/main/scala/shine/OpenMP/DSL/ImperativePrimitives.scala index f5e0646b2..60708db60 100644 --- a/src/main/scala/shine/OpenMP/DSL/ImperativePrimitives.scala +++ b/src/main/scala/shine/OpenMP/DSL/ImperativePrimitives.scala @@ -15,7 +15,7 @@ object parFor { dt: DataType, out: Phrase[AccType], f: Phrase[ExpType] => Phrase[AccType] => Phrase[CommType]): ParFor = - ParFor(n, dt, out, λ(expT(idx(n), read))(i => λ(accT(dt))(o => f(i)(o) ))) + ParFor(n, dt, out, fun(expT(idx(n), read))(i => fun(accT(dt))(o => f(i)(o) ))) } object `parForVec` { @@ -23,12 +23,12 @@ object `parForVec` { st: DataType, out: Phrase[AccType], f: Phrase[ExpType] => Phrase[AccType] => Phrase[CommType]): ForVec = - ForVec(n, st, out, λ(expT(idx(n), read))(i => λ(accT(st))(o => f(i)(o) ))) + ForVec(n, st, out, fun(expT(idx(n), read))(i => fun(accT(st))(o => f(i)(o) ))) } object parForNat { def apply(n: Nat, ft: NatToData, out: Phrase[AccType], f: NatIdentifier => Phrase[AccType] => Phrase[CommType]): ParForNat = { - ParForNat(n, ft, out, nFun(idx => λ(accT(ft(idx)))(o => f(idx)(o)), RangeAdd(0, n, 1))) + ParForNat(n, ft, out, nFun(RangeAdd(0, n, 1))(idx => fun(accT(ft(idx)))(o => f(idx)(o)))) } } diff --git a/src/main/scala/shine/OpenMP/TranslationContext.scala b/src/main/scala/shine/OpenMP/TranslationContext.scala index 4e2eca392..7bb280390 100644 --- a/src/main/scala/shine/OpenMP/TranslationContext.scala +++ b/src/main/scala/shine/OpenMP/TranslationContext.scala @@ -1,7 +1,7 @@ package shine.OpenMP import shine.C -import shine.DPIA.DSL.{ExpPhraseExtensions, λ} +import shine.DPIA.DSL.{ExpPhraseExtensions, fun} import shine.DPIA.Phrases.Phrase import shine.DPIA.Types.{AccType, CommType, ExpType} import rise.core.types.{DataType, read} diff --git a/src/main/scala/shine/cuda/Compilation/Passes/HoistMemoryAllocations.scala b/src/main/scala/shine/cuda/Compilation/Passes/HoistMemoryAllocations.scala index 48e7c2b71..3e7d43660 100644 --- a/src/main/scala/shine/cuda/Compilation/Passes/HoistMemoryAllocations.scala +++ b/src/main/scala/shine/cuda/Compilation/Passes/HoistMemoryAllocations.scala @@ -1,8 +1,8 @@ package shine.cuda.Compilation.Passes import arithexpr.arithmetic.ArithExpr.Math.Min -import shine.DPIA.DSL.{π1, π2, _} -import shine.DPIA.Phrases.{Identifier, Lambda, Phrase, VisitAndRebuild} +import shine.DPIA.DSL._ +import shine.DPIA.Phrases.{Identifier, Lambda, Phrase, Proj1, Proj2, VisitAndRebuild} import rise.core.types._ import rise.core.DSL.Type._ import shine.DPIA.Types.{CommType, ExpType, PhraseType} @@ -105,16 +105,16 @@ object HoistMemoryAllocations { case Left(identExpr) => Phrase.substitute( substitutionMap = Map( - π1(oldVariable) -> (π1(newVariable) `@` identExpr), - π2(oldVariable) -> (π2(newVariable) `@` identExpr) + Proj1(oldVariable) -> (Proj1(newVariable) `@` identExpr), + Proj2(oldVariable) -> (Proj2(newVariable) `@` identExpr) ), in = oldBody ) case Right(identNat) => Phrase.substitute( substitutionMap = Map( - π1(oldVariable) -> (π1(newVariable) `@` identNat), - π2(oldVariable) -> (π2(newVariable) `@` identNat) + Proj1(oldVariable) -> (Proj1(newVariable) `@` identNat), + Proj2(oldVariable) -> (Proj2(newVariable) `@` identNat) ), in = oldBody ) diff --git a/src/main/scala/shine/cuda/Compilation/TranslationContext.scala b/src/main/scala/shine/cuda/Compilation/TranslationContext.scala index bfb2da93c..9e15bfd10 100644 --- a/src/main/scala/shine/cuda/Compilation/TranslationContext.scala +++ b/src/main/scala/shine/cuda/Compilation/TranslationContext.scala @@ -16,8 +16,8 @@ class TranslationContext() extends shine.OpenCL.Compilation.TranslationContext { dt match { case FragmentType(rows, columns, layers, dt, frag, layout) => ForFragment(rows, columns, layers, dt, frag, layout, rhs, lhs, - λ(expT(dt, read))(x => - λ(accT(dt))(o => + fun(expT(dt, read))(x => + fun(accT(dt))(o => Assign(dt, o, x)))) case _ => diff --git a/src/main/scala/util/monads.scala b/src/main/scala/util/monads.scala index baa07dac4..286d9e4f5 100644 --- a/src/main/scala/util/monads.scala +++ b/src/main/scala/util/monads.scala @@ -12,7 +12,10 @@ object monads { bind(mx)(x => bind(mxs)(xs => return_(x +: xs)))}) } - implicit def monadicSyntax[M[_], A](m: M[A])(implicit tc: Monad[M]) = new { + implicit def monadicSyntax[M[_], A](m: M[A])(implicit tc: Monad[M]): MonadicSyntax[M, A] = + new MonadicSyntax[M, A](m, tc) + + class MonadicSyntax[M[_], A](m: M[A], tc: Monad[M]) { def map[B](f: A => B): M[B] = tc.bind(m)(a => tc.return_(f(a)) ) def flatMap[B](f: A => M[B]): M[B] = tc.bind(m)(f) } diff --git a/src/test/scala/shine/cuda/basic.scala b/src/test/scala/shine/cuda/basic.scala index 768bfcc0f..d78412f36 100644 --- a/src/test/scala/shine/cuda/basic.scala +++ b/src/test/scala/shine/cuda/basic.scala @@ -2,7 +2,7 @@ package shine.cuda import rise.core.types.DataType._ import rise.core.types._ -import shine.DPIA.DSL.{depFun, λ} +import shine.DPIA.DSL.{depFun, fun} import shine.DPIA.Types.ExpType import shine.OpenCL.{Global, Local} import util.gen @@ -11,8 +11,8 @@ class basic extends test_util.Tests { test("id with mapThreads compiles to syntactically correct Cuda") { val mapId = depFun(NatKind)(n => - λ(ExpType(ArrayType(n, f32), read))(array => - shine.cuda.primitives.functional.Map(Local, 'x')(n, f32, f32, λ(ExpType(f32, read))(x => x), array)) + fun(ExpType(ArrayType(n, f32), read))(array => + shine.cuda.primitives.functional.Map(Local, 'x')(n, f32, f32, fun(ExpType(f32, read))(x => x), array)) ) val code = gen.cuda.kernel.asStringFromPhrase(mapId) @@ -21,8 +21,8 @@ class basic extends test_util.Tests { test("id with mapGlobal compiles to syntactically correct CUDA") { val mapId = depFun(NatKind)(n => - λ(ExpType(ArrayType(n, f32), read))(array => - shine.cuda.primitives.functional.Map(Global, 'x')(n, f32, f32, λ(ExpType(f32, read))(x => x), array)) + fun(ExpType(ArrayType(n, f32), read))(array => + shine.cuda.primitives.functional.Map(Global, 'x')(n, f32, f32, fun(ExpType(f32, read))(x => x), array)) ) val code = gen.cuda.kernel.asStringFromPhrase(mapId)