diff --git a/pkg/ir/air/gadgets/lexicographic_sort.go b/pkg/ir/air/gadgets/lexicographic_sort.go index a03f355ae..e9795268a 100644 --- a/pkg/ir/air/gadgets/lexicographic_sort.go +++ b/pkg/ir/air/gadgets/lexicographic_sort.go @@ -192,7 +192,7 @@ func (p *LexicographicSortingGadget[F]) addLexicographicSelectorBits(deltaIndex // Add binary constraints for selector bits for i := uint(0); i < ncols; i++ { ref := sc.NewRegisterRef(mid, sc.NewRegisterId(bitIndex+i)) - // Add binarity constraints (i.e. to enfoce that this column is a bit). + // Add binarity constraints (i.e. to enforce that this column is a bit). NewBitwidthGadget(schema).Constrain(ref, 1) } // Apply constraints to ensure at most one is set. @@ -252,8 +252,8 @@ func (p *LexicographicSortingGadget[F]) addLexicographicSelectorBits(deltaIndex // Construct the lexicographic delta constraint. This states that the delta // column either holds 0 or the difference Ci[k] - Ci[k-1] (adjusted -// appropriately for the sign) between the ith column whose multiplexor bit is -// set. This is assumes that multiplexor bits are mutually exclusive (i.e. at +// appropriately for the sign) between the ith column whose multiplexer bit is +// set. This is assumes that multiplexer bits are mutually exclusive (i.e. at // most is one). func constructLexicographicDeltaConstraint[F field.Element[F]](deltaIndex sc.RegisterId, columns []sc.RegisterId, signs []bool) air.Term[F] { diff --git a/pkg/ir/air/gadgets/normalisation.go b/pkg/ir/air/gadgets/normalisation.go index 9fe06b267..c2775e2a6 100644 --- a/pkg/ir/air/gadgets/normalisation.go +++ b/pkg/ir/air/gadgets/normalisation.go @@ -15,9 +15,7 @@ package gadgets import ( "fmt" "math" - "math/big" - "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" "github.com/consensys/go-corset/pkg/ir" "github.com/consensys/go-corset/pkg/ir/air" "github.com/consensys/go-corset/pkg/ir/assignment" @@ -34,7 +32,7 @@ import ( // Normalise constructs an expression representing the normalised value of e. // That is, an expression which is 0 when e is 0, and 1 when e is non-zero. // This is done by introducing a computed column to hold the (pseudo) -// mutliplicative inverse of e. +// multiplicative inverse of e. func Normalise[F field.Element[F]](e air.Term[F], module *air.ModuleBuilder[F]) air.Term[F] { // Construct pseudo multiplicative inverse of e. ie := applyPseudoInverseGadget(e, module) @@ -50,13 +48,13 @@ func Normalise[F field.Element[F]](e air.Term[F], module *air.ModuleBuilder[F]) func applyPseudoInverseGadget[F field.Element[F]](e air.Term[F], module *air.ModuleBuilder[F]) air.Term[F] { var ( // Construct inverse computation - ie = &psuedoInverse[F]{Expr: e} + ie = &pseudoInverse[F]{Expr: e} // Determine computed column name name = ie.Lisp(true, module).String(false) // Look up column index, ok = module.HasRegister(name) // Default padding (for now) - padding big.Int = ir.PaddingFor(ie, module) + padding = ir.PaddingFor(ie, module) ) // Add new column (if it does not already exist) if !ok { @@ -64,8 +62,9 @@ func applyPseudoInverseGadget[F field.Element[F]](e air.Term[F], module *air.Mod var bitwidth uint = math.MaxUint // Add computed register. index = module.NewRegister(sc.NewComputedRegister(name, bitwidth, padding)) - // Add assignment - module.AddAssignment(assignment.NewComputedRegister(sc.NewRegisterRef(module.Id(), index), ie, true)) + target := sc.NewRegisterRef(module.Id(), index) + // Add inverse assignment + module.AddAssignment(assignment.NewPseudoInverse(target, e)) // Construct proof of 1/e inv_e := ir.NewRegisterAccess[F, air.Term[F]](index, 0) // Construct e/e @@ -81,15 +80,16 @@ func applyPseudoInverseGadget[F field.Element[F]](e air.Term[F], module *air.Mod return ir.NewRegisterAccess[F, air.Term[F]](index, 0) } -// psuedoInverse represents a computation which computes the multiplicative -// inverse of a given expression. -type psuedoInverse[F field.Element[F]] struct { +// pseudoInverse represents a computation which computes the multiplicative +// inverse of a given expression. This is only needed now for the padding +// computation. +type pseudoInverse[F field.Element[F]] struct { Expr air.Term[F] } // EvalAt computes the multiplicative inverse of a given expression at a given // row in the table. -func (e *psuedoInverse[F]) EvalAt(k int, tr trace.Module[F], sc schema.Module[F]) (F, error) { +func (e *pseudoInverse[F]) EvalAt(k int, tr trace.Module[F], sc schema.Module[F]) (F, error) { // Convert expression into something which can be evaluated, then evaluate // it. val, err := e.Expr.EvalAt(k, tr, sc) @@ -99,38 +99,26 @@ func (e *psuedoInverse[F]) EvalAt(k int, tr trace.Module[F], sc schema.Module[F] return inv, err } -// AsConstant determines whether or not this is a constant expression. If -// so, the constant is returned; otherwise, nil is returned. NOTE: this -// does not perform any form of simplification to determine this. -func (e *psuedoInverse[F]) AsConstant() *fr.Element { return nil } - // Bounds returns max shift in either the negative (left) or positive // direction (right). -func (e *psuedoInverse[F]) Bounds() util.Bounds { return e.Expr.Bounds() } +func (e *pseudoInverse[F]) Bounds() util.Bounds { return e.Expr.Bounds() } // RequiredRegisters returns the set of registers on which this term depends. // That is, registers whose values may be accessed when evaluating this term on // a given trace. -func (e *psuedoInverse[F]) RequiredRegisters() *set.SortedSet[uint] { +func (e *pseudoInverse[F]) RequiredRegisters() *set.SortedSet[uint] { return e.Expr.RequiredRegisters() } // RequiredCells returns the set of trace cells on which this term depends. // In this case, that is the empty set. -func (e *psuedoInverse[F]) RequiredCells(row int, mid trace.ModuleId) *set.AnySortedSet[trace.CellRef] { +func (e *pseudoInverse[F]) RequiredCells(row int, mid trace.ModuleId) *set.AnySortedSet[trace.CellRef] { return e.Expr.RequiredCells(row, mid) } -// IsDefined implementation for Evaluable interface. -func (e *psuedoInverse[F]) IsDefined() bool { - // NOTE: this is technically safe given the limited way that IsDefined is - // used for lookup selectors. - return true -} - // Lisp converts this schema element into a simple S-Expression, for example // so it can be printed. -func (e *psuedoInverse[F]) Lisp(global bool, mapping sc.RegisterMap) sexp.SExp { +func (e *pseudoInverse[F]) Lisp(global bool, mapping sc.RegisterMap) sexp.SExp { return sexp.NewList([]sexp.SExp{ sexp.NewSymbol("inv"), e.Expr.Lisp(global, mapping), @@ -138,12 +126,12 @@ func (e *psuedoInverse[F]) Lisp(global bool, mapping sc.RegisterMap) sexp.SExp { } // Substitute implementation for Substitutable interface. -func (e *psuedoInverse[F]) Substitute(mapping map[string]F) { +func (e *pseudoInverse[F]) Substitute(mapping map[string]F) { panic("unreachable") } // ValueRange implementation for Term interface. -func (e *psuedoInverse[F]) ValueRange(mapping schema.RegisterMap) util_math.Interval { +func (e *pseudoInverse[F]) ValueRange(mapping schema.RegisterMap) util_math.Interval { // This could be managed by having a mechanism for representing infinity // (e.g. nil). For now, this is never actually used, so we can just ignore // it. diff --git a/pkg/ir/assignment/computed_register.go b/pkg/ir/assignment/computed_register.go index 3e6343ac8..0108fe2b5 100644 --- a/pkg/ir/assignment/computed_register.go +++ b/pkg/ir/assignment/computed_register.go @@ -175,9 +175,9 @@ func (p *ComputedRegister[F, E]) Substitute(mapping map[string]F) { //nolint:revive func (p *ComputedRegister[F, E]) Lisp(schema sc.AnySchema[F]) sexp.SExp { var ( - module = schema.Module(p.Target.Module()) - target = module.Register(p.Target.Register()) - datatype string = "𝔽" + module = schema.Module(p.Target.Module()) + target = module.Register(p.Target.Register()) + datatype = "𝔽" ) // if target.Width != math.MaxUint { diff --git a/pkg/ir/assignment/pseudo_inverse.go b/pkg/ir/assignment/pseudo_inverse.go new file mode 100644 index 000000000..28c6f22a2 --- /dev/null +++ b/pkg/ir/assignment/pseudo_inverse.go @@ -0,0 +1,184 @@ +// Copyright Consensys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on +// an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 +package assignment + +import ( + "fmt" + "math" + + "github.com/consensys/go-corset/pkg/ir" + "github.com/consensys/go-corset/pkg/ir/air" + "github.com/consensys/go-corset/pkg/schema" + "github.com/consensys/go-corset/pkg/trace" + "github.com/consensys/go-corset/pkg/util" + "github.com/consensys/go-corset/pkg/util/collection/array" + "github.com/consensys/go-corset/pkg/util/collection/set" + "github.com/consensys/go-corset/pkg/util/field" + "github.com/consensys/go-corset/pkg/util/source/sexp" +) + +// PseudoInverse represents a computation which computes the multiplicative +// inverse of a given expression. +type PseudoInverse[F field.Element[F]] struct { + // Target index for computed column + Target schema.RegisterRef + + Expr air.Term[F] +} + +// NewPseudoInverse constructs a new pseudo-inverse assignment for the given +// target register and expression. +func NewPseudoInverse[F field.Element[F]](target schema.RegisterRef, expr air.Term[F]) *PseudoInverse[F] { + return &PseudoInverse[F]{ + Target: target, + Expr: expr, + } +} + +// Bounds determines the well-definedness bounds for this assignment. +// It is the same as that of the expression it is inverting. +func (e *PseudoInverse[F]) Bounds(mid schema.ModuleId) util.Bounds { + if mid == e.Target.Module() { + return e.Expr.Bounds() + } + // Not relevant + return util.EMPTY_BOUND +} + +// Compute performs the inversion. +func (e *PseudoInverse[F]) Compute(tr trace.Trace[F], schema schema.AnySchema[F]) ([]array.MutArray[F], error) { + var ( + trModule = tr.Module(e.Target.Module()) + scModule = schema.Module(e.Target.Module()) + err error + ) + // Determine multiplied height + height := trModule.Height() + // FIXME: using a large bitwidth here ensures the underlying data is + // represented using a full field element, rather than e.g. some smaller + // number of bytes. This is needed to handle reject tests which can produce + // values outside the range of the computed register, but which we still + // want to check are actually rejected (i.e. since they are simulating what + // an attacker might do). + data := tr.Builder().NewArray(height, math.MaxUint) + // Expand the trace + err = invert(data, e.Expr, trModule, scModule) + // Sanity check + if err != nil { + return nil, err + } + // Done + return []array.MutArray[F]{data}, err +} + +// Consistent performs some simple checks that the given assignment is +// consistent with its enclosing schema This provides a double check of certain +// key properties, such as that registers used for assignments are valid, +// etc. +func (e *PseudoInverse[F]) Consistent(schema.AnySchema[F]) []error { + return nil +} + +// RegistersExpanded identifies registers expanded by this assignment. +func (e *PseudoInverse[F]) RegistersExpanded() []schema.RegisterRef { + return nil +} + +// RegistersRead returns the set of columns that this assignment depends upon. +// That can include input columns, as well as other computed columns. +func (e *PseudoInverse[F]) RegistersRead() []schema.RegisterRef { + var ( + module = e.Target.Module() + regs = e.Expr.RequiredRegisters() + rids = make([]schema.RegisterRef, regs.Iter().Count()) + ) + // + for i, iter := 0, regs.Iter(); iter.HasNext(); i++ { + rid := schema.NewRegisterId(iter.Next()) + rids[i] = schema.NewRegisterRef(module, rid) + } + // Remove target to allow recursive definitions. Observe this does not + // guarantee they make sense! + return array.RemoveMatching(rids, func(r schema.RegisterRef) bool { + return r == e.Target + }) +} + +// RegistersWritten identifies registers assigned by this assignment. +func (e *PseudoInverse[F]) RegistersWritten() []schema.RegisterRef { + return []schema.RegisterRef{e.Target} +} + +// Lisp converts this constraint into an S-Expression. +// +//nolint:revive +func (e *PseudoInverse[F]) Lisp(schema schema.AnySchema[F]) sexp.SExp { + var ( + module = schema.Module(e.Target.Module()) + target = module.Register(e.Target.Register()) + datatype = "𝔽" + ) + // + if target.Width != math.MaxUint { + datatype = fmt.Sprintf("u%d", target.Width) + } + // + return sexp.NewList( + []sexp.SExp{sexp.NewSymbol("inv"), + sexp.NewList([]sexp.SExp{ + sexp.NewSymbol(target.QualifiedName(module)), + sexp.NewSymbol(datatype)}), + e.Expr.Lisp(false, module), + }) +} + +// RequiredRegisters returns the set of registers on which this term depends. +// That is, registers whose values may be accessed when evaluating this term on +// a given trace. +func (e *PseudoInverse[F]) RequiredRegisters() *set.SortedSet[uint] { + return e.Expr.RequiredRegisters() +} + +// RequiredCells returns the set of trace cells on which this term depends. +// In this case, that is the empty set. +func (e *PseudoInverse[F]) RequiredCells(row int, mid trace.ModuleId) *set.AnySortedSet[trace.CellRef] { + return e.Expr.RequiredCells(row, mid) +} + +// Substitute implementation for Substitutable interface. +func (e *PseudoInverse[F]) Substitute(map[string]F) { + panic("unreachable") +} + +func invert[F field.Element[F]]( + data array.MutArray[F], + expr ir.Evaluable[F], + trMod trace.Module[F], + scMod schema.Module[F], +) error { + // Forwards computation + for i := range data.Len() { + val, err := expr.EvalAt(int(i), trMod, scMod) + // error check + if err != nil { + return err + } + // + data.Set(i, val) + } + + field.BatchInvert(data) + + // + return nil +} diff --git a/pkg/ir/mir/lower.go b/pkg/ir/mir/lower.go index 5d05dab23..9a0d5bd29 100644 --- a/pkg/ir/mir/lower.go +++ b/pkg/ir/mir/lower.go @@ -102,7 +102,7 @@ func (p *AirLowering[F]) InitialiseModule(index uint) { airModule.NewRegisters(mirModule.Registers()...) } -// LowerModule lowers the given MIR module into the correspondind AIR module. +// LowerModule lowers the given MIR module into the corresponding AIR module. // This includes all constraints and assignments. func (p *AirLowering[F]) LowerModule(index uint) { var ( diff --git a/pkg/schema/register.go b/pkg/schema/register.go index 4f0d67ae2..46bb175f3 100644 --- a/pkg/schema/register.go +++ b/pkg/schema/register.go @@ -40,7 +40,7 @@ type RegisterMap interface { // RegisterId captures the notion of a register index. That is, for each // module, every register is allocated a given index starting from 0. The -// purpose of the wrapper is avoid confusion between uint values and things +// purpose of the wrapper is to avoid confusion between uint values and things // which are expected to identify Columns. type RegisterId = trace.ColumnId @@ -103,7 +103,7 @@ type Register struct { Name string // Width (in bits) of this register Width uint - // Determies what value will be used to padd this register. + // Determines what value will be used to padd this register. Padding big.Int } @@ -222,7 +222,7 @@ func (p RegisterType) GobEncode() (data []byte, err error) { gobEncoder = gob.NewEncoder(&buffer) ) // - if err := gobEncoder.Encode(&p.kind); err != nil { + if err = gobEncoder.Encode(&p.kind); err != nil { return nil, err } // Done diff --git a/pkg/util/collection/bit/bit_set.go b/pkg/util/collection/bit/bit_set.go index 599b99eee..a687b77db 100644 --- a/pkg/util/collection/bit/bit_set.go +++ b/pkg/util/collection/bit/bit_set.go @@ -120,12 +120,12 @@ func (p *Set) Count() uint { } // NewSet creates a Set of the given size. -func NewSet(size int) *Set { +func NewSet(size uint) *Set { return &Set{make([]uint64, (size+63)/64)} } // Set the iᵗʰ bit to v -func (p *Set) Set(i int, v bool) { +func (p *Set) Set(i uint, v bool) { x := uint64(1) << (i % 64) i = i / 64 @@ -137,7 +137,7 @@ func (p *Set) Set(i int, v bool) { } // Get the value of the iᵗʰ bit -func (p *Set) Get(i int) bool { +func (p *Set) Get(i uint) bool { return p.words[i/64]&(1<<(i%64)) != 0 } diff --git a/pkg/util/field/batch_invert.go b/pkg/util/field/batch_invert.go index 0e35e8fab..c3ee5ae23 100644 --- a/pkg/util/field/batch_invert.go +++ b/pkg/util/field/batch_invert.go @@ -13,12 +13,13 @@ package field import ( + "github.com/consensys/go-corset/pkg/util/collection/array" "github.com/consensys/go-corset/pkg/util/collection/bit" ) // BatchInvert efficiently inverts the list of elements s, in place. -func BatchInvert[T Element[T]](s []T) { - if len(s) == 0 { +func BatchInvert[T Element[T]](s array.MutArray[T]) { + if s.Len() == 0 { return } // @@ -26,42 +27,45 @@ func BatchInvert[T Element[T]](s []T) { zero = Zero[T]() one = One[T]() // identifies entries which are zero - isZero = bit.NewSet(len(s)) + isZero = bit.NewSet(s.Len()) - m = make([]T, len(s)) // m[i] = s[i] * s[i+1] * ... + m = make([]T, s.Len()) // m[i] = s[i] * s[i+1] * ... ) // - isZero.Set(len(s)-1, s[len(s)-1].IsZero()) + isZero.Set(s.Len()-1, s.Get(s.Len()-1).IsZero()) - if isZero.Get(len(s) - 1) { - s[len(s)-1] = one + if isZero.Get(s.Len() - 1) { + s.Set(s.Len()-1, one) } - m[len(s)-1] = s[len(s)-1] + m[s.Len()-1] = s.Get(s.Len() - 1) - for i := len(s) - 2; i >= 0; i-- { - isZero.Set(i, s[i].IsZero()) + for i := int(s.Len()) - 2; i >= 0; i-- { + isZero.Set(uint(i), s.Get(uint(i)).IsZero()) - if isZero.Get(i) { - s[i] = one + if isZero.Get(uint(i)) { + s.Set(uint(i), one) } - m[i] = m[i+1].Mul(s[i]) + m[i] = m[i+1].Mul(s.Get(uint(i))) } inv := m[0].Inverse() // inv = s[0]⁻¹ * s[1]⁻¹ * ... - for i := range len(s) - 1 { + for i := range s.Len() - 1 { // inv = s[i]⁻¹ * s[i+1]⁻¹ * ... - s[i], inv = inv.Mul(m[i+1]), inv.Mul(s[i]) + newInv := inv.Mul(s.Get(i)) + s.Set(i, inv.Mul(m[i+1])) + inv = newInv // inv = s[i+1]⁻¹ * s[i+2]⁻¹ * ... if isZero.Get(i) { - s[i] = zero + s.Set(i, zero) } } - s[len(s)-1] = inv - if isZero.Get(len(s) - 1) { - s[len(s)-1] = zero + s.Set(s.Len()-1, inv) + + if isZero.Get(s.Len() - 1) { + s.Set(s.Len()-1, zero) } } diff --git a/pkg/util/field/element_test.go b/pkg/util/field/element_test.go index e21585193..a817d80a0 100644 --- a/pkg/util/field/element_test.go +++ b/pkg/util/field/element_test.go @@ -14,9 +14,11 @@ package field import ( "math/rand" + "slices" "testing" "github.com/consensys/go-corset/pkg/util/assert" + "github.com/consensys/go-corset/pkg/util/collection/array" "github.com/consensys/go-corset/pkg/util/field/bls12_377" "github.com/consensys/go-corset/pkg/util/field/koalabear" ) @@ -28,9 +30,9 @@ func init() { } func TestBatchInvert(t *testing.T) { - s := make([]koalabear.Element, 4000) - sInv := make([]koalabear.Element, len(s)) - scratch := make([]koalabear.Element, len(s)) + s := make(elementArray, 4000) + sInv := make(elementArray, len(s)) + scratch := make(elementArray, len(s)) for i := range s { s[i] = koalabear.Element{rand.Uint32()} @@ -48,3 +50,37 @@ func TestBatchInvert(t *testing.T) { } } } + +type elementArray []koalabear.Element + +func (e elementArray) BitWidth() uint { + panic("not implemented") +} + +func (e elementArray) Clone() array.MutArray[koalabear.Element] { + return slices.Clone(e) +} + +func (e elementArray) Get(u uint) koalabear.Element { + return e[u] +} + +func (e elementArray) Len() uint { + return uint(len(e)) +} + +func (e elementArray) Slice(u uint, u2 uint) array.Array[koalabear.Element] { + return e[u:u2] +} + +func (e elementArray) Append(t koalabear.Element) { + panic("not implemented") +} + +func (e elementArray) Set(u uint, t koalabear.Element) { + e[u] = t +} + +func (e elementArray) Pad(u uint, u2 uint, t koalabear.Element) { + panic("not implemented") +}