Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions pkg/ir/air/gadgets/lexicographic_sort.go
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ func (p *LexicographicSortingGadget[F]) addLexicographicSelectorBits(deltaIndex
// Add binary constraints for selector bits
for i := uint(0); i < ncols; i++ {
ref := sc.NewRegisterRef(mid, sc.NewRegisterId(bitIndex+i))
// Add binarity constraints (i.e. to enfoce that this column is a bit).
// Add binarity constraints (i.e. to enforce that this column is a bit).
NewBitwidthGadget(schema).Constrain(ref, 1)
}
// Apply constraints to ensure at most one is set.
Expand Down Expand Up @@ -252,8 +252,8 @@ func (p *LexicographicSortingGadget[F]) addLexicographicSelectorBits(deltaIndex

// Construct the lexicographic delta constraint. This states that the delta
// column either holds 0 or the difference Ci[k] - Ci[k-1] (adjusted
// appropriately for the sign) between the ith column whose multiplexor bit is
// set. This is assumes that multiplexor bits are mutually exclusive (i.e. at
// appropriately for the sign) between the ith column whose multiplexer bit is
// set. This is assumes that multiplexer bits are mutually exclusive (i.e. at
// most is one).
func constructLexicographicDeltaConstraint[F field.Element[F]](deltaIndex sc.RegisterId, columns []sc.RegisterId,
signs []bool) air.Term[F] {
Expand Down
46 changes: 17 additions & 29 deletions pkg/ir/air/gadgets/normalisation.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,7 @@ package gadgets
import (
"fmt"
"math"
"math/big"

"github.com/consensys/gnark-crypto/ecc/bls12-377/fr"
"github.com/consensys/go-corset/pkg/ir"
"github.com/consensys/go-corset/pkg/ir/air"
"github.com/consensys/go-corset/pkg/ir/assignment"
Expand All @@ -34,7 +32,7 @@ import (
// Normalise constructs an expression representing the normalised value of e.
// That is, an expression which is 0 when e is 0, and 1 when e is non-zero.
// This is done by introducing a computed column to hold the (pseudo)
// mutliplicative inverse of e.
// multiplicative inverse of e.
func Normalise[F field.Element[F]](e air.Term[F], module *air.ModuleBuilder[F]) air.Term[F] {
// Construct pseudo multiplicative inverse of e.
ie := applyPseudoInverseGadget(e, module)
Expand All @@ -50,22 +48,23 @@ func Normalise[F field.Element[F]](e air.Term[F], module *air.ModuleBuilder[F])
func applyPseudoInverseGadget[F field.Element[F]](e air.Term[F], module *air.ModuleBuilder[F]) air.Term[F] {
var (
// Construct inverse computation
ie = &psuedoInverse[F]{Expr: e}
ie = &pseudoInverse[F]{Expr: e}
// Determine computed column name
name = ie.Lisp(true, module).String(false)
// Look up column
index, ok = module.HasRegister(name)
// Default padding (for now)
padding big.Int = ir.PaddingFor(ie, module)
padding = ir.PaddingFor(ie, module)
)
// Add new column (if it does not already exist)
if !ok {
// Indicate column has "field element width".
var bitwidth uint = math.MaxUint
// Add computed register.
index = module.NewRegister(sc.NewComputedRegister(name, bitwidth, padding))
// Add assignment
module.AddAssignment(assignment.NewComputedRegister(sc.NewRegisterRef(module.Id(), index), ie, true))
target := sc.NewRegisterRef(module.Id(), index)
// Add inverse assignment
module.AddAssignment(assignment.NewPseudoInverse(target, e))
// Construct proof of 1/e
inv_e := ir.NewRegisterAccess[F, air.Term[F]](index, 0)
// Construct e/e
Expand All @@ -81,15 +80,16 @@ func applyPseudoInverseGadget[F field.Element[F]](e air.Term[F], module *air.Mod
return ir.NewRegisterAccess[F, air.Term[F]](index, 0)
}

// psuedoInverse represents a computation which computes the multiplicative
// inverse of a given expression.
type psuedoInverse[F field.Element[F]] struct {
// pseudoInverse represents a computation which computes the multiplicative
// inverse of a given expression. This is only needed now for the padding
// computation.
type pseudoInverse[F field.Element[F]] struct {
Expr air.Term[F]
}

// EvalAt computes the multiplicative inverse of a given expression at a given
// row in the table.
func (e *psuedoInverse[F]) EvalAt(k int, tr trace.Module[F], sc schema.Module[F]) (F, error) {
func (e *pseudoInverse[F]) EvalAt(k int, tr trace.Module[F], sc schema.Module[F]) (F, error) {
// Convert expression into something which can be evaluated, then evaluate
// it.
val, err := e.Expr.EvalAt(k, tr, sc)
Expand All @@ -99,51 +99,39 @@ func (e *psuedoInverse[F]) EvalAt(k int, tr trace.Module[F], sc schema.Module[F]
return inv, err
}

// AsConstant determines whether or not this is a constant expression. If
// so, the constant is returned; otherwise, nil is returned. NOTE: this
// does not perform any form of simplification to determine this.
func (e *psuedoInverse[F]) AsConstant() *fr.Element { return nil }

// Bounds returns max shift in either the negative (left) or positive
// direction (right).
func (e *psuedoInverse[F]) Bounds() util.Bounds { return e.Expr.Bounds() }
func (e *pseudoInverse[F]) Bounds() util.Bounds { return e.Expr.Bounds() }

// RequiredRegisters returns the set of registers on which this term depends.
// That is, registers whose values may be accessed when evaluating this term on
// a given trace.
func (e *psuedoInverse[F]) RequiredRegisters() *set.SortedSet[uint] {
func (e *pseudoInverse[F]) RequiredRegisters() *set.SortedSet[uint] {
return e.Expr.RequiredRegisters()
}

// RequiredCells returns the set of trace cells on which this term depends.
// In this case, that is the empty set.
func (e *psuedoInverse[F]) RequiredCells(row int, mid trace.ModuleId) *set.AnySortedSet[trace.CellRef] {
func (e *pseudoInverse[F]) RequiredCells(row int, mid trace.ModuleId) *set.AnySortedSet[trace.CellRef] {
return e.Expr.RequiredCells(row, mid)
}

// IsDefined implementation for Evaluable interface.
func (e *psuedoInverse[F]) IsDefined() bool {
// NOTE: this is technically safe given the limited way that IsDefined is
// used for lookup selectors.
return true
}

// Lisp converts this schema element into a simple S-Expression, for example
// so it can be printed.
func (e *psuedoInverse[F]) Lisp(global bool, mapping sc.RegisterMap) sexp.SExp {
func (e *pseudoInverse[F]) Lisp(global bool, mapping sc.RegisterMap) sexp.SExp {
return sexp.NewList([]sexp.SExp{
sexp.NewSymbol("inv"),
e.Expr.Lisp(global, mapping),
})
}

// Substitute implementation for Substitutable interface.
func (e *psuedoInverse[F]) Substitute(mapping map[string]F) {
func (e *pseudoInverse[F]) Substitute(mapping map[string]F) {
panic("unreachable")
}

// ValueRange implementation for Term interface.
func (e *psuedoInverse[F]) ValueRange(mapping schema.RegisterMap) util_math.Interval {
func (e *pseudoInverse[F]) ValueRange(mapping schema.RegisterMap) util_math.Interval {
// This could be managed by having a mechanism for representing infinity
// (e.g. nil). For now, this is never actually used, so we can just ignore
// it.
Expand Down
6 changes: 3 additions & 3 deletions pkg/ir/assignment/computed_register.go
Original file line number Diff line number Diff line change
Expand Up @@ -175,9 +175,9 @@ func (p *ComputedRegister[F, E]) Substitute(mapping map[string]F) {
//nolint:revive
func (p *ComputedRegister[F, E]) Lisp(schema sc.AnySchema[F]) sexp.SExp {
var (
module = schema.Module(p.Target.Module())
target = module.Register(p.Target.Register())
datatype string = "𝔽"
module = schema.Module(p.Target.Module())
target = module.Register(p.Target.Register())
datatype = "𝔽"
)
//
if target.Width != math.MaxUint {
Expand Down
184 changes: 184 additions & 0 deletions pkg/ir/assignment/pseudo_inverse.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
// Copyright Consensys Software Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
// the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
// an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package assignment

import (
"fmt"
"math"

"github.com/consensys/go-corset/pkg/ir"
"github.com/consensys/go-corset/pkg/ir/air"
"github.com/consensys/go-corset/pkg/schema"
"github.com/consensys/go-corset/pkg/trace"
"github.com/consensys/go-corset/pkg/util"
"github.com/consensys/go-corset/pkg/util/collection/array"
"github.com/consensys/go-corset/pkg/util/collection/set"
"github.com/consensys/go-corset/pkg/util/field"
"github.com/consensys/go-corset/pkg/util/source/sexp"
)

// PseudoInverse represents a computation which computes the multiplicative
// inverse of a given expression.
type PseudoInverse[F field.Element[F]] struct {
// Target index for computed column
Target schema.RegisterRef

Expr air.Term[F]
}

// NewPseudoInverse constructs a new pseudo-inverse assignment for the given
// target register and expression.
func NewPseudoInverse[F field.Element[F]](target schema.RegisterRef, expr air.Term[F]) *PseudoInverse[F] {
return &PseudoInverse[F]{
Target: target,
Expr: expr,
}
}

// Bounds determines the well-definedness bounds for this assignment.
// It is the same as that of the expression it is inverting.
func (e *PseudoInverse[F]) Bounds(mid schema.ModuleId) util.Bounds {
if mid == e.Target.Module() {
return e.Expr.Bounds()
}
// Not relevant
return util.EMPTY_BOUND
}

// Compute performs the inversion.
func (e *PseudoInverse[F]) Compute(tr trace.Trace[F], schema schema.AnySchema[F]) ([]array.MutArray[F], error) {
var (
trModule = tr.Module(e.Target.Module())
scModule = schema.Module(e.Target.Module())
err error
)
// Determine multiplied height
height := trModule.Height()
// FIXME: using a large bitwidth here ensures the underlying data is
// represented using a full field element, rather than e.g. some smaller
// number of bytes. This is needed to handle reject tests which can produce
// values outside the range of the computed register, but which we still
// want to check are actually rejected (i.e. since they are simulating what
// an attacker might do).
data := tr.Builder().NewArray(height, math.MaxUint)
// Expand the trace
err = invert(data, e.Expr, trModule, scModule)
// Sanity check
if err != nil {
return nil, err
}
// Done
return []array.MutArray[F]{data}, err
}

// Consistent performs some simple checks that the given assignment is
// consistent with its enclosing schema This provides a double check of certain
// key properties, such as that registers used for assignments are valid,
// etc.
func (e *PseudoInverse[F]) Consistent(schema.AnySchema[F]) []error {
return nil
}

// RegistersExpanded identifies registers expanded by this assignment.
func (e *PseudoInverse[F]) RegistersExpanded() []schema.RegisterRef {
return nil
}

// RegistersRead returns the set of columns that this assignment depends upon.
// That can include input columns, as well as other computed columns.
func (e *PseudoInverse[F]) RegistersRead() []schema.RegisterRef {
var (
module = e.Target.Module()
regs = e.Expr.RequiredRegisters()
rids = make([]schema.RegisterRef, regs.Iter().Count())
)
//
for i, iter := 0, regs.Iter(); iter.HasNext(); i++ {
rid := schema.NewRegisterId(iter.Next())
rids[i] = schema.NewRegisterRef(module, rid)
}
// Remove target to allow recursive definitions. Observe this does not
// guarantee they make sense!
return array.RemoveMatching(rids, func(r schema.RegisterRef) bool {
return r == e.Target
})
}

// RegistersWritten identifies registers assigned by this assignment.
func (e *PseudoInverse[F]) RegistersWritten() []schema.RegisterRef {
return []schema.RegisterRef{e.Target}
}

// Lisp converts this constraint into an S-Expression.
//
//nolint:revive
func (e *PseudoInverse[F]) Lisp(schema schema.AnySchema[F]) sexp.SExp {
var (
module = schema.Module(e.Target.Module())
target = module.Register(e.Target.Register())
datatype = "𝔽"
)
//
if target.Width != math.MaxUint {
datatype = fmt.Sprintf("u%d", target.Width)
}
//
return sexp.NewList(
[]sexp.SExp{sexp.NewSymbol("inv"),
sexp.NewList([]sexp.SExp{
sexp.NewSymbol(target.QualifiedName(module)),
sexp.NewSymbol(datatype)}),
e.Expr.Lisp(false, module),
})
}

// RequiredRegisters returns the set of registers on which this term depends.
// That is, registers whose values may be accessed when evaluating this term on
// a given trace.
func (e *PseudoInverse[F]) RequiredRegisters() *set.SortedSet[uint] {
return e.Expr.RequiredRegisters()
}

// RequiredCells returns the set of trace cells on which this term depends.
// In this case, that is the empty set.
func (e *PseudoInverse[F]) RequiredCells(row int, mid trace.ModuleId) *set.AnySortedSet[trace.CellRef] {
return e.Expr.RequiredCells(row, mid)
}

// Substitute implementation for Substitutable interface.
func (e *PseudoInverse[F]) Substitute(map[string]F) {
panic("unreachable")
}

func invert[F field.Element[F]](
data array.MutArray[F],
expr ir.Evaluable[F],
trMod trace.Module[F],
scMod schema.Module[F],
) error {
// Forwards computation
for i := range data.Len() {
val, err := expr.EvalAt(int(i), trMod, scMod)
// error check
if err != nil {
return err
}
//
data.Set(i, val)
}

field.BatchInvert(data)

//
return nil
}
2 changes: 1 addition & 1 deletion pkg/ir/mir/lower.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ func (p *AirLowering[F]) InitialiseModule(index uint) {
airModule.NewRegisters(mirModule.Registers()...)
}

// LowerModule lowers the given MIR module into the correspondind AIR module.
// LowerModule lowers the given MIR module into the corresponding AIR module.
// This includes all constraints and assignments.
func (p *AirLowering[F]) LowerModule(index uint) {
var (
Expand Down
6 changes: 3 additions & 3 deletions pkg/schema/register.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ type RegisterMap interface {

// RegisterId captures the notion of a register index. That is, for each
// module, every register is allocated a given index starting from 0. The
// purpose of the wrapper is avoid confusion between uint values and things
// purpose of the wrapper is to avoid confusion between uint values and things
// which are expected to identify Columns.
type RegisterId = trace.ColumnId

Expand Down Expand Up @@ -103,7 +103,7 @@ type Register struct {
Name string
// Width (in bits) of this register
Width uint
// Determies what value will be used to padd this register.
// Determines what value will be used to padd this register.
Padding big.Int
}

Expand Down Expand Up @@ -222,7 +222,7 @@ func (p RegisterType) GobEncode() (data []byte, err error) {
gobEncoder = gob.NewEncoder(&buffer)
)
//
if err := gobEncoder.Encode(&p.kind); err != nil {
if err = gobEncoder.Encode(&p.kind); err != nil {
return nil, err
}
// Done
Expand Down
Loading
Loading