Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Optimize counter bitwidth in Foreach control #293

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 52 additions & 10 deletions src/spatial/transform/CounterBitwidthTransformer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package spatial.transform
import argon._
import argon.node._
import argon.transform.MutateTransformer

import spatial.lang._
import spatial.node._
import spatial.util.shouldMotionFromConditional
Expand All @@ -11,29 +12,70 @@ import spatial.metadata.control._
import spatial.metadata.memory._
import spatial.metadata.blackbox._

import utils.math.log2Up

import emul.FixedPoint

case class CounterBitwidthTransformer(IR: State) extends MutateTransformer
with AccelTraversal {

private def optimizeCtrs(ctrs: Seq[Sym[_]]): Seq[Sym[_]] = {
/** Calculate the least bitwidth required for an integer. */
private def getBitwidth(x: Int): Int = log2Up(x.abs)

/** Extract the content from Const and cast it to Int. */
private def constToInt(x: Sym[_]): Int =
if (x.isConst)
x.c.get.asInstanceOf[FixedPoint].toInt
else
throw new Exception(s"$x is not a Const.")

/** Create a new CounterNew object with compact bitwidth. */
private def getOptimizedCounterNew(ctr: CounterNew[_]): CounterNew[_] = ctr match {
case CounterNew(start, stop, step, par) =>
// we take the largest magnitude of start and stop to decide the boundary of bitwidth
val begin = constToInt(start)
val end = constToInt(stop)
val bitwidth = math.max(getBitwidth(begin), getBitwidth(end))

// TODO: Find a better way that can map bitwidth to the exact Fix type
if (bitwidth <= 7) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mattfel1 Hi there, I feel that it would be tedious to implement all the bitwidth-to-type cast and there should be a better way that I'm not aware of. Maybe you've met this scenario before and have a good way to deal with it? Thanks!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately, I don't think there is a non-tedious way to do this.
Since each bitwidth is its own trait (argon/lang/types/CustomBitWidths.scala), its painful to work with. Some people have used quasiquotes for this problem before but there isn't a nice way that I know of.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @mattfel1 ! In my latest update I manually added all the mappings for different bit-width values. Hope it looks fine.

type T = Fix[TRUE,_8,_0]
CounterNew[T](begin.to[T], end.to[T], constToInt(step).to[T], par)
} else if (bitwidth <= 15) {
type T = Fix[TRUE,_16,_0]
CounterNew[T](begin.to[T], end.to[T], constToInt(step).to[T], par)
} else {
type T = Fix[TRUE,_32,_0]
CounterNew[T](begin.to[T], end.to[T], constToInt(step).to[T], par)
}
case _ => ctr
}

/** Optimize a list of Counter. */
private def getOptimizeCounters(ctrs: Seq[Counter[_]]): Seq[Counter[_]] = {
ctrs.map {
case Op(CounterNew(start, stop, step, par)) =>
println(start, stop, step, par)
case Op(ctr: CounterNew[_]) => stage(getOptimizeCounterNew(ctr))
}
ctrs
}

override def transform[A:Type](lhs: Sym[A], rhs: Op[A])(implicit ctx: SrcCtx): Sym[A] = rhs match {
case AccelScope(_) =>
inAccel { super.transform(lhs, rhs) }

case OpForeach(ens, cchain, block, iters, stopWhen) if inHw =>
println(ens, cchain, block, iters, stopWhen)
println(cchain.node)
optimizeCtrs(cchain.counters)
super.transform(lhs, rhs)
case OpForeach(ens, cchain, blk, iters, stopWhen) if inHw =>
val newctrs = getOptimizedCounters(cchain.counters)
val newcchain = stageWithFlow(CounterChainNew(newctrs)){ lhs2 => transferData(lhs, lhs2)}

stageWithFlow(
OpForeach(
ens,
newcchain,
stageBlock{blk.stms.foreach(visit)},
iters,
stopWhen)
){lhs2 => transferData(lhs, lhs2)}

case _ =>
// println(lhs, rhs)
dbgs(s"visiting $lhs = $rhs");
super.transform(lhs, rhs)
}
Expand Down