Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Optimize counter bitwidth in Foreach control #293

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions apps/src/TestForeachCounter.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import spatial.dsl._

@spatial object TestForeachCounter extends SpatialApp {

def main(args: Array[String]): Unit = {
// Loop upper bound
val N = 128

// The DRAM
val d = DRAM[Int](N)

// DRAM content
val data = Array.fill[Int](N)(0)
setMem(d, data)

Accel {
val s = SRAM[Int](N)

s load d(0::N)

Foreach(N by 1) { i => s(i) = s(i) + i }

d(0::N) store s
}

printArray(getMem(d), "Result: ")
}
}
3 changes: 3 additions & 0 deletions src/spatial/Spatial.scala
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ trait Spatial extends Compiler with ParamLoader {
lazy val retiming = RetimingTransformer(state)
lazy val accumTransformer = AccumTransformer(state)
lazy val regReadCSE = RegReadCSE(state)
lazy val counterBitwidth = CounterBitwidthTransformer(state)

// --- Codegen
lazy val chiselCodegen = ChiselGen(state)
Expand Down Expand Up @@ -156,6 +157,8 @@ trait Spatial extends Compiler with ParamLoader {
/** Dead code elimination */
useAnalyzer ==>
transientCleanup ==> printer ==> transformerChecks ==>
// Counter bitwidth improvement
counterBitwidth ==> printer ==>
/** Stream controller rewrites */
(spatialConfig.distributeStreamCtr ? streamTransformer) ==> printer ==>
/** Memory analysis */
Expand Down
169 changes: 169 additions & 0 deletions src/spatial/transform/CounterBitwidthTransformer.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
package spatial.transform

import argon._
import argon.node._
import argon.transform.MutateTransformer

import spatial.lang._
import spatial.node._
import spatial.util.shouldMotionFromConditional
import spatial.traversal.AccelTraversal
import spatial.metadata.control._
import spatial.metadata.memory._
import spatial.metadata.blackbox._

import utils.math.log2Up

import emul.FixedPoint

case class CounterBitwidthTransformer(IR: State) extends MutateTransformer
with AccelTraversal {

/** Calculate the least bitwidth required for an integer. */
private def getBitwidth(x: Int): Int = log2Up(x.abs)

/** Extract the content from Const and cast it to Int. */
private def constToInt(x: Sym[_]): Int =
if (x.isConst)
x.c.get.asInstanceOf[FixedPoint].toInt
else
throw new Exception(s"$x is not a Const.")

/** Create a new CounterNew object with compact bitwidth. */
private def getOptimizedCounterNew(ctr: CounterNew[_]): CounterNew[_] = ctr match {
case CounterNew(start, stop, step, par) =>
// we take the largest magnitude of start and stop to decide the boundary of bitwidth
val begin = constToInt(start)
val end = constToInt(stop)
val bits = math.max(getBitwidth(begin), getBitwidth(end))

// TODO: Find a better way that can map bitwidth to the exact Fix type
bits match {
case 1 =>
type T = Fix[TRUE,_2 ,_0]
CounterNew[T](begin.to[T], end.to[T], constToInt(step).to[T], par)
case 2 =>
type T = Fix[TRUE,_3 ,_0]
CounterNew[T](begin.to[T], end.to[T], constToInt(step).to[T], par)
case 3 =>
type T = Fix[TRUE,_4 ,_0]
CounterNew[T](begin.to[T], end.to[T], constToInt(step).to[T], par)
case 4 =>
type T = Fix[TRUE,_5 ,_0]
CounterNew[T](begin.to[T], end.to[T], constToInt(step).to[T], par)
case 5 =>
type T = Fix[TRUE,_6 ,_0]
CounterNew[T](begin.to[T], end.to[T], constToInt(step).to[T], par)
case 6 =>
type T = Fix[TRUE,_7 ,_0]
CounterNew[T](begin.to[T], end.to[T], constToInt(step).to[T], par)
case 7 =>
type T = Fix[TRUE,_8 ,_0]
CounterNew[T](begin.to[T], end.to[T], constToInt(step).to[T], par)
case 8 =>
type T = Fix[TRUE,_9 ,_0]
CounterNew[T](begin.to[T], end.to[T], constToInt(step).to[T], par)
case 9 =>
type T = Fix[TRUE,_10,_0]
CounterNew[T](begin.to[T], end.to[T], constToInt(step).to[T], par)
case 10 =>
type T = Fix[TRUE,_11,_0]
CounterNew[T](begin.to[T], end.to[T], constToInt(step).to[T], par)
case 11 =>
type T = Fix[TRUE,_12,_0]
CounterNew[T](begin.to[T], end.to[T], constToInt(step).to[T], par)
case 12 =>
type T = Fix[TRUE,_13,_0]
CounterNew[T](begin.to[T], end.to[T], constToInt(step).to[T], par)
case 13 =>
type T = Fix[TRUE,_14,_0]
CounterNew[T](begin.to[T], end.to[T], constToInt(step).to[T], par)
case 14 =>
type T = Fix[TRUE,_15,_0]
CounterNew[T](begin.to[T], end.to[T], constToInt(step).to[T], par)
case 15 =>
type T = Fix[TRUE,_16,_0]
CounterNew[T](begin.to[T], end.to[T], constToInt(step).to[T], par)
case 16 =>
type T = Fix[TRUE,_17,_0]
CounterNew[T](begin.to[T], end.to[T], constToInt(step).to[T], par)
case 17 =>
type T = Fix[TRUE,_18,_0]
CounterNew[T](begin.to[T], end.to[T], constToInt(step).to[T], par)
case 18 =>
type T = Fix[TRUE,_19,_0]
CounterNew[T](begin.to[T], end.to[T], constToInt(step).to[T], par)
case 19 =>
type T = Fix[TRUE,_20,_0]
CounterNew[T](begin.to[T], end.to[T], constToInt(step).to[T], par)
case 20 =>
type T = Fix[TRUE,_21,_0]
CounterNew[T](begin.to[T], end.to[T], constToInt(step).to[T], par)
case 21 =>
type T = Fix[TRUE,_22,_0]
CounterNew[T](begin.to[T], end.to[T], constToInt(step).to[T], par)
case 22 =>
type T = Fix[TRUE,_23,_0]
CounterNew[T](begin.to[T], end.to[T], constToInt(step).to[T], par)
case 23 =>
type T = Fix[TRUE,_24,_0]
CounterNew[T](begin.to[T], end.to[T], constToInt(step).to[T], par)
case 24 =>
type T = Fix[TRUE,_25,_0]
CounterNew[T](begin.to[T], end.to[T], constToInt(step).to[T], par)
case 25 =>
type T = Fix[TRUE,_26,_0]
CounterNew[T](begin.to[T], end.to[T], constToInt(step).to[T], par)
case 26 =>
type T = Fix[TRUE,_27,_0]
CounterNew[T](begin.to[T], end.to[T], constToInt(step).to[T], par)
case 27 =>
type T = Fix[TRUE,_28,_0]
CounterNew[T](begin.to[T], end.to[T], constToInt(step).to[T], par)
case 28 =>
type T = Fix[TRUE,_29,_0]
CounterNew[T](begin.to[T], end.to[T], constToInt(step).to[T], par)
case 29 =>
type T = Fix[TRUE,_30,_0]
CounterNew[T](begin.to[T], end.to[T], constToInt(step).to[T], par)
case 30 =>
type T = Fix[TRUE,_31,_0]
CounterNew[T](begin.to[T], end.to[T], constToInt(step).to[T], par)
case 31 =>
type T = Fix[TRUE,_32,_0]
CounterNew[T](begin.to[T], end.to[T], constToInt(step).to[T], par)
case _ =>
throw new Exception(s"Bit-width $bits is not supported")
}
case _ => ctr
}

/** Optimize a list of Counter. */
private def getOptimizedCounters(ctrs: Seq[Counter[_]]): Seq[Counter[_]] = {
ctrs.map {
case Op(ctr: CounterNew[_]) => stage(getOptimizedCounterNew(ctr))
}
}

override def transform[A:Type](lhs: Sym[A], rhs: Op[A])(implicit ctx: SrcCtx): Sym[A] = rhs match {
case AccelScope(_) =>
inAccel { super.transform(lhs, rhs) }

case OpForeach(ens, cchain, blk, iters, stopWhen) if inHw =>
val newctrs = getOptimizedCounters(cchain.counters)
val newcchain = stageWithFlow(CounterChainNew(newctrs)){ lhs2 => transferData(lhs, lhs2)}

stageWithFlow(
OpForeach(
ens,
newcchain,
stageBlock{blk.stms.foreach(visit)},
iters,
stopWhen)
){lhs2 => transferData(lhs, lhs2)}

case _ =>
dbgs(s"visiting $lhs = $rhs");
super.transform(lhs, rhs)
}
}