This document provides a high-level overview of VAJAX's architecture for developers new to the codebase.
VAJAX is built on three core principles:
- Functional Device Models: Devices are pure JAX functions compiled from Verilog-A
- Automatic Differentiation: Jacobians computed via JAX autodiff, no explicit derivatives
- Vectorization: Same-type devices evaluated in parallel via
jax.vmap
┌─────────────────────────────────────────────────────────────────────┐
│ User Code │
│ from vajax import CircuitEngine │
│ engine = CircuitEngine("circuit.sim") │
│ engine.parse() │
│ result = engine.run_transient() │
└─────────────────────────────────────────────────────────────────────┘
│
▼
┌─────────────────────────────────────────────────────────────────────┐
│ CircuitEngine (analysis/engine.py) │
│ ┌─────────────┐ ┌──────────────────┐ ┌───────────────────────┐ │
│ │ run_transient│ │ run_ac │ │ run_noise │ │
│ │ (lax.scan) │ │ (AC analysis) │ │ (noise analysis) │ │
│ └─────────────┘ └──────────────────┘ └───────────────────────┘ │
│ ┌─────────────┐ ┌──────────────────┐ ┌───────────────────────┐ │
│ │ run_corners │ │ run_dcinc │ │ run_dcxf/run_acxf │ │
│ │ (PVT sweep) │ │ (transfer funcs) │ │ (transfer functions) │ │
│ └─────────────┘ └──────────────────┘ └───────────────────────┘ │
└─────────────────────────────────────────────────────────────────────┘
│
▼
┌─────────────────────────────────────────────────────────────────────┐
│ Device Layer │
│ ┌───────────────────────────────────────────────────────────────┐ │
│ │ OpenVAF Compiled Verilog-A Models │ │
│ │ resistor.va capacitor.va diode.va psp103.va ... │ │
│ └───────────────────────────────────────────────────────────────┘ │
│ ┌───────────────────────────────────────────────────────────────┐ │
│ │ Built-in Sources (vsource.py) │ │
│ │ DC, Pulse, Sine, PWL voltage/current sources │ │
│ └───────────────────────────────────────────────────────────────┘ │
└─────────────────────────────────────────────────────────────────────┘
│
▼
┌─────────────────────────────────────────────────────────────────────┐
│ JAX Runtime │
│ ┌──────────────┐ ┌─────────────┐ ┌────────────────────┐ │
│ │ JIT │ │ vmap │ │ Autodiff │ │
│ │ Compilation │ │ Batched │ │ (Jacobians via │ │
│ │ (lax.scan) │ │ Device Eval │ │ jacfwd/jvp) │ │
│ └──────────────┘ └─────────────┘ └────────────────────┘ │
└─────────────────────────────────────────────────────────────────────┘
1. Load Circuit File
└── engine = CircuitEngine("circuit.sim")
└── engine.parse()
2. Netlist Processing
├── Parse VACASK .sim file or SPICE netlist
├── Load device models (.osdi files via OpenVAF)
├── Build node map (node name → index)
└── Flatten hierarchical instances
3. Device Compilation
├── Compile Verilog-A models with OpenVAF
├── Generate JAX-compatible device functions
└── Batch devices by type for vmap evaluation
4. System Builder
└── _make_full_mna_build_system_fn()
├── Creates JIT-compiled residual/Jacobian builder
├── Uses full MNA with branch currents as unknowns
├── Batches device evaluations via vmap
└── Handles sparse vs dense matrix assembly
1. Initial Conditions
└── Run DC operating point via Newton-Raphson
2. Time Integration (lax.scan)
│
├── For each time step t = dt, 2*dt, ..., t_stop:
│ │
│ ├── Update source waveforms (pulse, sine, PWL)
│ │
│ ├── Build system (residual f, Jacobian J)
│ │ ├── Evaluate all devices via batched vmap
│ │ ├── Stamp currents → residual vector
│ │ └── Stamp conductances → Jacobian matrix
│ │
│ ├── Newton-Raphson iteration (lax.while_loop):
│ │ ├── Solve: delta_V = solve(J, -f)
│ │ │ ├── Dense: jax.scipy.linalg.solve()
│ │ │ └── Sparse: jax.experimental.sparse.linalg.spsolve()
│ │ ├── Update: V = V + delta_V
│ │ └── Check: max(|f|) < abstol?
│ │
│ └── Store solution: V[t] appended to trajectory
│
└── Return TransientResult(times, voltages)
3. GPU Efficiency
└── lax.scan enables full GPU execution without Python callbacks
The central class that manages circuit parsing, device compilation, and analysis.
class CircuitEngine:
# Core data
circuit_file: str # Path to .sim or SPICE file
node_map: Dict[str, int] # Node name → index
models: Dict[str, CompiledModel] # OpenVAF compiled models
device_data: Dict[str, DeviceInfo] # Device instances and parameters
# Parsing
def parse() -> None # Parse netlist, compile models
# Analysis methods
def prepare(t_stop, dt, ...) -> None
def run_transient() -> TransientResult
def run_ac(freq_start, freq_stop, ...) -> ACResult
def run_noise(out, input_source, ...) -> NoiseResult
def run_corners(corners) -> CornerSweepResult
def run_dcinc() -> DCIncResult
def run_dcxf(out) -> DCXFResult
def run_acxf(out, freq_start, freq_stop, ...) -> ACXFResult
# Internal system building
def _make_full_mna_build_system_fn() -> CallableThe output of transient analysis.
@dataclass
class TransientResult:
times: Array # Shape: (n_steps,) time points
voltages: Dict[str, Array] # node_name → voltage array
currents: Dict[str, Array] # source_name → current array
stats: Dict[str, Any] # Simulation statistics (wall_time, etc.)
# Properties
num_steps: int # len(times)
node_names: List[str] # list(voltages.keys())
source_names: List[str] # list(currents.keys())
# Methods
def voltage(node: str) -> Array # Case-insensitive node voltage lookup
def current(source: str) -> Array # Case-insensitive source current lookupDevices are compiled from Verilog-A and evaluated in batches:
# OpenVAF compiles Verilog-A to JAX-compatible functions
model = compile_va("resistor.va")
# Device evaluation function signature (simplified):
def device_fn(
voltages: Array, # Terminal voltages [n_devices, n_terminals]
params: Array, # Device parameters [n_devices, n_params]
temperature: float, # Operating temperature
) -> Tuple[Array, Array]:
# Returns (currents, conductances) for MNA stamping
...
# Batched evaluation via vmap
currents, G = jax.vmap(device_fn)(V_terminals, params_batch, temp)All devices are compiled from Verilog-A source:
resistor.va → OpenVAF Compiler → MIR/OSDI → JAX Function
Example Verilog-A source (resistor.va):
module resistor(p, n);
inout p, n;
electrical p, n;
parameter real r = 1k;
parameter real tc1 = 0.0;
parameter real tc2 = 0.0;
analog begin
I(p, n) <+ V(p, n) / r * (1 + tc1*dT + tc2*dT*dT);
end
endmoduleOpenVAF compiles this to a pure JAX function that:
- Takes terminal voltages and parameters as input
- Returns currents and conductance matrix
- Is automatically differentiable for Jacobian computation
- Can be batched with
jax.vmapfor parallel evaluation
- PDK Compatibility: Use production models (PSP103, BSIM4) directly
- Standardization: Industry-standard compact model format
- Validation: Models tested against commercial simulators
- Maintainability: One source for all backends (JAX, VACASK, ngspice)
For large circuits (>1000 nodes), VAJAX uses JAX's native sparse formats:
from jax.experimental.sparse import BCOO, BCSR
from jax.experimental.sparse.linalg import spsolve
# Build sparse Jacobian from COO triplets
def build_sparse_jacobian(rows, cols, values, shape):
# Use pure JAX for COO→CSR conversion
data, indices, indptr = build_csr_arrays(rows, cols, values, shape)
return data, indices, indptr
# Solve sparse system
# JAX spsolve works on CPU and GPU (via cuSOLVER)
delta_V = spsolve(data, indices, indptr, -residual, tol=0)| Format | Usage | Notes |
|---|---|---|
| BCOO | Matrix construction | JAX native COO, efficient for building |
| BCSR | Linear solve | CSR required by spsolve |
| Circuit Size | Solver | Reason |
|---|---|---|
| < 1000 nodes | Dense | Lower overhead, jax.scipy.linalg.solve() |
| ≥ 1000 nodes | Sparse | Memory efficiency, spsolve() |
The switch is controlled by use_sparse=True in prepare():
engine.prepare(t_stop=1e-6, dt=1e-9, use_sparse=True)
result = engine.run_transient()Verilog-A (.va)
│
▼
OpenVAF Compiler
│
▼
MIR (Mid-level IR)
│
▼
openvaf_jax Translator
│
▼
JAX Function
class VerilogADevice:
def __init__(self, compiled_model, params):
self.eval_fn = openvaf_jax.translate(compiled_model)
self.params = params
self.n_internal = compiled_model.n_internal_nodes
def evaluate(self, V, params, context):
# Call the JAX-translated function
outputs = self.eval_fn(V, params)
# Extract currents and conductances from outputs
return DeviceStamps(
currents=outputs.currents,
conductances=outputs.jacobian
)@jax.jit
def newton_step(V, system, context):
residual, J = system.build_jacobian_and_residual(V, context)
delta_V = jnp.linalg.solve(J, -residual)
return V + delta_VFirst call: ~1-5 seconds (compilation) Subsequent calls: ~1-20 ms
# Without vectorization (slow)
currents = []
for i in range(n_devices):
I = device_fn(V[nodes[i]], params[i])
currents.append(I)
# With vectorization (fast)
currents = jax.vmap(device_fn)(V[node_indices], batched_params)Speedup: 10-100x depending on device count
For a detailed walkthrough of how all these mechanisms compose for a real circuit, see Parallelism Architecture: c6288 Case Study.
Pre-computing parameter arrays eliminates Python loops:
# During build_device_groups()
params = {
"r": jnp.array([d.params["r"] for d in resistors]),
"tc1": jnp.array([d.params.get("tc1", 0) for d in resistors]),
}
# Now vmap can use these directlyGradually ramps supply voltage from 0 to target:
VDD steps: 0.0 → 0.12 → 0.24 → ... → 1.2V
At each step:
1. Solve DC with current VDD
2. Use solution as initial guess for next step
Helps with digital circuits that have multiple stable states.
Adds decreasing conductance to ground at each node:
GMIN steps: 1e-3 → 1e-6 → 1e-9 → 1e-12 S
At each step:
1. Add GMIN*V to residual (pulls nodes toward 0)
2. Solve DC
3. Reduce GMIN and repeat
Prevents floating nodes and improves conditioning.
| File | Purpose |
|---|---|
analysis/engine.py |
CircuitEngine - main simulation API, parsing, all analyses |
analysis/solver.py |
Newton-Raphson solver with lax.while_loop |
analysis/transient/ |
Transient analysis (scan/loop strategies) |
analysis/ac.py |
AC small-signal analysis |
analysis/noise.py |
Noise analysis |
analysis/hb.py |
Harmonic balance analysis |
analysis/xfer.py |
Transfer function (DCINC, DCXF, ACXF) |
analysis/corners.py |
PVT corner analysis |
analysis/homotopy.py |
Convergence aids (GMIN, source stepping) |
analysis/sparse.py |
JAX sparse utilities (BCOO/BCSR, spsolve) |
devices/vsource.py |
Voltage/current source waveforms |
devices/verilog_a.py |
OpenVAF Verilog-A device wrapper |
netlist/parser.py |
VACASK netlist parser |
benchmarks/runner.py |
VACASK benchmark runner |
benchmarks/registry.py |
Auto-discovery of benchmark circuits |
from vajax import CircuitEngine
# Load circuit from VACASK .sim file
engine = CircuitEngine("vendor/VACASK/sim/ring.sim")
engine.parse()
# Prepare and run transient analysis
engine.prepare(t_stop=1e-6, dt=1e-9, use_sparse=True)
result = engine.run_transient()# Get all node voltages at final time
for node_name, voltages in result.voltages.items():
print(f"{node_name}: {voltages[-1]:.3f}V")
# Get specific node over time
vout = result.voltages["out"] # Array of voltages
# Time array
times = result.times# Transient
engine.prepare(t_stop=1e-6, dt=1e-9)
tran_result = engine.run_transient()
# AC analysis
ac_result = engine.run_ac(freq_start=1e3, freq_stop=1e9, points=100)
# Noise analysis
noise_result = engine.run_noise(
freq_start=1e3, freq_stop=1e9,
input_source="vin", out="vout"
)
# PVT corners
from vajax.analysis.corners import create_pvt_corners
corners = create_pvt_corners(
processes=['TT', 'FF', 'SS'],
voltages=[0.9, 1.0, 1.1],
temperatures=['cold', 'room', 'hot'],
)
corner_results = engine.run_corners(corners)When investigating issues, start here:
- Parsing issues: Check
CircuitEngine.parse()inengine.py - Device compilation: Check OpenVAF model loading in
_compile_openvaf_models() - System building: Check
_make_full_mna_build_system_fn()for J/f construction - Convergence issues: Look at Newton-Raphson loop in
solver.py - Sparse solver: Check
sparse.pyfor BCOO/BCSR operations - Source waveforms: Check
vsource.pyfor pulse/sine/PWL evaluation