Skip to content

Commit

Permalink
[query] Push through requiredness on EmitCodes (#10445)
Browse files Browse the repository at this point in the history
* [query] Push through requiredness on EmitCodes

EmitCodes and their EmitTypes, not ptypes, are to be the authoritative
source of requiredness information in the code generator.

This is the next step toward removing top-level requiredness from ptypes.

* fix min/max requiredness

* fix nda req

* remove extra paren

* fix inferptype

* fix stream test
  • Loading branch information
tpoterba authored May 6, 2021
1 parent 51c3627 commit 40e8b6b
Show file tree
Hide file tree
Showing 17 changed files with 205 additions and 126 deletions.
130 changes: 85 additions & 45 deletions hail/src/main/scala/is/hail/expr/ir/Emit.scala
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import is.hail.expr.ir.streams.{EmitStream, StreamProducer, StreamUtils}
import is.hail.io.{BufferSpec, InputBuffer, OutputBuffer}
import is.hail.linalg.{BLAS, LAPACK, LinalgCodeUtils}
import is.hail.services.shuffler._
import is.hail.types.TypeWithRequiredness
import is.hail.types.physical._
import is.hail.types.physical.stypes.concrete.{SBaseStructPointerCode, SCanonicalShufflePointer, SCanonicalShufflePointerCode, SCanonicalShufflePointerSettable}
import is.hail.types.physical.stypes.interfaces.{SBaseStructCode, SNDArray, SNDArrayCode, SStreamCode}
Expand Down Expand Up @@ -225,7 +226,7 @@ object IEmitCode {
}

def apply[A](Lmissing: CodeLabel, Lpresent: CodeLabel, value: A, required: Boolean): IEmitCodeGen[A] =
IEmitCodeGen(Lmissing, Lpresent, value, false)
IEmitCodeGen(Lmissing, Lpresent, value, required)

def present[A](cb: EmitCodeBuilder, value: => A): IEmitCodeGen[A] = {
val Lpresent = CodeLabel()
Expand Down Expand Up @@ -375,10 +376,16 @@ case class IEmitCodeGen[+A](Lmissing: CodeLabel, Lpresent: CodeLabel, value: A,

object EmitCode {
def apply(setup: Code[Unit], m: Code[Boolean], pv: PCode): EmitCode = {
val mCC = Code(setup, m).toCCode
val iec = IEmitCode(new CodeLabel(mCC.Ltrue), new CodeLabel(mCC.Lfalse), pv, false)
val result = new EmitCode(new CodeLabel(mCC.entry), iec)
result
Code.constBoolValue(m) match {
case Some(false) =>
val Lpresent = CodeLabel()
new EmitCode(new CodeLabel(Code(setup, Lpresent.goto).start), IEmitCode(CodeLabel(), Lpresent, pv, required = true))
case _ =>
val mCC = Code(setup, m).toCCode
val iec = IEmitCode(new CodeLabel(mCC.Ltrue), new CodeLabel(mCC.Lfalse), pv, required = false)
val result = new EmitCode(new CodeLabel(mCC.entry), iec)
result
}
}

def unapply(ec: EmitCode): Option[(Code[Unit], Code[Boolean], PCode)] =
Expand Down Expand Up @@ -474,9 +481,9 @@ class RichIndexedSeqEmitSettable(is: IndexedSeq[EmitSettable]) {
}

object LoopRef {
def apply(cb: EmitCodeBuilder, L: CodeLabel, args: IndexedSeq[(String, PType)], pool: Value[RegionPool]): LoopRef = {
def apply(cb: EmitCodeBuilder, L: CodeLabel, args: IndexedSeq[(String, PType)], pool: Value[RegionPool], resultType: EmitType): LoopRef = {
val (loopArgs, tmpLoopArgs) = args.zipWithIndex.map { case ((name, pt), i) =>
(cb.emb.newEmitField(s"$name$i", pt), cb.emb.newEmitField(s"tmp$name$i", pt))
(cb.emb.newEmitField(s"$name$i", pt, pt.required), cb.emb.newEmitField(s"tmp$name$i", pt, pt.required))
}.unzip

val r1: Settable[Region] = cb.newLocal[Region]("loop_ref_r1")
Expand All @@ -485,17 +492,18 @@ object LoopRef {
val r2: Settable[Region] = cb.newLocal[Region]("loop_ref_r2")
cb.assign(r2, Region.stagedCreate(Region.REGULAR, pool))

LoopRef(L, args.map(_._2), loopArgs, tmpLoopArgs, r1, r2)
new LoopRef(L, args.map(_._2), loopArgs, tmpLoopArgs, r1, r2, resultType)
}
}

case class LoopRef(
L: CodeLabel,
loopTypes: IndexedSeq[PType],
loopArgs: IndexedSeq[EmitSettable],
tmpLoopArgs: IndexedSeq[EmitSettable],
r1: Settable[Region],
r2: Settable[Region])
class LoopRef(
val L: CodeLabel,
val loopTypes: IndexedSeq[PType],
val loopArgs: IndexedSeq[EmitSettable],
val tmpLoopArgs: IndexedSeq[EmitSettable],
val r1: Settable[Region],
val r2: Settable[Region],
val resultType: EmitType)

abstract class EstimableEmitter[C] {
def emit(mb: EmitMethodBuilder[C]): Code[Unit]
Expand Down Expand Up @@ -777,7 +785,7 @@ class Emit[C](

cb.goto(Lmissing)

IEmitCode(Lmissing, Ldefined, coalescedValue.load(), emittedValues.forall(_.required))
IEmitCode(Lmissing, Ldefined, coalescedValue.load(), emittedValues.exists(_.required))

case If(cond, cnsq, altr) =>
assert(cnsq.typ == altr.typ)
Expand Down Expand Up @@ -1762,8 +1770,8 @@ class Emit[C](
.flatMap(cb) { case (stream: SStreamCode) =>
val producer = stream.producer

val xAcc = mb.newEmitField(accumName, x.accPType) // in future, will choose compatible type for zero/body with requiredness
val xElt = mb.newEmitField(valueName, producer.element.pt)
val xAcc = mb.newEmitField(accumName, x.accPType, x.accPType.required) // in future, will choose compatible type for zero/body with requiredness
val xElt = mb.newEmitField(valueName, producer.element.emitType)

var tmpRegion: Settable[Region] = null

Expand Down Expand Up @@ -1812,9 +1820,9 @@ class Emit[C](

var tmpRegion: Settable[Region] = null

val xElt = mb.newEmitField(valueName, producer.element.pt)
val xElt = mb.newEmitField(valueName, producer.element.emitType)
val names = acc.map(_._1)
val accTypes = x.accPTypes
val accTypes = x.accPTypes.map(pt => EmitType(pt.sType, pt.required))
val accVars = (names, accTypes).zipped.map(mb.newEmitField)

val resEnv = env.bind(names.zip(accVars): _*)
Expand Down Expand Up @@ -1867,6 +1875,20 @@ class Emit[C](
emitI(res, env = resEnv)
}

case Die(m, typ, errorId) =>
val cm = emitI(m)
val msg = cb.newLocal[String]("die_msg")
cm.consume(cb,
cb.assign(msg, "<exception message missing>"),
{ sc => cb.assign(msg, sc.asString.loadString())})
cb._throw[HailException](Code.newInstance[HailException, String, Int](msg, errorId))

val t = PType.canonical(typ, true).deepInnerRequired(true)
IEmitCode.present(cb, t.defaultValue(cb.emb))

case CastToArray(a) =>
emitI(a).map(cb) { ind => ind.asIndexable.castToArray(cb) }.typecast[PCode]

case x@ShuffleWith(
keyFields,
rowType,
Expand Down Expand Up @@ -1954,7 +1976,9 @@ class Emit[C](

val stagedPool = cb.newLocal[RegionPool]("tail_loop_pool_ref")
cb.assign(stagedPool, region.getPool())
val loopRef = LoopRef(cb, loopStartLabel, inits.map { case ((name, _), pt) => (name, pt) }, stagedPool)

val resultEmitType = ctx.req.lookup(body).asInstanceOf[TypeWithRequiredness].canonicalEmitType(body.typ)
val loopRef = LoopRef(cb, loopStartLabel, inits.map { case ((name, _), pt) => (name, pt) }, stagedPool, resultEmitType)

val argEnv = env
.bind((args.map(_._1), loopRef.loopArgs).zipped.toArray: _*)
Expand All @@ -1968,12 +1992,14 @@ class Emit[C](

cb.define(loopStartLabel)

emitI(body, env = argEnv, loopEnv = Some(newLoopEnv.bind(name, loopRef))).map(cb) { pc =>
val result = emitI(body, env = argEnv, loopEnv = Some(newLoopEnv.bind(name, loopRef))).map(cb) { pc =>
val answerInRightRegion = pc.copyToRegion(cb, region)
cb.append(loopRef.r1.clearRegion())
cb.append(loopRef.r2.clearRegion())
answerInRightRegion
}
assert(result.emitType == resultEmitType, s"loop type mismatch: emitted=${ result.emitType }, expected=${ resultEmitType }")
result

case Recur(name, args, _) =>
val loopRef = loopEnv.get.lookup(name)
Expand All @@ -1998,7 +2024,9 @@ class Emit[C](
// after a goto.
val deadLabel = CodeLabel()
cb.define(deadLabel)
IEmitCode.missing(cb, pt.defaultValue(cb.emb))

val rt = loopRef.resultType
IEmitCode(CodeLabel(), CodeLabel(), rt.st.pType.defaultValue(mb), rt.required)

case x@CollectDistributedArray(contexts, globals, cname, gname, body, tsd) =>
val ctxsType = coerce[PStream](contexts.pType)
Expand Down Expand Up @@ -2125,6 +2153,16 @@ class Emit[C](
emitFallback(ir)
}

ctx.req.lookupOpt(ir) match {
case Some(r) =>
if (result.required != r.required) {
throw new RuntimeException(s"requiredness mismatch: EC=${ result.required } / Analysis=${ r.required }\n${ result.pt }\n${ Pretty(ir) }")
}

case _ =>
// we dynamically generate some IRs in emission. Ignore these...
}

if (result.pt != pt) {
if (!result.pt.equalModuloRequired(pt))
throw new RuntimeException(s"ptype mismatch:\n emitted: ${ result.pt }\n inferred: ${ ir.pType }\n ir: $ir")
Expand Down Expand Up @@ -2307,25 +2345,10 @@ class Emit[C](
sorter.toRegion(cb, x.pType)
})

case CastToArray(a) =>
val et = emit(a)
EmitCode(et.setup, et.m, PCode(pt, et.v))

case In(i, expectedPType) =>
// this, Code[Region], ...
val ev = mb.getEmitParam(2 + i, region)
ev
case Die(m, typ, errorId) =>
val cm = emit(m)
EmitCode(
Code(
cm.setup,
Code._throw[HailException, Unit](Code.newInstance[HailException, String, Int](
cm.m.mux[String](
"<exception message missing>",
coerce[String](StringFunctions.wrapArg(EmitRegion(mb, region), m.pType)(cm.v))), errorId))),
true,
pt.defaultValue(mb))

case ir@Apply(fn, typeArgs, args, rt) =>
val impl = ir.implementation
Expand Down Expand Up @@ -2378,6 +2401,16 @@ class Emit[C](
}
}

ctx.req.lookupOpt(ir) match {
case Some(r) =>
if (result.required != r.required) {
throw new RuntimeException(s"requiredness mismatch: EC=${ result.required } / Analysis=${ r.required }\n${ result.pt }\n${ Pretty(ir) }")
}

case _ =>
// we dynamically generate some IRs in emission. Ignore these...
}

if (result.pt != pt) {
if (!result.pt.equalModuloRequired(pt))
throw new RuntimeException(s"ptype mismatch:\n emitted: ${ result.pt }\n inferred: ${ ir.pType }\n ir: $ir")
Expand Down Expand Up @@ -2722,13 +2755,20 @@ class Emit[C](
cb._fatal("need at least one ndarray to concatenate")
})

val missing = cb.newLocal[Boolean]("ndarray_concat_result_missing")
cb.assign(missing, false)
// Need to check if the any of the ndarrays are missing.
val missingCheckLoopIdx = cb.newLocal[Int]("ndarray_concat_missing_check_idx")
cb.forLoop(cb.assign(missingCheckLoopIdx, 0), missingCheckLoopIdx < arrLength, cb.assign(missingCheckLoopIdx, missingCheckLoopIdx + 1),
cb.assign(missing, missing | ndsArrayPValue.isElementMissing(missingCheckLoopIdx))
)
val missing: Code[Boolean] = {
if (ndsArrayPValue.st.elementEmitType.required)
const(false)
else {
val missing = cb.newLocal[Boolean]("ndarray_concat_result_missing")
cb.assign(missing, false)
// Need to check if the any of the ndarrays are missing.
val missingCheckLoopIdx = cb.newLocal[Int]("ndarray_concat_missing_check_idx")
cb.forLoop(cb.assign(missingCheckLoopIdx, 0), missingCheckLoopIdx < arrLength, cb.assign(missingCheckLoopIdx, missingCheckLoopIdx + 1),
cb.assign(missing, missing | ndsArrayPValue.isElementMissing(missingCheckLoopIdx))
)
missing
}
}

IEmitCode(cb, missing, {
val loopIdx = cb.newLocal[Int]("ndarray_concat_shape_check_idx")
Expand Down
53 changes: 33 additions & 20 deletions hail/src/main/scala/is/hail/expr/ir/EmitClassBuilder.scala
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ import is.hail.expr.ir.orderings.CodeOrdering
import is.hail.io.fs.FS
import is.hail.io.{BufferSpec, InputBuffer, TypedCodecSpec}
import is.hail.lir
import is.hail.types.physical.stypes.SType
import is.hail.types.physical.stypes.{EmitType, SType}
import is.hail.types.physical.stypes.interfaces.PVoidCode.pt
import is.hail.types.physical.{PCanonicalTuple, PCode, PSettable, PStream, PType, PValue, typeToTypeInfo}
import is.hail.types.virtual.Type
import is.hail.utils._
Expand Down Expand Up @@ -89,11 +90,15 @@ trait WrappedEmitClassBuilder[C] extends WrappedEmitModuleBuilder {

def newPField(name: String, pt: PType): PSettable = ecb.newPField(name, pt)

def newEmitField(pt: PType): EmitSettable = ecb.newEmitField(pt)
def newEmitField(et: EmitType): EmitSettable = ecb.newEmitField(et.st.pType, et.required)

def newEmitField(name: String, pt: PType): EmitSettable = ecb.newEmitField(name, pt)
def newEmitField(pt: PType, required: Boolean): EmitSettable = ecb.newEmitField(pt, required)

def newEmitSettable(pt: PType, ms: Settable[Boolean], vs: PSettable): EmitSettable = ecb.newEmitSettable(pt, ms, vs)
def newEmitField(name: String, et: EmitType): EmitSettable = ecb.newEmitField(name, et.st.pType, et.required)

def newEmitField(name: String, pt: PType, required: Boolean): EmitSettable = ecb.newEmitField(name, pt, required)

def newEmitSettable(pt: PType, ms: Settable[Boolean], vs: PSettable, required: Boolean): EmitSettable = ecb.newEmitSettable(pt, ms, vs, required)

def newPresentEmitField(pt: PType): PresentEmitSettable = ecb.newPresentEmitField(pt)

Expand Down Expand Up @@ -226,29 +231,35 @@ class EmitClassBuilder[C](

def newPField(name: String, pt: PType): PSettable = newPSettable(fieldBuilder, pt, name)

def newEmitField(pt: PType): EmitSettable =
newEmitSettable(pt, genFieldThisRef[Boolean](), newPField(pt))
def newEmitField(pt: PType, required: Boolean): EmitSettable =
newEmitSettable(pt, genFieldThisRef[Boolean](), newPField(pt), required)

def newEmitField(name: String, pt: PType): EmitSettable =
newEmitSettable(pt, genFieldThisRef[Boolean](name + "_missing"), newPField(name, pt))
def newEmitField(name: String, emitType: EmitType): EmitSettable = newEmitField(name, emitType.st.pType, emitType.required)

def newEmitSettable(_pt: PType, ms: Settable[Boolean], vs: PSettable): EmitSettable = new EmitSettable {
def newEmitField(name: String, pt: PType, required: Boolean): EmitSettable =
newEmitSettable(pt, genFieldThisRef[Boolean](name + "_missing"), newPField(name, pt), required)

def newEmitSettable(_pt: PType, ms: Settable[Boolean], vs: PSettable, required: Boolean): EmitSettable = new EmitSettable {
if (!_pt.isRealizable) {
throw new UnsupportedOperationException(s"newEmitSettable can only be called on realizable PTypes. Called on ${_pt}")
}

def pt: PType = _pt

def load: EmitCode = EmitCode(Code._empty,
if (_pt.required) false else ms.get,
vs.get)
def load: EmitCode = {
val ec = EmitCode(Code._empty,
if (required) const(false) else ms.get,
vs.get)
assert(ec.required == required)
ec
}

def store(cb: EmitCodeBuilder, ec: EmitCode): Unit = {
store(cb, ec.toI(cb))
}

def store(cb: EmitCodeBuilder, iec: IEmitCode): Unit =
if (_pt.required)
if (required)
cb.assign(vs, iec.get(cb, s"Required EmitSettable cannot be missing ${ _pt }"))
else
iec.consume(cb, {
Expand All @@ -259,7 +270,7 @@ class EmitClassBuilder[C](
})

override def get(cb: EmitCodeBuilder): PCode = {
if (_pt.required) {
if (required) {
vs
} else {
cb.ifx(ms, cb._fatal(s"Can't convert missing ${_pt} to PValue"))
Expand Down Expand Up @@ -1002,11 +1013,13 @@ class EmitMethodBuilder[C](

def newPLocal(name: String, pt: PType): PSettable = newPSettable(localBuilder, pt, name)

def newEmitLocal(pt: PType): EmitSettable =
newEmitSettable(pt, if (pt.required) null else newLocal[Boolean](), newPLocal(pt))
def newEmitLocal(emitType: EmitType): EmitSettable = newEmitLocal(emitType.st.pType, emitType.required)
def newEmitLocal(pt: PType, required: Boolean): EmitSettable =
newEmitSettable(pt, if (required) null else newLocal[Boolean](), newPLocal(pt), required)

def newEmitLocal(name: String, pt: PType): EmitSettable =
newEmitSettable(pt, if (pt.required) null else newLocal[Boolean](name + "_missing"), newPLocal(name, pt))
def newEmitLocal(name: String, emitType: EmitType): EmitSettable = newEmitLocal(name, emitType.st.pType, emitType.required)
def newEmitLocal(name: String, pt: PType, required: Boolean): EmitSettable =
newEmitSettable(pt, if (required) null else newLocal[Boolean](name + "_missing"), newPLocal(name, pt), required)

def newPresentEmitLocal(pt: PType): PresentEmitSettable =
newPresentEmitSettable(newPLocal(pt))
Expand Down Expand Up @@ -1076,9 +1089,9 @@ trait WrappedEmitMethodBuilder[C] extends WrappedEmitClassBuilder[C] {

def newPLocal(name: String, pt: PType): PSettable = emb.newPLocal(name, pt)

def newEmitLocal(pt: PType): EmitSettable = emb.newEmitLocal(pt)
def newEmitLocal(pt: PType, required: Boolean): EmitSettable = emb.newEmitLocal(pt, required)

def newEmitLocal(name: String, pt: PType): EmitSettable = emb.newEmitLocal(name, pt)
def newEmitLocal(name: String, pt: PType, required: Boolean): EmitSettable = emb.newEmitLocal(name, pt, required)

def newPresentEmitLocal(pt: PType): PresentEmitSettable = emb.newPresentEmitLocal(pt)
}
Expand Down
Loading

0 comments on commit 40e8b6b

Please sign in to comment.