Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion atlas-onnx-tracer/src/model/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,12 @@ impl Model {
| Operator::Rsqrt(_)
| Operator::Sigmoid(_)
| Operator::Sin(_) => LOG_K_CHUNK + log_2(node.pow2_padded_num_output_elements()),
Operator::ScalarConstDiv(_) => log_2(node.pow2_padded_num_output_elements()),
Operator::ScalarConstDiv(_) => {
LOG_K_CHUNK + log_2(node.pow2_padded_num_output_elements())
}
Operator::ScalarConstDivPow2(_) => {
LOG_K_CHUNK + log_2(node.pow2_padded_num_output_elements())
}
Operator::SoftmaxAxes(_) => {
LOG_K_CHUNK + log_2(*node.output_dims.last().unwrap_or(&1))
}
Expand Down
8 changes: 8 additions & 0 deletions atlas-onnx-tracer/src/model/shadow_trace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -442,6 +442,14 @@ fn shadow_f64(op: &Operator, inputs: Vec<&Tensor<f64>>, scale: Scale) -> Tensor<
elementwise_f64(inputs[0], |x| x / d)
}
}
Operator::ScalarConstDivPow2(scd) => {
if is_rebase_divisor(scd.divisor, scale) {
inputs[0].clone()
} else {
let d = scd.divisor as f64;
elementwise_f64(inputs[0], |x| x / d)
}
}

// ── Matrix / tensor contraction ─────────────────────────────────
Operator::Einsum(e) => tensor::ops::einsum(&e.equation, &inputs).unwrap(),
Expand Down
12 changes: 6 additions & 6 deletions atlas-onnx-tracer/src/node/handlers/arith.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use std::collections::HashMap;

use crate::{
node::ComputationNode,
ops::{Constant, Cube, Mul, Operator, ScalarConstDiv},
ops::{Constant, Cube, Mul, Operator, ScalarConstDiv, ScalarConstDivPow2},
simple_handler,
utils::{handler_builder::HandlerBuilder, parser::DecompositionBuilder},
};
Expand Down Expand Up @@ -47,7 +47,7 @@ fn build_mul(hctx: &mut HandlerContext) -> Vec<ComputationNode> {
.simple_op(Operator::Mul(Mul { scale }));

#[cfg(not(feature = "fused-ops"))]
let builder = builder.with_auto_rebase();
let builder = builder.with_auto_rebase_pow2();

builder.build()
}
Expand Down Expand Up @@ -85,11 +85,11 @@ fn handle_square(hctx: &mut HandlerContext) -> Vec<ComputationNode> {
/// Builds Square(x) = x²/S to maintain scale S.
///
/// When `pre_rebase_nonlinear` is **false** (default):
/// Square(x) → ScalarConstDiv(x², S)
/// Square(x) → ScalarConstDivPow2(x², S)
///
/// HACK: When `pre_rebase_nonlinear` is **true** (for large models like GPT-2):
/// Decomposes into existing ops to avoid i32 overflow:
/// x' = x / S (ScalarConstDiv)
/// x' = x / S (ScalarConstDivPow2)
/// result = x' * x (Mul, no rebase)
/// Since x' * x = x²/S, the result is already at scale S.
/// TODO: Remove pre_rebase_nonlinear path once fused i64 ops are default.
Expand All @@ -106,7 +106,7 @@ fn build_square(hctx: &mut HandlerContext) -> Vec<ComputationNode> {
// Node 0: x' = x / S
builder.add_node(ComputationNode {
idx: builder.idx(0),
operator: Operator::ScalarConstDiv(ScalarConstDiv { divisor: s }),
operator: Operator::ScalarConstDivPow2(ScalarConstDivPow2 { divisor: s }),
inputs: vec![x_idx],
output_dims: output_dims.clone(),
});
Expand All @@ -123,7 +123,7 @@ fn build_square(hctx: &mut HandlerContext) -> Vec<ComputationNode> {
HandlerBuilder::new(hctx)
.with_broadcast()
.simple_op(Operator::Square(Default::default()))
.with_auto_rebase()
.with_auto_rebase_pow2()
.build()
}

Expand Down
2 changes: 1 addition & 1 deletion atlas-onnx-tracer/src/node/handlers/other.rs
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ fn handle_einsum(hctx: &mut HandlerContext) -> Vec<ComputationNode> {
}));

#[cfg(not(feature = "fused-ops"))]
let builder = builder.with_auto_rebase();
let builder = builder.with_auto_rebase_pow2();

builder.build()
}
Expand Down
3 changes: 3 additions & 0 deletions atlas-onnx-tracer/src/ops/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ pub mod reshape;
pub mod rsqrt;
/// Division by a scalar constant operator.
pub mod scalar_const_div;
/// Division by a power-of-two scalar constant operator.
pub mod scalar_const_div_pow2;
/// Sigmoid activation operator.
pub mod sigmoid;
/// Element-wise sine operator.
Expand Down Expand Up @@ -145,6 +147,7 @@ define_operators! {
Reshape { shape:Vec<usize> },
Rsqrt { scale: F32 },
ScalarConstDiv {divisor: i32},
ScalarConstDivPow2 {divisor: i32},
Sigmoid { scale: F32, tau: i32, log_table: usize },
Sin { scale: F32 },
Slice { axis: usize, start: usize, end: usize},
Expand Down
1 change: 1 addition & 0 deletions atlas-onnx-tracer/src/ops/scalar_const_div.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ impl Op for ScalarConstDiv {
fn f(&self, inputs: Vec<&Tensor<i32>>) -> Tensor<i32> {
let a = inputs[0];
let b = self.divisor;
assert!(b > 0, "ScalarConstDiv requires a positive divisor, got {b}");
let data: Vec<i32> = a
.data()
.iter()
Expand Down
31 changes: 31 additions & 0 deletions atlas-onnx-tracer/src/ops/scalar_const_div_pow2.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
use super::Op;
use crate::{ops::ScalarConstDivPow2, tensor::Tensor};

impl Op for ScalarConstDivPow2 {
#[tracing::instrument(name = "ScalarConstDivPow2::f", skip_all)]
fn f(&self, inputs: Vec<&Tensor<i32>>) -> Tensor<i32> {
let a = inputs[0];
let b = self.divisor;
assert!(
b > 0 && (b as u32).is_power_of_two(),
"ScalarConstDivPow2 requires a positive power-of-two divisor, got {b}"
);
let data: Vec<i32> = a
.data()
.iter()
.map(|&x| {
let mut d_inv_x = x / b;
let remainder = x % b;
if remainder < 0 {
d_inv_x -= 1;
}
d_inv_x
})
.collect();
Tensor::new(Some(&data), a.dims()).unwrap()
}

fn requires_shape_equality(&self) -> bool {
true
}
}
22 changes: 19 additions & 3 deletions atlas-onnx-tracer/src/utils/handler_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@

use crate::{
node::{ComputationNode, handlers::HandlerContext},
ops::{Constant, Operator, ScalarConstDiv},
ops::{Constant, Operator, ScalarConstDiv, ScalarConstDivPow2},
tensor::Tensor,
};

Expand All @@ -45,6 +45,7 @@ pub struct HandlerBuilder<'a, 'b> {
broadcast_nodes: Vec<ComputationNode>,
stages: Vec<Stage>,
auto_rebase: bool,
auto_rebase_pow2: bool,
custom_rebase_factor: Option<i32>,
}

Expand Down Expand Up @@ -79,6 +80,7 @@ impl<'a, 'b> HandlerBuilder<'a, 'b> {
broadcast_nodes: vec![],
stages: vec![],
auto_rebase: false,
auto_rebase_pow2: false,
custom_rebase_factor: None,
}
}
Expand Down Expand Up @@ -127,6 +129,15 @@ impl<'a, 'b> HandlerBuilder<'a, 'b> {
self
}

/// Automatically adds a power-of-two rebase node.
///
/// This should only be used when the rebase factor is known to be a positive
/// power of two, such as fixed-point scale restoration after Mul/Square/Einsum.
pub fn with_auto_rebase_pow2(mut self) -> Self {
self.auto_rebase_pow2 = true;
self
}

/// Explicitly sets a rebase factor (1 << (scale * factor)).
///
/// Use this when you need custom rebase behavior.
Expand Down Expand Up @@ -260,10 +271,15 @@ impl<'a, 'b> HandlerBuilder<'a, 'b> {
if let Some(factor) = self.determine_rebase_factor() {
let prev_idx = current_output_idx.expect("Rebase requires a previous node");
let output_dims = self.hctx.output_dims.clone();
let operator = if self.auto_rebase_pow2 {
Operator::ScalarConstDivPow2(ScalarConstDivPow2 { divisor: factor })
} else {
Operator::ScalarConstDiv(ScalarConstDiv { divisor: factor })
};
Comment on lines +274 to +278

Copilot AI Apr 2, 2026

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When auto_rebase_pow2 is enabled, this unconditionally emits ScalarConstDivPow2 { divisor: factor }, but determine_rebase_factor() can also return custom_rebase_factor (which may be non-power-of-two). That will trigger runtime assertions/panics later (execution/proof code requires power-of-two). Add a guard here (assert factor > 0 && is_power_of_two) or fall back to ScalarConstDiv when the chosen factor is not a positive power of two.

Copilot uses AI. Check for mistakes.

builder.add_node(ComputationNode {
idx: builder.idx(node_offset),
operator: Operator::ScalarConstDiv(ScalarConstDiv { divisor: factor }),
operator,
inputs: vec![prev_idx],
output_dims,
});
Expand Down Expand Up @@ -302,7 +318,7 @@ impl<'a, 'b> HandlerBuilder<'a, 'b> {
return Some(factor);
}

if self.auto_rebase {
if self.auto_rebase || self.auto_rebase_pow2 {
// Find the last operator that might need rebase
for stage in self.stages.iter().rev() {
let operator = match stage {
Expand Down
72 changes: 67 additions & 5 deletions common/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,13 @@ pub enum CommittedPolynomial {
///
/// `1` - d
NodeOutputRaD(usize, usize),
CosRaD(usize, usize), // One-hot read addresses for Cos lookup
ErfRaD(usize, usize), // One-hot read addresses for Erf lookup
SigmoidRaD(usize, usize), // One-hot read addresses for Sigmoid lookup
SinRaD(usize, usize), // One-hot read addresses for Sin lookup
TanhRaD(usize, usize), // One-hot read addresses for Tanh lookup
CosRaD(usize, usize), // One-hot read addresses for Cos lookup
ErfRaD(usize, usize), // One-hot read addresses for Erf lookup
SigmoidRaD(usize, usize), // One-hot read addresses for Sigmoid lookup
SinRaD(usize, usize), // One-hot read addresses for Sin lookup
ScalarConstDivPow2RaD(usize, usize), // One-hot read addresses for ScalarConstDivPow2 remainder lookup
ScalarConstDivRangeCheckRaD(usize, usize), // Interleaved remainder and constant divisor for ScalarConstDiv
TanhRaD(usize, usize), // One-hot read addresses for Tanh lookup

/// Fields:
///
Expand Down Expand Up @@ -86,10 +88,14 @@ pub enum VirtualPolynomial {
// Those are proven by the ReadRafSumcheckProver,
// from Committed one-hot polynomials.
DivRangeCheckRa(usize),
ScalarConstDivDivisor(usize),
ScalarConstDivRangeCheckRa(usize),
SqrtRangeCheckRa(usize),
TeleportRangeCheckRa(usize),

DivRemainder(usize),
ScalarConstDivPow2Divisor(usize),
ScalarConstDivPow2Ra(usize),
SqrtRemainder(usize),
TeleportQuotient(usize), // Quotient polynomial for neural teleportation lookups
TeleportRemainder(usize), // Remainder polynomial for neural teleportation lookups
Expand Down Expand Up @@ -187,6 +193,16 @@ impl CanonicalSerialize for CommittedPolynomial {
a.serialize_with_mode(&mut writer, compress)?;
b.serialize_with_mode(&mut writer, compress)?;
}
Self::ScalarConstDivPow2RaD(a, b) => {
17u8.serialize_with_mode(&mut writer, compress)?;
a.serialize_with_mode(&mut writer, compress)?;
b.serialize_with_mode(&mut writer, compress)?;
}
Self::ScalarConstDivRangeCheckRaD(a, b) => {
18u8.serialize_with_mode(&mut writer, compress)?;
a.serialize_with_mode(&mut writer, compress)?;
b.serialize_with_mode(&mut writer, compress)?;
}
}
Ok(())
}
Expand All @@ -199,6 +215,8 @@ impl CanonicalSerialize for CommittedPolynomial {
| Self::SigmoidRaD(a, b)
| Self::CosRaD(a, b)
| Self::SinRaD(a, b)
| Self::ScalarConstDivPow2RaD(a, b)
| Self::ScalarConstDivRangeCheckRaD(a, b)
| Self::SoftmaxRemainder(a, b)
| Self::DivRangeCheckRaD(a, b)
| Self::SqrtDivRangeCheckRaD(a, b)
Expand Down Expand Up @@ -306,6 +324,14 @@ impl CanonicalDeserialize for CommittedPolynomial {
usize::deserialize_with_mode(&mut reader, compress, validate)?,
usize::deserialize_with_mode(&mut reader, compress, validate)?,
)),
17 => Ok(Self::ScalarConstDivPow2RaD(
usize::deserialize_with_mode(&mut reader, compress, validate)?,
usize::deserialize_with_mode(&mut reader, compress, validate)?,
)),
18 => Ok(Self::ScalarConstDivRangeCheckRaD(
usize::deserialize_with_mode(&mut reader, compress, validate)?,
usize::deserialize_with_mode(&mut reader, compress, validate)?,
)),
_ => Err(SerializationError::InvalidData),
}
}
Expand Down Expand Up @@ -421,6 +447,22 @@ impl CanonicalSerialize for VirtualPolynomial {
22u8.serialize_with_mode(&mut writer, compress)?;
a.serialize_with_mode(&mut writer, compress)?;
}
Self::ScalarConstDivDivisor(a) => {
23u8.serialize_with_mode(&mut writer, compress)?;
a.serialize_with_mode(&mut writer, compress)?;
}
Self::ScalarConstDivRangeCheckRa(a) => {
24u8.serialize_with_mode(&mut writer, compress)?;
a.serialize_with_mode(&mut writer, compress)?;
}
Self::ScalarConstDivPow2Divisor(a) => {
25u8.serialize_with_mode(&mut writer, compress)?;
a.serialize_with_mode(&mut writer, compress)?;
}
Self::ScalarConstDivPow2Ra(a) => {
26u8.serialize_with_mode(&mut writer, compress)?;
a.serialize_with_mode(&mut writer, compress)?;
}
}
Ok(())
}
Expand All @@ -436,9 +478,13 @@ impl CanonicalSerialize for VirtualPolynomial {
| Self::SinRa(a)
| Self::TanhRa(a)
| Self::DivRangeCheckRa(a)
| Self::ScalarConstDivDivisor(a)
| Self::ScalarConstDivRangeCheckRa(a)
| Self::SqrtRangeCheckRa(a)
| Self::TeleportRangeCheckRa(a)
| Self::DivRemainder(a)
| Self::ScalarConstDivPow2Divisor(a)
| Self::ScalarConstDivPow2Ra(a)
| Self::SqrtRemainder(a)
| Self::TeleportQuotient(a)
| Self::TeleportRemainder(a) => a.serialized_size(compress),
Expand Down Expand Up @@ -573,6 +619,22 @@ impl CanonicalDeserialize for VirtualPolynomial {
compress,
validate,
)?)),
23 => Ok(Self::ScalarConstDivDivisor(usize::deserialize_with_mode(
&mut reader,
compress,
validate,
)?)),
24 => Ok(Self::ScalarConstDivRangeCheckRa(
usize::deserialize_with_mode(&mut reader, compress, validate)?,
)),
25 => Ok(Self::ScalarConstDivPow2Divisor(
usize::deserialize_with_mode(&mut reader, compress, validate)?,
)),
26 => Ok(Self::ScalarConstDivPow2Ra(usize::deserialize_with_mode(
&mut reader,
compress,
validate,
)?)),
_ => Err(SerializationError::InvalidData),
}
}
Expand Down
3 changes: 3 additions & 0 deletions jolt-atlas-core/src/onnx_proof/ops/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,8 @@ pub mod reshape;
pub mod rsqrt;
/// Division by a scalar constant.
pub mod scalar_const_div;
/// Division by a power-of-two scalar constant.
pub mod scalar_const_div_pow2;
/// Sigmoid activation function.
pub mod sigmoid;
/// Sin trigonometric function.
Expand Down Expand Up @@ -241,6 +243,7 @@ macro_rules! dispatch_operator {
Operator::Reshape($inner) => $body,
Operator::Rsqrt($inner) => $body,
Operator::ScalarConstDiv($inner) => $body,
Operator::ScalarConstDivPow2($inner) => $body,
Operator::Sigmoid($inner) => $body,
Operator::Sin($inner) => $body,
Operator::Slice($inner) => $body,
Expand Down
Loading
Loading