From 58f150b06e98f0a95124b33f5df2828b70643267 Mon Sep 17 00:00:00 2001 From: "brady.ouren" Date: Mon, 15 Sep 2025 07:55:31 -0700 Subject: [PATCH 1/6] initial simple static cost --- clarity/src/vm/costs/analysis.rs | 372 +++++++++++++++++++++++++++++++ clarity/src/vm/costs/mod.rs | 2 + 2 files changed, 374 insertions(+) create mode 100644 clarity/src/vm/costs/analysis.rs diff --git a/clarity/src/vm/costs/analysis.rs b/clarity/src/vm/costs/analysis.rs new file mode 100644 index 0000000000..898eaee7d3 --- /dev/null +++ b/clarity/src/vm/costs/analysis.rs @@ -0,0 +1,372 @@ +// Static cost analysis for Clarity expressions + +use crate::vm::ast::parser::v2::parse; +use crate::vm::costs::cost_functions::CostValues; +use crate::vm::costs::costs_3::Costs3; +use crate::vm::costs::ExecutionCost; +use crate::vm::errors::InterpreterResult; +use crate::vm::representations::{ClarityName, PreSymbolicExpression, PreSymbolicExpressionType}; + +// TODO: +// variable traverse for +// - if, is-*, match, etc +// contract-call? - how to handle? +// type-checking +// lookups + +#[derive(Debug, Clone)] +pub struct StaticCostNode { + pub function: Vec, + pub cost: StaticCost, + pub children: Vec, +} + +impl StaticCostNode { + pub fn new( + function: Vec, + cost: StaticCost, + children: Vec, + ) -> Self { + Self { + function, + cost, + children, + } + } + + pub fn leaf(function: Vec, cost: StaticCost) -> Self { + Self { + function, + cost, + children: vec![], + } + } +} + +#[derive(Debug, Clone)] +pub struct StaticCost { + pub min: ExecutionCost, + pub max: ExecutionCost, +} + +impl StaticCost { + pub const ZERO: StaticCost = StaticCost { + min: ExecutionCost::ZERO, + max: ExecutionCost::ZERO, + }; +} + +/// Parse Clarity source code and calculate its static execution cost +/// +/// This function takes a Clarity expression as a string, parses it into symbolic +/// expressions, builds a cost tree, and returns the min and max execution cost. +/// theoretically you could inspect the tree at any node to get the spot cost +pub fn static_cost(source: &str) -> Result { + let pre_expressions = parse(source).map_err(|e| format!("Parse error: {:?}", e))?; + + if pre_expressions.is_empty() { + return Err("No expressions found".to_string()); + } + + let pre_expr = &pre_expressions[0]; + let cost_tree = build_cost_tree(pre_expr)?; + + Ok(calculate_total_cost(&cost_tree)) +} + +// TODO: Needs alternative traversals to get min/max +fn build_cost_tree(expr: &PreSymbolicExpression) -> Result { + match &expr.pre_expr { + PreSymbolicExpressionType::List(list) => { + if list.is_empty() { + return Err("Empty list expression".to_string()); + } + + let function_name = match &list[0].pre_expr { + PreSymbolicExpressionType::Atom(name) => name, + _ => { + return Err("First element of list must be an atom (function name)".to_string()) + } + }; + + // TODO this is wrong + let args = &list[1..]; + let mut children = Vec::new(); + + for arg in args { + children.push(build_cost_tree(arg)?); + } + + let cost = calculate_function_cost(function_name, args.len() as u64)?; + + Ok(StaticCostNode::new(list.clone(), cost, children)) + } + PreSymbolicExpressionType::AtomValue(_value) => { + Ok(StaticCostNode::leaf(vec![expr.clone()], StaticCost::ZERO)) + } + PreSymbolicExpressionType::Atom(_name) => { + Ok(StaticCostNode::leaf(vec![expr.clone()], StaticCost::ZERO)) + } + PreSymbolicExpressionType::Tuple(tuple) => { + let function_name = match &tuple[0].pre_expr { + PreSymbolicExpressionType::Atom(name) => name, + _ => { + return Err("First element of tuple must be an atom (function name)".to_string()) + } + }; + + let args = &tuple[1..]; + let mut children = Vec::new(); + + for arg in args { + children.push(build_cost_tree(arg)?); + } + + let cost = calculate_function_cost(function_name, args.len() as u64)?; + + Ok(StaticCostNode::new(tuple.clone(), cost, children)) + } + _ => Err("Unsupported expression type for cost analysis".to_string()), + } +} + +fn calculate_function_cost( + function_name: &ClarityName, + arg_count: u64, +) -> Result { + let cost_function = match get_cost_function_for_name(function_name) { + Some(cost_fn) => cost_fn, + None => { + // TODO: zero cost for now + return Ok(StaticCost::ZERO); + } + }; + + let cost = get_costs(cost_function, arg_count)?; + Ok(StaticCost { + min: cost.clone(), + max: cost, + }) +} + +/// Convert a function name to its corresponding cost function +fn get_cost_function_for_name( + name: &ClarityName, +) -> Option InterpreterResult> { + let name_str = name.as_str(); + + // Map function names to their cost functions using the existing enum structure + match name_str { + "+" | "add" => Some(Costs3::cost_add), + "-" | "sub" => Some(Costs3::cost_sub), + "*" | "mul" => Some(Costs3::cost_mul), + "/" | "div" => Some(Costs3::cost_div), + "mod" => Some(Costs3::cost_mod), + "pow" => Some(Costs3::cost_pow), + "sqrti" => Some(Costs3::cost_sqrti), + "log2" => Some(Costs3::cost_log2), + "to-int" | "to-uint" | "int-cast" => Some(Costs3::cost_int_cast), + "is-eq" | "=" | "eq" => Some(Costs3::cost_eq), + ">=" | "geq" => Some(Costs3::cost_geq), + "<=" | "leq" => Some(Costs3::cost_leq), + ">" | "ge" => Some(Costs3::cost_ge), + "<" | "le" => Some(Costs3::cost_le), + "xor" => Some(Costs3::cost_xor), + "not" => Some(Costs3::cost_not), + "and" => Some(Costs3::cost_and), + "or" => Some(Costs3::cost_or), + "concat" => Some(Costs3::cost_concat), + "len" => Some(Costs3::cost_len), + "as-max-len?" => Some(Costs3::cost_as_max_len), + "list" => Some(Costs3::cost_list_cons), + "element-at" | "element-at?" => Some(Costs3::cost_element_at), + "index-of" | "index-of?" => Some(Costs3::cost_index_of), + "fold" => Some(Costs3::cost_fold), + "map" => Some(Costs3::cost_map), + "filter" => Some(Costs3::cost_filter), + "append" => Some(Costs3::cost_append), + "tuple-get" => Some(Costs3::cost_tuple_get), + "tuple-merge" => Some(Costs3::cost_tuple_merge), + "tuple" => Some(Costs3::cost_tuple_cons), + "some" => Some(Costs3::cost_some_cons), + "ok" => Some(Costs3::cost_ok_cons), + "err" => Some(Costs3::cost_err_cons), + "default-to" => Some(Costs3::cost_default_to), + "unwrap!" => Some(Costs3::cost_unwrap_ret), + "unwrap-err!" => Some(Costs3::cost_unwrap_err_or_ret), + "is-ok" => Some(Costs3::cost_is_okay), + "is-none" => Some(Costs3::cost_is_none), + "is-err" => Some(Costs3::cost_is_err), + "is-some" => Some(Costs3::cost_is_some), + "unwrap-panic" => Some(Costs3::cost_unwrap), + "unwrap-err-panic" => Some(Costs3::cost_unwrap_err), + "try!" => Some(Costs3::cost_try_ret), + "if" => Some(Costs3::cost_if), + "match" => Some(Costs3::cost_match), + "begin" => Some(Costs3::cost_begin), + "let" => Some(Costs3::cost_let), + "asserts!" => Some(Costs3::cost_asserts), + "hash160" => Some(Costs3::cost_hash160), + "sha256" => Some(Costs3::cost_sha256), + "sha512" => Some(Costs3::cost_sha512), + "sha512/256" => Some(Costs3::cost_sha512t256), + "keccak256" => Some(Costs3::cost_keccak256), + "secp256k1-recover?" => Some(Costs3::cost_secp256k1recover), + "secp256k1-verify" => Some(Costs3::cost_secp256k1verify), + "print" => Some(Costs3::cost_print), + "contract-call?" => Some(Costs3::cost_contract_call), + "contract-of" => Some(Costs3::cost_contract_of), + "principal-of?" => Some(Costs3::cost_principal_of), + "at-block" => Some(Costs3::cost_at_block), + "load-contract" => Some(Costs3::cost_load_contract), + "create-map" => Some(Costs3::cost_create_map), + "create-var" => Some(Costs3::cost_create_var), + "create-non-fungible-token" => Some(Costs3::cost_create_nft), + "create-fungible-token" => Some(Costs3::cost_create_ft), + "map-get?" => Some(Costs3::cost_fetch_entry), + "map-set!" => Some(Costs3::cost_set_entry), + "var-get" => Some(Costs3::cost_fetch_var), + "var-set!" => Some(Costs3::cost_set_var), + "contract-storage" => Some(Costs3::cost_contract_storage), + "get-block-info?" => Some(Costs3::cost_block_info), + "get-burn-block-info?" => Some(Costs3::cost_burn_block_info), + "stx-get-balance" => Some(Costs3::cost_stx_balance), + "stx-transfer?" => Some(Costs3::cost_stx_transfer), + "stx-transfer-memo?" => Some(Costs3::cost_stx_transfer_memo), + "stx-account" => Some(Costs3::cost_stx_account), + "ft-mint?" => Some(Costs3::cost_ft_mint), + "ft-transfer?" => Some(Costs3::cost_ft_transfer), + "ft-get-balance" => Some(Costs3::cost_ft_balance), + "ft-get-supply" => Some(Costs3::cost_ft_get_supply), + "ft-burn?" => Some(Costs3::cost_ft_burn), + "nft-mint?" => Some(Costs3::cost_nft_mint), + "nft-transfer?" => Some(Costs3::cost_nft_transfer), + "nft-get-owner?" => Some(Costs3::cost_nft_owner), + "nft-burn?" => Some(Costs3::cost_nft_burn), + "buff-to-int-le?" => Some(Costs3::cost_buff_to_int_le), + "buff-to-uint-le?" => Some(Costs3::cost_buff_to_uint_le), + "buff-to-int-be?" => Some(Costs3::cost_buff_to_int_be), + "buff-to-uint-be?" => Some(Costs3::cost_buff_to_uint_be), + "to-consensus-buff?" => Some(Costs3::cost_to_consensus_buff), + "from-consensus-buff?" => Some(Costs3::cost_from_consensus_buff), + "is-standard?" => Some(Costs3::cost_is_standard), + "principal-destruct" => Some(Costs3::cost_principal_destruct), + "principal-construct?" => Some(Costs3::cost_principal_construct), + "as-contract" => Some(Costs3::cost_as_contract), + "string-to-int?" => Some(Costs3::cost_string_to_int), + "string-to-uint?" => Some(Costs3::cost_string_to_uint), + "int-to-ascii" => Some(Costs3::cost_int_to_ascii), + "int-to-utf8?" => Some(Costs3::cost_int_to_utf8), + _ => None, // Unknown function name + } +} + +fn get_max_input_size_for_function_name(function_name: &ClarityName, arg_count: u64) -> u64 { + let name_str = function_name.as_str(); + + match name_str { + "concat" => { + // For string concatenation, max size is the sum of max string lengths + // Each string can be up to MAX_VALUE_SIZE (1MB), so for n strings it's n * MAX_VALUE_SIZE + arg_count * 1024 * 1024 + } + "len" => { + // For length, maximum string length + 1024 * 1024 // MAX_VALUE_SIZE + } + _ => { + // Default case - use a fixed max size to match original behavior + // The original code used 2000 as the max input size for arithmetic operations + 2000 + } + } +} + +fn calculate_total_cost(node: &StaticCostNode) -> StaticCost { + let mut min_total = node.cost.min.clone(); + let mut max_total = node.cost.max.clone(); + + // Add costs from all children + // TODO: this should traverse different paths to get min and max costs + for child in &node.children { + let child_cost = calculate_total_cost(child); + let _ = min_total.add(&child_cost.min); + let _ = max_total.add(&child_cost.max); + } + + StaticCost { + min: min_total, + max: max_total, + } +} + +/// Helper: calculate min & max costs for a given cost function +/// This is likely tooo simplistic but for now it'll do +fn get_costs( + cost_fn: fn(u64) -> InterpreterResult, + arg_count: u64, +) -> Result { + let cost = cost_fn(arg_count).map_err(|e| format!("Cost calculation error: {:?}", e))?; + Ok(cost) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_constant() { + let source = "u2"; + let cost = static_cost(source).unwrap(); + assert_eq!(cost.min.runtime, 0); + assert_eq!(cost.max.runtime, 0); + } + + #[test] + fn test_simple_addition() { + let source = "(+ u1 u2)"; + let cost = static_cost(source).unwrap(); + + // Min: linear(2, 11, 125) = 11*2 + 125 = 147 + assert_eq!(cost.min.runtime, 147); + assert_eq!(cost.max.runtime, 147); + } + + #[test] + fn test_arithmetic() { + let source = "(- u4 (+ u1 u2))"; + let cost = static_cost(source).unwrap(); + assert_eq!(cost.min.runtime, 147 + 147); + assert_eq!(cost.max.runtime, 147 + 147); + } + + #[test] + fn test_nested_operations() { + let source = "(* (+ u1 u2) (- u3 u4))"; + let cost = static_cost(source).unwrap(); + // multiplication: 13*2 + 125 = 151 + assert_eq!(cost.min.runtime, 151 + 147 + 147); + assert_eq!(cost.max.runtime, 151 + 147 + 147); + } + + #[test] + fn test_string_concat_min_max() { + let source = "(concat \"hello\" \"world\")"; + let cost = static_cost(source).unwrap(); + + // For concat with 2 arguments: + // linear(2, 37, 220) = 37*2 + 220 = 294 + assert_eq!(cost.min.runtime, 294); + assert_eq!(cost.max.runtime, 294); + } + + #[test] + fn test_string_len_min_max() { + let source = "(len \"hello\")"; + let cost = static_cost(source).unwrap(); + + // cost: 429 (constant) - len doesn't depend on string size + assert_eq!(cost.min.runtime, 429); + assert_eq!(cost.max.runtime, 429); + } +} diff --git a/clarity/src/vm/costs/mod.rs b/clarity/src/vm/costs/mod.rs index 3eee14e96e..edce399481 100644 --- a/clarity/src/vm/costs/mod.rs +++ b/clarity/src/vm/costs/mod.rs @@ -42,6 +42,8 @@ use crate::vm::types::{ FunctionType, PrincipalData, QualifiedContractIdentifier, TupleData, TypeSignature, }; use crate::vm::{CallStack, ClarityName, Environment, LocalContext, SymbolicExpression, Value}; + +pub mod analysis; pub mod constants; pub mod cost_functions; #[allow(unused_variables)] From c6a3c33b29d42ec68994fbf7996f1a5f865b2c73 Mon Sep 17 00:00:00 2001 From: "brady.ouren" Date: Mon, 15 Sep 2025 11:42:33 -0700 Subject: [PATCH 2/6] add ExprTree building with branching flag --- clarity/src/vm/costs/analysis.rs | 244 ++++++++++++++++++++++++++++++- 1 file changed, 240 insertions(+), 4 deletions(-) diff --git a/clarity/src/vm/costs/analysis.rs b/clarity/src/vm/costs/analysis.rs index 898eaee7d3..0664399b89 100644 --- a/clarity/src/vm/costs/analysis.rs +++ b/clarity/src/vm/costs/analysis.rs @@ -1,5 +1,9 @@ // Static cost analysis for Clarity expressions +use clarity_serialization::representations::ContractName; +use clarity_serialization::types::TraitIdentifier; +use clarity_serialization::Value; + use crate::vm::ast::parser::v2::parse; use crate::vm::costs::cost_functions::CostValues; use crate::vm::costs::costs_3::Costs3; @@ -13,6 +17,7 @@ use crate::vm::representations::{ClarityName, PreSymbolicExpression, PreSymbolic // contract-call? - how to handle? // type-checking // lookups +// unwrap evaluates both branches (https://github.com/clarity-lang/reference/issues/59) #[derive(Debug, Clone)] pub struct StaticCostNode { @@ -68,20 +73,175 @@ pub fn static_cost(source: &str) -> Result { return Err("No expressions found".to_string()); } + // TODO what happens if multiple expressions are selected? let pre_expr = &pre_expressions[0]; + let _expr_tree = build_expr_tree(pre_expr)?; let cost_tree = build_cost_tree(pre_expr)?; Ok(calculate_total_cost(&cost_tree)) } +#[derive(Debug, Clone)] +pub enum ExprNode { + If, + Match, + Unwrap, + Ok, + Err, + GT, + LT, + GE, + LE, + EQ, + Add, + Sub, + Mul, + Div, + // Other functions + Function(ClarityName), + // Values + AtomValue(Value), + Atom(ClarityName), + // Placeholder for sugared identifiers + SugaredContractIdentifier(ContractName), + SugaredFieldIdentifier(ContractName, ClarityName), + FieldIdentifier(TraitIdentifier), + TraitReference(ClarityName), +} + +#[derive(Debug, Clone)] +pub struct ExprTree { + pub expr: ExprNode, + pub children: Vec, + pub branching: bool, +} + +/// Build an expression tree, skipping comments and placeholders +fn build_expr_tree(expr: &PreSymbolicExpression) -> Result { + match &expr.pre_expr { + PreSymbolicExpressionType::List(list) => build_listlike_expr_tree(list, "list"), + PreSymbolicExpressionType::AtomValue(value) => Ok(ExprTree { + expr: ExprNode::AtomValue(value.clone()), + children: vec![], + branching: false, + }), + PreSymbolicExpressionType::Atom(name) => Ok(ExprTree { + expr: ExprNode::Atom(name.clone()), + children: vec![], + branching: false, + }), + PreSymbolicExpressionType::Tuple(tuple) => build_listlike_expr_tree(tuple, "tuple"), + PreSymbolicExpressionType::SugaredContractIdentifier(contract_name) => { + // TODO: Look up the source for this contract identifier + Ok(ExprTree { + expr: ExprNode::SugaredContractIdentifier(contract_name.clone()), + children: vec![], + branching: false, + }) + } + PreSymbolicExpressionType::SugaredFieldIdentifier(contract_name, field_name) => { + // TODO: Look up the source for this field identifier + Ok(ExprTree { + expr: ExprNode::SugaredFieldIdentifier(contract_name.clone(), field_name.clone()), + children: vec![], + branching: false, + }) + } + PreSymbolicExpressionType::FieldIdentifier(field_name) => Ok(ExprTree { + expr: ExprNode::FieldIdentifier(field_name.clone()), + children: vec![], + branching: false, + }), + PreSymbolicExpressionType::TraitReference(trait_name) => { + // TODO: Look up the source for this trait reference + Ok(ExprTree { + expr: ExprNode::TraitReference(trait_name.clone()), + children: vec![], + branching: false, + }) + } + // Comments and placeholders should be filtered out during traversal + PreSymbolicExpressionType::Comment(_comment) => { + Err("hit an irrelevant comment expr type".to_string()) + } + PreSymbolicExpressionType::Placeholder(_placeholder) => { + Err("hit an irrelevant placeholder expr type".to_string()) + } + } +} + +/// Helper function to build expression trees for both lists and tuples +fn build_listlike_expr_tree( + items: &[PreSymbolicExpression], + container_type: &str, +) -> Result { + let function_name = match &items[0].pre_expr { + PreSymbolicExpressionType::Atom(name) => name, + _ => { + return Err(format!( + "First element of {} must be an atom (function name)", + container_type + )); + } + }; + + let args = &items[1..]; + let mut children = Vec::new(); + + // Build children for all arguments, skipping comments and placeholders + for arg in args { + match &arg.pre_expr { + PreSymbolicExpressionType::Comment(_) | PreSymbolicExpressionType::Placeholder(_) => { + // Skip comments and placeholders + continue; + } + _ => { + children.push(build_expr_tree(arg)?); + } + } + } + + // Determine if this is a branching function + let branching = is_branching_function(function_name); + + // Create the appropriate ExprNode + let expr_node = match function_name.as_str() { + "if" => ExprNode::If, + "match" => ExprNode::Match, + "unwrap!" | "unwrap-err!" | "unwrap-panic" | "unwrap-err-panic" => ExprNode::Unwrap, + "ok" => ExprNode::Ok, + "err" => ExprNode::Err, + ">" => ExprNode::GT, + "<" => ExprNode::LT, + ">=" => ExprNode::GE, + "<=" => ExprNode::LE, + "=" | "is-eq" | "eq" => ExprNode::EQ, + "+" | "add" => ExprNode::Add, + "-" | "sub" => ExprNode::Sub, + "*" | "mul" => ExprNode::Mul, + "/" | "div" => ExprNode::Div, + _ => ExprNode::Function(function_name.clone()), + }; + + Ok(ExprTree { + expr: expr_node, + children, + branching, + }) +} + +/// Determine if a function name represents a branching function +fn is_branching_function(function_name: &ClarityName) -> bool { + match function_name.as_str() { + "if" | "match" | "unwrap!" | "unwrap-err!" => true, + _ => false, + } +} + // TODO: Needs alternative traversals to get min/max fn build_cost_tree(expr: &PreSymbolicExpression) -> Result { match &expr.pre_expr { PreSymbolicExpressionType::List(list) => { - if list.is_empty() { - return Err("Empty list expression".to_string()); - } - let function_name = match &list[0].pre_expr { PreSymbolicExpressionType::Atom(name) => name, _ => { @@ -369,4 +529,80 @@ mod tests { assert_eq!(cost.min.runtime, 429); assert_eq!(cost.max.runtime, 429); } + + #[test] + fn test_build_expr_tree_if_expression() { + let source = "(if (> 3 0) (ok true) (ok false))"; + let pre_expressions = parse(source).unwrap(); + let pre_expr = &pre_expressions[0]; + let expr_tree = build_expr_tree(pre_expr).unwrap(); + + // Root should be an If node with branching=true + assert!(matches!(expr_tree.expr, ExprNode::If)); + assert!(expr_tree.branching); + assert_eq!(expr_tree.children.len(), 3); // condition, then, else + + // First child should be GT comparison + let gt_node = &expr_tree.children[0]; + assert!(matches!(gt_node.expr, ExprNode::GT)); + assert!(!gt_node.branching); + assert_eq!(gt_node.children.len(), 2); // 3 and 0 + + // GT children should be AtomValue(3) and AtomValue(0) + let left_val = >_node.children[0]; + let right_val = >_node.children[1]; + assert!(matches!(left_val.expr, ExprNode::AtomValue(_))); + assert!(matches!(right_val.expr, ExprNode::AtomValue(_))); + + // Second child should be Ok(true) + let ok_true_node = &expr_tree.children[1]; + assert!(matches!(ok_true_node.expr, ExprNode::Ok)); + assert!(!ok_true_node.branching); + assert_eq!(ok_true_node.children.len(), 1); + + // Third child should be Ok(false) + let ok_false_node = &expr_tree.children[2]; + assert!(matches!(ok_false_node.expr, ExprNode::Ok)); + assert!(!ok_false_node.branching); + assert_eq!(ok_false_node.children.len(), 1); + } + + #[test] + fn test_build_expr_tree_arithmetic() { + let source = "(+ (* 2 3) (- 5 1))"; + let pre_expressions = parse(source).unwrap(); + let pre_expr = &pre_expressions[0]; + let expr_tree = build_expr_tree(pre_expr).unwrap(); + + // Root should be Add node + assert!(matches!(expr_tree.expr, ExprNode::Add)); + assert!(!expr_tree.branching); + assert_eq!(expr_tree.children.len(), 2); + + // First child should be Mul + let mul_node = &expr_tree.children[0]; + assert!(matches!(mul_node.expr, ExprNode::Mul)); + assert_eq!(mul_node.children.len(), 2); + + // Second child should be Sub + let sub_node = &expr_tree.children[1]; + assert!(matches!(sub_node.expr, ExprNode::Sub)); + assert_eq!(sub_node.children.len(), 2); + } + + #[test] + fn test_build_expr_tree_with_comments() { + let source = "(+ 1 ;; this is a comment\n 2)"; + let pre_expressions = parse(source).unwrap(); + let pre_expr = &pre_expressions[0]; + let expr_tree = build_expr_tree(pre_expr).unwrap(); + + assert!(matches!(expr_tree.expr, ExprNode::Add)); + assert!(!expr_tree.branching); + assert_eq!(expr_tree.children.len(), 2); + + for child in &expr_tree.children { + assert!(matches!(child.expr, ExprNode::AtomValue(_))); + } + } } From 380b15a72216d4e9e8726598e11ddb03c196615a Mon Sep 17 00:00:00 2001 From: "brady.ouren" Date: Mon, 15 Sep 2025 13:21:37 -0700 Subject: [PATCH 3/6] add branching sums --- clarity/src/vm/costs/analysis.rs | 419 ++++++++++++++++++++++++++----- 1 file changed, 356 insertions(+), 63 deletions(-) diff --git a/clarity/src/vm/costs/analysis.rs b/clarity/src/vm/costs/analysis.rs index 0664399b89..523bf526c0 100644 --- a/clarity/src/vm/costs/analysis.rs +++ b/clarity/src/vm/costs/analysis.rs @@ -1,11 +1,11 @@ // Static cost analysis for Clarity expressions use clarity_serialization::representations::ContractName; -use clarity_serialization::types::TraitIdentifier; +use clarity_serialization::types::{CharType, SequenceData, TraitIdentifier}; use clarity_serialization::Value; use crate::vm::ast::parser::v2::parse; -use crate::vm::costs::cost_functions::CostValues; +use crate::vm::costs::cost_functions::{linear, CostValues}; use crate::vm::costs::costs_3::Costs3; use crate::vm::costs::ExecutionCost; use crate::vm::errors::InterpreterResult; @@ -13,8 +13,8 @@ use crate::vm::representations::{ClarityName, PreSymbolicExpression, PreSymbolic // TODO: // variable traverse for -// - if, is-*, match, etc -// contract-call? - how to handle? +// - if, unwrap-*, match, etc +// contract-call? - get source from database // type-checking // lookups // unwrap evaluates both branches (https://github.com/clarity-lang/reference/issues/59) @@ -61,6 +61,75 @@ impl StaticCost { }; } +/// A type to track summed execution costs for different paths +/// This allows us to compute min and max costs across different execution paths +#[derive(Debug, Clone)] +pub struct SummingExecutionCost { + pub costs: Vec, +} + +impl SummingExecutionCost { + pub fn new() -> Self { + Self { costs: Vec::new() } + } + + pub fn from_single(cost: ExecutionCost) -> Self { + Self { costs: vec![cost] } + } + + pub fn add_cost(&mut self, cost: ExecutionCost) { + self.costs.push(cost); + } + + pub fn add_summing(&mut self, other: &SummingExecutionCost) { + self.costs.extend(other.costs.clone()); + } + + /// Get the minimum cost across all paths + pub fn min(&self) -> ExecutionCost { + if self.costs.is_empty() { + ExecutionCost::ZERO + } else { + self.costs + .iter() + .fold(self.costs[0].clone(), |acc, cost| ExecutionCost { + runtime: acc.runtime.min(cost.runtime), + write_length: acc.write_length.min(cost.write_length), + write_count: acc.write_count.min(cost.write_count), + read_length: acc.read_length.min(cost.read_length), + read_count: acc.read_count.min(cost.read_count), + }) + } + } + + /// Get the maximum cost across all paths + pub fn max(&self) -> ExecutionCost { + if self.costs.is_empty() { + ExecutionCost::ZERO + } else { + self.costs + .iter() + .fold(self.costs[0].clone(), |acc, cost| ExecutionCost { + runtime: acc.runtime.max(cost.runtime), + write_length: acc.write_length.max(cost.write_length), + write_count: acc.write_count.max(cost.write_count), + read_length: acc.read_length.max(cost.read_length), + read_count: acc.read_count.max(cost.read_count), + }) + } + } + + /// Combine costs by adding them (for non-branching operations) + pub fn add_all(&self) -> ExecutionCost { + self.costs + .iter() + .fold(ExecutionCost::ZERO, |mut acc, cost| { + let _ = acc.add(cost); + acc + }) + } +} + /// Parse Clarity source code and calculate its static execution cost /// /// This function takes a Clarity expression as a string, parses it into symbolic @@ -75,10 +144,12 @@ pub fn static_cost(source: &str) -> Result { // TODO what happens if multiple expressions are selected? let pre_expr = &pre_expressions[0]; - let _expr_tree = build_expr_tree(pre_expr)?; - let cost_tree = build_cost_tree(pre_expr)?; + let expr_tree = build_expr_tree(pre_expr)?; + let cost_tree = build_cost_tree(&expr_tree)?; - Ok(calculate_total_cost(&cost_tree)) + // Use branching-aware cost calculation + let summing_cost = calculate_total_cost_with_branching(&expr_tree, &cost_tree); + Ok(summing_cost.into()) } #[derive(Debug, Clone)] @@ -238,63 +309,102 @@ fn is_branching_function(function_name: &ClarityName) -> bool { } } -// TODO: Needs alternative traversals to get min/max -fn build_cost_tree(expr: &PreSymbolicExpression) -> Result { - match &expr.pre_expr { - PreSymbolicExpressionType::List(list) => { - let function_name = match &list[0].pre_expr { - PreSymbolicExpressionType::Atom(name) => name, - _ => { - return Err("First element of list must be an atom (function name)".to_string()) - } - }; - - // TODO this is wrong - let args = &list[1..]; - let mut children = Vec::new(); - - for arg in args { - children.push(build_cost_tree(arg)?); +/// Build a cost tree from an expression tree, using branching logic for min/max calculation +fn build_cost_tree(expr_tree: &ExprTree) -> Result { + let function_name = match &expr_tree.expr { + ExprNode::If => "if", + ExprNode::Match => "match", + ExprNode::Unwrap => "unwrap!", + ExprNode::Ok => "ok", + ExprNode::Err => "err", + ExprNode::GT => ">", + ExprNode::LT => "<", + ExprNode::GE => ">=", + ExprNode::LE => "<=", + ExprNode::EQ => "=", + ExprNode::Add => "+", + ExprNode::Sub => "-", + ExprNode::Mul => "*", + ExprNode::Div => "/", + ExprNode::Function(name) => name.as_str(), + ExprNode::AtomValue(value) => { + // String literals have cost based on length only when they're standalone (not function arguments) + if let Value::Sequence(SequenceData::String(CharType::UTF8(data))) = value { + let length = data.data.len() as u64; + let cost = linear(length, 36, 3); + let execution_cost = ExecutionCost::runtime(cost); + return Ok(StaticCostNode::leaf( + vec![], + StaticCost { + min: execution_cost.clone(), + max: execution_cost, + }, + )); + } else if let Value::Sequence(SequenceData::String(CharType::ASCII(data))) = value { + let length = data.data.len() as u64; + let cost = linear(length, 36, 3); + let execution_cost = ExecutionCost::runtime(cost); + return Ok(StaticCostNode::leaf( + vec![], + StaticCost { + min: execution_cost.clone(), + max: execution_cost, + }, + )); } - - let cost = calculate_function_cost(function_name, args.len() as u64)?; - - Ok(StaticCostNode::new(list.clone(), cost, children)) + // Other atom values have zero cost + return Ok(StaticCostNode::leaf(vec![], StaticCost::ZERO)); } - PreSymbolicExpressionType::AtomValue(_value) => { - Ok(StaticCostNode::leaf(vec![expr.clone()], StaticCost::ZERO)) + ExprNode::Atom(_) + | ExprNode::SugaredContractIdentifier(_) + | ExprNode::SugaredFieldIdentifier(_, _) + | ExprNode::FieldIdentifier(_) + | ExprNode::TraitReference(_) => { + // Leaf nodes have zero cost + return Ok(StaticCostNode::leaf(vec![], StaticCost::ZERO)); } - PreSymbolicExpressionType::Atom(_name) => { - Ok(StaticCostNode::leaf(vec![expr.clone()], StaticCost::ZERO)) - } - PreSymbolicExpressionType::Tuple(tuple) => { - let function_name = match &tuple[0].pre_expr { - PreSymbolicExpressionType::Atom(name) => name, - _ => { - return Err("First element of tuple must be an atom (function name)".to_string()) - } - }; - - let args = &tuple[1..]; - let mut children = Vec::new(); - - for arg in args { - children.push(build_cost_tree(arg)?); + }; + + let mut children = Vec::new(); + for child_expr in &expr_tree.children { + // For certain functions like concat, string arguments should have zero cost + // since the function cost includes their processing + if function_name == "concat" { + if let ExprNode::AtomValue(Value::Sequence(SequenceData::String(_))) = &child_expr.expr + { + // String arguments to concat have zero cost + children.push(StaticCostNode::leaf(vec![], StaticCost::ZERO)); + continue; } + } + children.push(build_cost_tree(child_expr)?); + } - let cost = calculate_function_cost(function_name, args.len() as u64)?; + let cost = calculate_function_cost_from_name(function_name, expr_tree.children.len() as u64)?; - Ok(StaticCostNode::new(tuple.clone(), cost, children)) - } - _ => Err("Unsupported expression type for cost analysis".to_string()), + // Create a representative PreSymbolicExpression for the node + let function_expr = PreSymbolicExpression { + pre_expr: PreSymbolicExpressionType::Atom(ClarityName::from(function_name)), + id: 0, // We don't need accurate IDs for cost analysis + }; + let mut expr_list = vec![function_expr]; + + // Add placeholder expressions for children (we don't need the actual child expressions) + for _ in &expr_tree.children { + expr_list.push(PreSymbolicExpression { + pre_expr: PreSymbolicExpressionType::Atom(ClarityName::from("placeholder")), + id: 0, + }); } + + Ok(StaticCostNode::new(expr_list, cost, children)) } -fn calculate_function_cost( - function_name: &ClarityName, +fn calculate_function_cost_from_name( + function_name: &str, arg_count: u64, ) -> Result { - let cost_function = match get_cost_function_for_name(function_name) { + let cost_function = match get_cost_function_for_name_str(function_name) { Some(cost_fn) => cost_fn, None => { // TODO: zero cost for now @@ -309,6 +419,123 @@ fn calculate_function_cost( }) } +fn calculate_function_cost( + function_name: &ClarityName, + arg_count: u64, +) -> Result { + calculate_function_cost_from_name(function_name.as_str(), arg_count) +} + +/// Convert a function name string to its corresponding cost function +fn get_cost_function_for_name_str( + name: &str, +) -> Option InterpreterResult> { + // Map function names to their cost functions using the existing enum structure + match name { + "+" | "add" => Some(Costs3::cost_add), + "-" | "sub" => Some(Costs3::cost_sub), + "*" | "mul" => Some(Costs3::cost_mul), + "/" | "div" => Some(Costs3::cost_div), + "mod" => Some(Costs3::cost_mod), + "pow" => Some(Costs3::cost_pow), + "sqrti" => Some(Costs3::cost_sqrti), + "log2" => Some(Costs3::cost_log2), + "to-int" | "to-uint" | "int-cast" => Some(Costs3::cost_int_cast), + "is-eq" | "=" | "eq" => Some(Costs3::cost_eq), + ">=" | "geq" => Some(Costs3::cost_geq), + "<=" | "leq" => Some(Costs3::cost_leq), + ">" | "ge" => Some(Costs3::cost_ge), + "<" | "le" => Some(Costs3::cost_le), + "xor" => Some(Costs3::cost_xor), + "not" => Some(Costs3::cost_not), + "and" => Some(Costs3::cost_and), + "or" => Some(Costs3::cost_or), + "concat" => Some(Costs3::cost_concat), + "len" => Some(Costs3::cost_len), + "as-max-len?" => Some(Costs3::cost_as_max_len), + "list" => Some(Costs3::cost_list_cons), + "element-at" | "element-at?" => Some(Costs3::cost_element_at), + "index-of" | "index-of?" => Some(Costs3::cost_index_of), + "fold" => Some(Costs3::cost_fold), + "map" => Some(Costs3::cost_map), + "filter" => Some(Costs3::cost_filter), + "append" => Some(Costs3::cost_append), + "tuple-get" => Some(Costs3::cost_tuple_get), + "tuple-merge" => Some(Costs3::cost_tuple_merge), + "tuple" => Some(Costs3::cost_tuple_cons), + "some" => Some(Costs3::cost_some_cons), + "ok" => Some(Costs3::cost_ok_cons), + "err" => Some(Costs3::cost_err_cons), + "default-to" => Some(Costs3::cost_default_to), + "unwrap!" => Some(Costs3::cost_unwrap_ret), + "unwrap-err!" => Some(Costs3::cost_unwrap_err_or_ret), + "is-ok" => Some(Costs3::cost_is_okay), + "is-none" => Some(Costs3::cost_is_none), + "is-err" => Some(Costs3::cost_is_err), + "is-some" => Some(Costs3::cost_is_some), + "unwrap-panic" => Some(Costs3::cost_unwrap), + "unwrap-err-panic" => Some(Costs3::cost_unwrap_err), + "try!" => Some(Costs3::cost_try_ret), + "if" => Some(Costs3::cost_if), + "match" => Some(Costs3::cost_match), + "begin" => Some(Costs3::cost_begin), + "let" => Some(Costs3::cost_let), + "asserts!" => Some(Costs3::cost_asserts), + "hash160" => Some(Costs3::cost_hash160), + "sha256" => Some(Costs3::cost_sha256), + "sha512" => Some(Costs3::cost_sha512), + "sha512/256" => Some(Costs3::cost_sha512t256), + "keccak256" => Some(Costs3::cost_keccak256), + "secp256k1-recover?" => Some(Costs3::cost_secp256k1recover), + "secp256k1-verify" => Some(Costs3::cost_secp256k1verify), + "print" => Some(Costs3::cost_print), + "contract-call?" => Some(Costs3::cost_contract_call), + "contract-of" => Some(Costs3::cost_contract_of), + "principal-of?" => Some(Costs3::cost_principal_of), + "at-block" => Some(Costs3::cost_at_block), + "load-contract" => Some(Costs3::cost_load_contract), + "create-map" => Some(Costs3::cost_create_map), + "create-var" => Some(Costs3::cost_create_var), + "create-non-fungible-token" => Some(Costs3::cost_create_nft), + "create-fungible-token" => Some(Costs3::cost_create_ft), + "map-get?" => Some(Costs3::cost_fetch_entry), + "map-set!" => Some(Costs3::cost_set_entry), + "var-get" => Some(Costs3::cost_fetch_var), + "var-set!" => Some(Costs3::cost_set_var), + "contract-storage" => Some(Costs3::cost_contract_storage), + "get-block-info?" => Some(Costs3::cost_block_info), + "get-burn-block-info?" => Some(Costs3::cost_burn_block_info), + "stx-get-balance" => Some(Costs3::cost_stx_balance), + "stx-transfer?" => Some(Costs3::cost_stx_transfer), + "stx-transfer-memo?" => Some(Costs3::cost_stx_transfer_memo), + "stx-account" => Some(Costs3::cost_stx_account), + "ft-mint?" => Some(Costs3::cost_ft_mint), + "ft-transfer?" => Some(Costs3::cost_ft_transfer), + "ft-get-balance" => Some(Costs3::cost_ft_balance), + "ft-get-supply" => Some(Costs3::cost_ft_get_supply), + "ft-burn?" => Some(Costs3::cost_ft_burn), + "nft-mint?" => Some(Costs3::cost_nft_mint), + "nft-transfer?" => Some(Costs3::cost_nft_transfer), + "nft-get-owner?" => Some(Costs3::cost_nft_owner), + "nft-burn?" => Some(Costs3::cost_nft_burn), + "buff-to-int-le?" => Some(Costs3::cost_buff_to_int_le), + "buff-to-uint-le?" => Some(Costs3::cost_buff_to_uint_le), + "buff-to-int-be?" => Some(Costs3::cost_buff_to_int_be), + "buff-to-uint-be?" => Some(Costs3::cost_buff_to_uint_be), + "to-consensus-buff?" => Some(Costs3::cost_to_consensus_buff), + "from-consensus-buff?" => Some(Costs3::cost_from_consensus_buff), + "is-standard?" => Some(Costs3::cost_is_standard), + "principal-destruct" => Some(Costs3::cost_principal_destruct), + "principal-construct?" => Some(Costs3::cost_principal_construct), + "as-contract" => Some(Costs3::cost_as_contract), + "string-to-int?" => Some(Costs3::cost_string_to_int), + "string-to-uint?" => Some(Costs3::cost_string_to_uint), + "int-to-ascii" => Some(Costs3::cost_int_to_ascii), + "int-to-utf8?" => Some(Costs3::cost_int_to_utf8), + _ => None, // Unknown function name + } +} + /// Convert a function name to its corresponding cost function fn get_cost_function_for_name( name: &ClarityName, @@ -443,20 +670,71 @@ fn get_max_input_size_for_function_name(function_name: &ClarityName, arg_count: } fn calculate_total_cost(node: &StaticCostNode) -> StaticCost { - let mut min_total = node.cost.min.clone(); - let mut max_total = node.cost.max.clone(); + calculate_total_cost_with_summing(node).into() +} - // Add costs from all children - // TODO: this should traverse different paths to get min and max costs +/// Calculate total cost using SummingExecutionCost to handle branching properly +fn calculate_total_cost_with_summing(node: &StaticCostNode) -> SummingExecutionCost { + let mut summing_cost = SummingExecutionCost::from_single(node.cost.min.clone()); + + // For each child, calculate its cost and combine appropriately for child in &node.children { - let child_cost = calculate_total_cost(child); - let _ = min_total.add(&child_cost.min); - let _ = max_total.add(&child_cost.max); + let child_summing = calculate_total_cost_with_summing(child); + summing_cost.add_summing(&child_summing); + } + + summing_cost +} + +/// Calculate total cost using branching logic from ExprTree +fn calculate_total_cost_with_branching( + expr_tree: &ExprTree, + cost_node: &StaticCostNode, +) -> SummingExecutionCost { + let mut summing_cost = SummingExecutionCost::new(); + + if expr_tree.branching { + // For branching, we need to create separate execution paths + // The first child is the condition, the rest are the branches + if cost_node.children.len() >= 2 { + let condition_cost = calculate_total_cost_with_summing(&cost_node.children[0]); + let condition_total = condition_cost.add_all(); + + // Add the root cost + condition cost to each branch + let mut root_and_condition = cost_node.cost.min.clone(); + let _ = root_and_condition.add(&condition_total); + + // For each branch (children 1+), create a complete path + for child_cost_node in cost_node.children.iter().skip(1) { + let branch_cost = calculate_total_cost_with_summing(child_cost_node); + let branch_total = branch_cost.add_all(); + + let mut path_cost = root_and_condition.clone(); + let _ = path_cost.add(&branch_total); + + summing_cost.add_cost(path_cost); + } + } + } else { + // For non-branching, add all costs sequentially + let mut total_cost = cost_node.cost.min.clone(); + for child_cost_node in &cost_node.children { + let child_summing = calculate_total_cost_with_summing(child_cost_node); + let combined_cost = child_summing.add_all(); + let _ = total_cost.add(&combined_cost); + } + summing_cost.add_cost(total_cost); } - StaticCost { - min: min_total, - max: max_total, + summing_cost +} + +impl From for StaticCost { + fn from(summing: SummingExecutionCost) -> Self { + StaticCost { + min: summing.min(), + max: summing.max(), + } } } @@ -530,6 +808,21 @@ mod tests { assert_eq!(cost.max.runtime, 429); } + #[test] + fn test_branching() { + let source = "(if (> 3 0) (ok (concat \"hello\" \"world\")) (ok \"asdf\"))"; + let cost = static_cost(source).unwrap(); + // min: 147 raw string + // max: 294 (concat) + + // ok = 199 + // if = 168 + // ge = (linear(n, 7, 128))) + let base_cost = 168 + ((2 * 7) + 128) + 199; + assert_eq!(cost.min.runtime, base_cost + 147); + assert_eq!(cost.max.runtime, base_cost + 294); + } + #[test] fn test_build_expr_tree_if_expression() { let source = "(if (> 3 0) (ok true) (ok false))"; From 48ea467eab557638f04901dc7f59e3b45374878d Mon Sep 17 00:00:00 2001 From: "brady.ouren" Date: Tue, 16 Sep 2025 01:33:22 -0700 Subject: [PATCH 4/6] cleanup unused and fix len test --- clarity/src/vm/costs/analysis.rs | 178 ++++--------------------------- 1 file changed, 18 insertions(+), 160 deletions(-) diff --git a/clarity/src/vm/costs/analysis.rs b/clarity/src/vm/costs/analysis.rs index 523bf526c0..77964c2eb4 100644 --- a/clarity/src/vm/costs/analysis.rs +++ b/clarity/src/vm/costs/analysis.rs @@ -19,6 +19,16 @@ use crate::vm::representations::{ClarityName, PreSymbolicExpression, PreSymbolic // lookups // unwrap evaluates both branches (https://github.com/clarity-lang/reference/issues/59) +/// Calculate the cost for a string based on its length +fn string_cost(length: usize) -> StaticCost { + let cost = linear(length as u64, 36, 3); + let execution_cost = ExecutionCost::runtime(cost); + StaticCost { + min: execution_cost.clone(), + max: execution_cost, + } +} + #[derive(Debug, Clone)] pub struct StaticCostNode { pub function: Vec, @@ -329,28 +339,11 @@ fn build_cost_tree(expr_tree: &ExprTree) -> Result { ExprNode::Function(name) => name.as_str(), ExprNode::AtomValue(value) => { // String literals have cost based on length only when they're standalone (not function arguments) + // TODO: not sure if /utf8 and ascii are treated the same cost-wise.. if let Value::Sequence(SequenceData::String(CharType::UTF8(data))) = value { - let length = data.data.len() as u64; - let cost = linear(length, 36, 3); - let execution_cost = ExecutionCost::runtime(cost); - return Ok(StaticCostNode::leaf( - vec![], - StaticCost { - min: execution_cost.clone(), - max: execution_cost, - }, - )); + return Ok(StaticCostNode::leaf(vec![], string_cost(data.data.len()))); } else if let Value::Sequence(SequenceData::String(CharType::ASCII(data))) = value { - let length = data.data.len() as u64; - let cost = linear(length, 36, 3); - let execution_cost = ExecutionCost::runtime(cost); - return Ok(StaticCostNode::leaf( - vec![], - StaticCost { - min: execution_cost.clone(), - max: execution_cost, - }, - )); + return Ok(StaticCostNode::leaf(vec![], string_cost(data.data.len()))); } // Other atom values have zero cost return Ok(StaticCostNode::leaf(vec![], StaticCost::ZERO)); @@ -367,12 +360,12 @@ fn build_cost_tree(expr_tree: &ExprTree) -> Result { let mut children = Vec::new(); for child_expr in &expr_tree.children { - // For certain functions like concat, string arguments should have zero cost + // For certain functions like concat and len, string arguments should have zero cost // since the function cost includes their processing - if function_name == "concat" { + if function_name == "concat" || function_name == "len" { if let ExprNode::AtomValue(Value::Sequence(SequenceData::String(_))) = &child_expr.expr { - // String arguments to concat have zero cost + // String arguments to concat and len have zero cost children.push(StaticCostNode::leaf(vec![], StaticCost::ZERO)); continue; } @@ -404,7 +397,7 @@ fn calculate_function_cost_from_name( function_name: &str, arg_count: u64, ) -> Result { - let cost_function = match get_cost_function_for_name_str(function_name) { + let cost_function = match get_cost_function_for_name(function_name) { Some(cost_fn) => cost_fn, None => { // TODO: zero cost for now @@ -427,9 +420,7 @@ fn calculate_function_cost( } /// Convert a function name string to its corresponding cost function -fn get_cost_function_for_name_str( - name: &str, -) -> Option InterpreterResult> { +fn get_cost_function_for_name(name: &str) -> Option InterpreterResult> { // Map function names to their cost functions using the existing enum structure match name { "+" | "add" => Some(Costs3::cost_add), @@ -536,139 +527,6 @@ fn get_cost_function_for_name_str( } } -/// Convert a function name to its corresponding cost function -fn get_cost_function_for_name( - name: &ClarityName, -) -> Option InterpreterResult> { - let name_str = name.as_str(); - - // Map function names to their cost functions using the existing enum structure - match name_str { - "+" | "add" => Some(Costs3::cost_add), - "-" | "sub" => Some(Costs3::cost_sub), - "*" | "mul" => Some(Costs3::cost_mul), - "/" | "div" => Some(Costs3::cost_div), - "mod" => Some(Costs3::cost_mod), - "pow" => Some(Costs3::cost_pow), - "sqrti" => Some(Costs3::cost_sqrti), - "log2" => Some(Costs3::cost_log2), - "to-int" | "to-uint" | "int-cast" => Some(Costs3::cost_int_cast), - "is-eq" | "=" | "eq" => Some(Costs3::cost_eq), - ">=" | "geq" => Some(Costs3::cost_geq), - "<=" | "leq" => Some(Costs3::cost_leq), - ">" | "ge" => Some(Costs3::cost_ge), - "<" | "le" => Some(Costs3::cost_le), - "xor" => Some(Costs3::cost_xor), - "not" => Some(Costs3::cost_not), - "and" => Some(Costs3::cost_and), - "or" => Some(Costs3::cost_or), - "concat" => Some(Costs3::cost_concat), - "len" => Some(Costs3::cost_len), - "as-max-len?" => Some(Costs3::cost_as_max_len), - "list" => Some(Costs3::cost_list_cons), - "element-at" | "element-at?" => Some(Costs3::cost_element_at), - "index-of" | "index-of?" => Some(Costs3::cost_index_of), - "fold" => Some(Costs3::cost_fold), - "map" => Some(Costs3::cost_map), - "filter" => Some(Costs3::cost_filter), - "append" => Some(Costs3::cost_append), - "tuple-get" => Some(Costs3::cost_tuple_get), - "tuple-merge" => Some(Costs3::cost_tuple_merge), - "tuple" => Some(Costs3::cost_tuple_cons), - "some" => Some(Costs3::cost_some_cons), - "ok" => Some(Costs3::cost_ok_cons), - "err" => Some(Costs3::cost_err_cons), - "default-to" => Some(Costs3::cost_default_to), - "unwrap!" => Some(Costs3::cost_unwrap_ret), - "unwrap-err!" => Some(Costs3::cost_unwrap_err_or_ret), - "is-ok" => Some(Costs3::cost_is_okay), - "is-none" => Some(Costs3::cost_is_none), - "is-err" => Some(Costs3::cost_is_err), - "is-some" => Some(Costs3::cost_is_some), - "unwrap-panic" => Some(Costs3::cost_unwrap), - "unwrap-err-panic" => Some(Costs3::cost_unwrap_err), - "try!" => Some(Costs3::cost_try_ret), - "if" => Some(Costs3::cost_if), - "match" => Some(Costs3::cost_match), - "begin" => Some(Costs3::cost_begin), - "let" => Some(Costs3::cost_let), - "asserts!" => Some(Costs3::cost_asserts), - "hash160" => Some(Costs3::cost_hash160), - "sha256" => Some(Costs3::cost_sha256), - "sha512" => Some(Costs3::cost_sha512), - "sha512/256" => Some(Costs3::cost_sha512t256), - "keccak256" => Some(Costs3::cost_keccak256), - "secp256k1-recover?" => Some(Costs3::cost_secp256k1recover), - "secp256k1-verify" => Some(Costs3::cost_secp256k1verify), - "print" => Some(Costs3::cost_print), - "contract-call?" => Some(Costs3::cost_contract_call), - "contract-of" => Some(Costs3::cost_contract_of), - "principal-of?" => Some(Costs3::cost_principal_of), - "at-block" => Some(Costs3::cost_at_block), - "load-contract" => Some(Costs3::cost_load_contract), - "create-map" => Some(Costs3::cost_create_map), - "create-var" => Some(Costs3::cost_create_var), - "create-non-fungible-token" => Some(Costs3::cost_create_nft), - "create-fungible-token" => Some(Costs3::cost_create_ft), - "map-get?" => Some(Costs3::cost_fetch_entry), - "map-set!" => Some(Costs3::cost_set_entry), - "var-get" => Some(Costs3::cost_fetch_var), - "var-set!" => Some(Costs3::cost_set_var), - "contract-storage" => Some(Costs3::cost_contract_storage), - "get-block-info?" => Some(Costs3::cost_block_info), - "get-burn-block-info?" => Some(Costs3::cost_burn_block_info), - "stx-get-balance" => Some(Costs3::cost_stx_balance), - "stx-transfer?" => Some(Costs3::cost_stx_transfer), - "stx-transfer-memo?" => Some(Costs3::cost_stx_transfer_memo), - "stx-account" => Some(Costs3::cost_stx_account), - "ft-mint?" => Some(Costs3::cost_ft_mint), - "ft-transfer?" => Some(Costs3::cost_ft_transfer), - "ft-get-balance" => Some(Costs3::cost_ft_balance), - "ft-get-supply" => Some(Costs3::cost_ft_get_supply), - "ft-burn?" => Some(Costs3::cost_ft_burn), - "nft-mint?" => Some(Costs3::cost_nft_mint), - "nft-transfer?" => Some(Costs3::cost_nft_transfer), - "nft-get-owner?" => Some(Costs3::cost_nft_owner), - "nft-burn?" => Some(Costs3::cost_nft_burn), - "buff-to-int-le?" => Some(Costs3::cost_buff_to_int_le), - "buff-to-uint-le?" => Some(Costs3::cost_buff_to_uint_le), - "buff-to-int-be?" => Some(Costs3::cost_buff_to_int_be), - "buff-to-uint-be?" => Some(Costs3::cost_buff_to_uint_be), - "to-consensus-buff?" => Some(Costs3::cost_to_consensus_buff), - "from-consensus-buff?" => Some(Costs3::cost_from_consensus_buff), - "is-standard?" => Some(Costs3::cost_is_standard), - "principal-destruct" => Some(Costs3::cost_principal_destruct), - "principal-construct?" => Some(Costs3::cost_principal_construct), - "as-contract" => Some(Costs3::cost_as_contract), - "string-to-int?" => Some(Costs3::cost_string_to_int), - "string-to-uint?" => Some(Costs3::cost_string_to_uint), - "int-to-ascii" => Some(Costs3::cost_int_to_ascii), - "int-to-utf8?" => Some(Costs3::cost_int_to_utf8), - _ => None, // Unknown function name - } -} - -fn get_max_input_size_for_function_name(function_name: &ClarityName, arg_count: u64) -> u64 { - let name_str = function_name.as_str(); - - match name_str { - "concat" => { - // For string concatenation, max size is the sum of max string lengths - // Each string can be up to MAX_VALUE_SIZE (1MB), so for n strings it's n * MAX_VALUE_SIZE - arg_count * 1024 * 1024 - } - "len" => { - // For length, maximum string length - 1024 * 1024 // MAX_VALUE_SIZE - } - _ => { - // Default case - use a fixed max size to match original behavior - // The original code used 2000 as the max input size for arithmetic operations - 2000 - } - } -} - fn calculate_total_cost(node: &StaticCostNode) -> StaticCost { calculate_total_cost_with_summing(node).into() } From 6d90b99d7fbfb53df9fb7884e33642662789825e Mon Sep 17 00:00:00 2001 From: "brady.ouren" Date: Mon, 22 Sep 2025 15:59:52 -0700 Subject: [PATCH 5/6] add UserArgument tracking --- clarity/src/vm/costs/analysis.rs | 191 ++++++++++++++++++++++++++----- 1 file changed, 162 insertions(+), 29 deletions(-) diff --git a/clarity/src/vm/costs/analysis.rs b/clarity/src/vm/costs/analysis.rs index 77964c2eb4..580140d355 100644 --- a/clarity/src/vm/costs/analysis.rs +++ b/clarity/src/vm/costs/analysis.rs @@ -188,6 +188,8 @@ pub enum ExprNode { SugaredFieldIdentifier(ContractName, ClarityName), FieldIdentifier(TraitIdentifier), TraitReference(ClarityName), + // User function arguments + UserArgument(ClarityName, ClarityName), // (argument_name, argument_type) } #[derive(Debug, Clone)] @@ -200,7 +202,18 @@ pub struct ExprTree { /// Build an expression tree, skipping comments and placeholders fn build_expr_tree(expr: &PreSymbolicExpression) -> Result { match &expr.pre_expr { - PreSymbolicExpressionType::List(list) => build_listlike_expr_tree(list, "list"), + PreSymbolicExpressionType::List(list) => { + // Check if this is a function definition + if let Some(function_name) = list.first().and_then(|first| first.match_atom()) { + if function_name.as_str() == "define-public" + || function_name.as_str() == "define-private" + || function_name.as_str() == "define-read-only" + { + return build_function_definition_expr_tree(list); + } + } + build_listlike_expr_tree(list, "list") + } PreSymbolicExpressionType::AtomValue(value) => Ok(ExprTree { expr: ExprNode::AtomValue(value.clone()), children: vec![], @@ -251,6 +264,79 @@ fn build_expr_tree(expr: &PreSymbolicExpression) -> Result { } } +/// Build an expression tree for function definitions like (define-public (foo (a u64)) (ok a)) +fn build_function_definition_expr_tree(list: &[PreSymbolicExpression]) -> Result { + if list.len() < 3 { + return Err( + "Function definition must have at least 3 elements: define type, signature, and body" + .to_string(), + ); + } + + let define_type = list[0] + .match_atom() + .ok_or("First element must be define type")?; + let signature = list[1] + .match_list() + .ok_or("Second element must be function signature")?; + let body = &list[2]; + + // Parse the function signature: (foo (a u64)) + if signature.is_empty() { + return Err("Function signature cannot be empty".to_string()); + } + + let _function_name = signature[0] + .match_atom() + .ok_or("Function name must be an atom")?; + let mut children = Vec::new(); + + // Process function arguments: (a u64) + for arg_expr in signature.iter().skip(1) { + if let Some(arg_list) = arg_expr.match_list() { + if arg_list.len() == 2 { + let arg_name = arg_list[0] + .match_atom() + .ok_or("Argument name must be an atom")?; + + // Handle both atom types and atom values for the type + let arg_type = match &arg_list[1].pre_expr { + PreSymbolicExpressionType::Atom(type_name) => type_name.clone(), + PreSymbolicExpressionType::AtomValue(value) => { + // Convert the value to a string representation + ClarityName::from(value.to_string().as_str()) + } + _ => return Err("Argument type must be an atom or atom value".to_string()), + }; + + // Create UserArgument node + children.push(ExprTree { + expr: ExprNode::UserArgument(arg_name.clone(), arg_type), + children: vec![], + branching: false, + }); + } else { + return Err( + "Function argument must have exactly 2 elements: name and type".to_string(), + ); + } + } else { + return Err("Function argument must be a list".to_string()); + } + } + + // Process the function body + let body_tree = build_expr_tree(body)?; + children.push(body_tree); + + // Create the function definition node + Ok(ExprTree { + expr: ExprNode::Function(define_type.clone()), + children, + branching: false, + }) +} + /// Helper function to build expression trees for both lists and tuples fn build_listlike_expr_tree( items: &[PreSymbolicExpression], @@ -314,7 +400,10 @@ fn build_listlike_expr_tree( /// Determine if a function name represents a branching function fn is_branching_function(function_name: &ClarityName) -> bool { match function_name.as_str() { - "if" | "match" | "unwrap!" | "unwrap-err!" => true, + "if" | "match" => true, + "unwrap!" | "unwrap-err!" => false, // XXX: currently unwrap and + // unwrap-err traverse both branches regardless of result, so until this is + // fixed in clarity we'll set this to false _ => false, } } @@ -352,7 +441,8 @@ fn build_cost_tree(expr_tree: &ExprTree) -> Result { | ExprNode::SugaredContractIdentifier(_) | ExprNode::SugaredFieldIdentifier(_, _) | ExprNode::FieldIdentifier(_) - | ExprNode::TraitReference(_) => { + | ExprNode::TraitReference(_) + | ExprNode::UserArgument(_, _) => { // Leaf nodes have zero cost return Ok(StaticCostNode::leaf(vec![], StaticCost::ZERO)); } @@ -412,13 +502,6 @@ fn calculate_function_cost_from_name( }) } -fn calculate_function_cost( - function_name: &ClarityName, - arg_count: u64, -) -> Result { - calculate_function_cost_from_name(function_name.as_str(), arg_count) -} - /// Convert a function name string to its corresponding cost function fn get_cost_function_for_name(name: &str) -> Option InterpreterResult> { // Map function names to their cost functions using the existing enum structure @@ -552,25 +635,40 @@ fn calculate_total_cost_with_branching( let mut summing_cost = SummingExecutionCost::new(); if expr_tree.branching { - // For branching, we need to create separate execution paths - // The first child is the condition, the rest are the branches - if cost_node.children.len() >= 2 { - let condition_cost = calculate_total_cost_with_summing(&cost_node.children[0]); - let condition_total = condition_cost.add_all(); - - // Add the root cost + condition cost to each branch - let mut root_and_condition = cost_node.cost.min.clone(); - let _ = root_and_condition.add(&condition_total); - - // For each branch (children 1+), create a complete path - for child_cost_node in cost_node.children.iter().skip(1) { - let branch_cost = calculate_total_cost_with_summing(child_cost_node); - let branch_total = branch_cost.add_all(); - - let mut path_cost = root_and_condition.clone(); - let _ = path_cost.add(&branch_total); - - summing_cost.add_cost(path_cost); + // Handle different types of branching functions + match &expr_tree.expr { + ExprNode::If | ExprNode::Match => { + // For if and match, we need to create separate execution paths + // The first child is the condition, the rest are the branches + if cost_node.children.len() >= 2 { + let condition_cost = calculate_total_cost_with_summing(&cost_node.children[0]); + let condition_total = condition_cost.add_all(); + + // Add the root cost + condition cost to each branch + let mut root_and_condition = cost_node.cost.min.clone(); + let _ = root_and_condition.add(&condition_total); + + // For each branch (children 1+), create a complete path + for child_cost_node in cost_node.children.iter().skip(1) { + let branch_cost = calculate_total_cost_with_summing(child_cost_node); + let branch_total = branch_cost.add_all(); + + let mut path_cost = root_and_condition.clone(); + let _ = path_cost.add(&branch_total); + + summing_cost.add_cost(path_cost); + } + } + } + _ => { + // For other branching functions, fall back to sequential processing + let mut total_cost = cost_node.cost.min.clone(); + for child_cost_node in &cost_node.children { + let child_summing = calculate_total_cost_with_summing(child_cost_node); + let combined_cost = child_summing.add_all(); + let _ = total_cost.add(&combined_cost); + } + summing_cost.add_cost(total_cost); } } } else { @@ -680,7 +778,42 @@ mod tests { assert_eq!(cost.min.runtime, base_cost + 147); assert_eq!(cost.max.runtime, base_cost + 294); } + #[test] + fn test_function_arguments() { + let src = r#"(define-public (foo (a u64)) (ok a))"#; + let pre_expressions = parse(src).unwrap(); + let pre_expr = &pre_expressions[0]; + let expr_tree = build_expr_tree(pre_expr).unwrap(); + + // The root should be a Function node with "define-public" + assert!(matches!(expr_tree.expr, ExprNode::Function(_))); + if let ExprNode::Function(name) = &expr_tree.expr { + assert_eq!(name.as_str(), "define-public"); + } + + // Should have 2 children: UserArgument for (a u64) and the body (ok a) + assert_eq!(expr_tree.children.len(), 2); + // First child should be UserArgument for (a u64) + let user_arg = &expr_tree.children[0]; + assert!(matches!(user_arg.expr, ExprNode::UserArgument(_, _))); + if let ExprNode::UserArgument(arg_name, arg_type) = &user_arg.expr { + assert_eq!(arg_name.as_str(), "a"); + assert_eq!(arg_type.as_str(), "u64"); + } + + // Second child should be the function body (ok a) + let body = &expr_tree.children[1]; + assert!(matches!(body.expr, ExprNode::Ok)); + assert_eq!(body.children.len(), 1); + + // The body should reference the argument 'a' + let arg_ref = &body.children[0]; + assert!(matches!(arg_ref.expr, ExprNode::Atom(_))); + if let ExprNode::Atom(name) = &arg_ref.expr { + assert_eq!(name.as_str(), "a"); + } + } #[test] fn test_build_expr_tree_if_expression() { let source = "(if (> 3 0) (ok true) (ok false))"; From 6cbf44aa224bee13f1fa4f917660cf6abb6ea084 Mon Sep 17 00:00:00 2001 From: "brady.ouren" Date: Tue, 23 Sep 2025 00:06:15 -0700 Subject: [PATCH 6/6] simplify to build_cost_analysis_tree --- clarity/src/vm/costs/analysis.rs | 607 +++++++++++++++---------------- 1 file changed, 286 insertions(+), 321 deletions(-) diff --git a/clarity/src/vm/costs/analysis.rs b/clarity/src/vm/costs/analysis.rs index 580140d355..1deef22d49 100644 --- a/clarity/src/vm/costs/analysis.rs +++ b/clarity/src/vm/costs/analysis.rs @@ -1,5 +1,7 @@ // Static cost analysis for Clarity expressions +use std::collections::HashMap; + use clarity_serialization::representations::ContractName; use clarity_serialization::types::{CharType, SequenceData, TraitIdentifier}; use clarity_serialization::Value; @@ -12,48 +14,74 @@ use crate::vm::errors::InterpreterResult; use crate::vm::representations::{ClarityName, PreSymbolicExpression, PreSymbolicExpressionType}; // TODO: -// variable traverse for -// - if, unwrap-*, match, etc // contract-call? - get source from database // type-checking // lookups // unwrap evaluates both branches (https://github.com/clarity-lang/reference/issues/59) -/// Calculate the cost for a string based on its length -fn string_cost(length: usize) -> StaticCost { - let cost = linear(length as u64, 36, 3); - let execution_cost = ExecutionCost::runtime(cost); - StaticCost { - min: execution_cost.clone(), - max: execution_cost, - } +const STRING_COST_BASE: u64 = 36; +const STRING_COST_MULTIPLIER: u64 = 3; + +/// Functions where string arguments have zero cost because the function +/// cost includes their processing +const FUNCTIONS_WITH_ZERO_STRING_ARG_COST: &[&str] = &["concat", "len"]; + +#[derive(Debug, Clone)] +pub enum ExprNode { + If, + Match, + Unwrap, + Ok, + Err, + GT, + LT, + GE, + LE, + EQ, + Add, + Sub, + Mul, + Div, + Function(ClarityName), + AtomValue(Value), + Atom(ClarityName), + SugaredContractIdentifier(ContractName), + SugaredFieldIdentifier(ContractName, ClarityName), + FieldIdentifier(TraitIdentifier), + TraitReference(ClarityName), + // User function arguments + UserArgument(ClarityName, ClarityName), // (argument_name, argument_type) } #[derive(Debug, Clone)] -pub struct StaticCostNode { - pub function: Vec, +pub struct CostAnalysisNode { + pub expr: ExprNode, pub cost: StaticCost, - pub children: Vec, + pub children: Vec, + pub branching: bool, } -impl StaticCostNode { +impl CostAnalysisNode { pub fn new( - function: Vec, + expr: ExprNode, cost: StaticCost, - children: Vec, + children: Vec, + branching: bool, ) -> Self { Self { - function, + expr, cost, children, + branching, } } - pub fn leaf(function: Vec, cost: StaticCost) -> Self { + pub fn leaf(expr: ExprNode, cost: StaticCost) -> Self { Self { - function, + expr, cost, children: vec![], + branching: false, } } } @@ -71,6 +99,32 @@ impl StaticCost { }; } +#[derive(Debug, Clone)] +pub struct UserArgumentsContext { + /// Map from argument name to argument type + pub arguments: HashMap, +} + +impl UserArgumentsContext { + pub fn new() -> Self { + Self { + arguments: HashMap::new(), + } + } + + pub fn add_argument(&mut self, name: ClarityName, arg_type: ClarityName) { + self.arguments.insert(name, arg_type); + } + + pub fn is_user_argument(&self, name: &ClarityName) -> bool { + self.arguments.contains_key(name) + } + + pub fn get_argument_type(&self, name: &ClarityName) -> Option<&ClarityName> { + self.arguments.get(name) + } +} + /// A type to track summed execution costs for different paths /// This allows us to compute min and max costs across different execution paths #[derive(Debug, Clone)] @@ -95,7 +149,7 @@ impl SummingExecutionCost { self.costs.extend(other.costs.clone()); } - /// Get the minimum cost across all paths + /// minimum cost across all paths pub fn min(&self) -> ExecutionCost { if self.costs.is_empty() { ExecutionCost::ZERO @@ -112,7 +166,7 @@ impl SummingExecutionCost { } } - /// Get the maximum cost across all paths + /// maximum cost across all paths pub fn max(&self) -> ExecutionCost { if self.costs.is_empty() { ExecutionCost::ZERO @@ -129,7 +183,6 @@ impl SummingExecutionCost { } } - /// Combine costs by adding them (for non-branching operations) pub fn add_all(&self) -> ExecutionCost { self.costs .iter() @@ -142,8 +195,6 @@ impl SummingExecutionCost { /// Parse Clarity source code and calculate its static execution cost /// -/// This function takes a Clarity expression as a string, parses it into symbolic -/// expressions, builds a cost tree, and returns the min and max execution cost. /// theoretically you could inspect the tree at any node to get the spot cost pub fn static_cost(source: &str) -> Result { let pre_expressions = parse(source).map_err(|e| format!("Parse error: {:?}", e))?; @@ -154,106 +205,67 @@ pub fn static_cost(source: &str) -> Result { // TODO what happens if multiple expressions are selected? let pre_expr = &pre_expressions[0]; - let expr_tree = build_expr_tree(pre_expr)?; - let cost_tree = build_cost_tree(&expr_tree)?; + let user_args = UserArgumentsContext::new(); + let cost_analysis_tree = build_cost_analysis_tree(pre_expr, &user_args)?; - // Use branching-aware cost calculation - let summing_cost = calculate_total_cost_with_branching(&expr_tree, &cost_tree); + let summing_cost = calculate_total_cost_with_branching(&cost_analysis_tree); Ok(summing_cost.into()) } -#[derive(Debug, Clone)] -pub enum ExprNode { - If, - Match, - Unwrap, - Ok, - Err, - GT, - LT, - GE, - LE, - EQ, - Add, - Sub, - Mul, - Div, - // Other functions - Function(ClarityName), - // Values - AtomValue(Value), - Atom(ClarityName), - // Placeholder for sugared identifiers - SugaredContractIdentifier(ContractName), - SugaredFieldIdentifier(ContractName, ClarityName), - FieldIdentifier(TraitIdentifier), - TraitReference(ClarityName), - // User function arguments - UserArgument(ClarityName, ClarityName), // (argument_name, argument_type) -} - -#[derive(Debug, Clone)] -pub struct ExprTree { - pub expr: ExprNode, - pub children: Vec, - pub branching: bool, -} - -/// Build an expression tree, skipping comments and placeholders -fn build_expr_tree(expr: &PreSymbolicExpression) -> Result { +fn build_cost_analysis_tree( + expr: &PreSymbolicExpression, + user_args: &UserArgumentsContext, +) -> Result { match &expr.pre_expr { PreSymbolicExpressionType::List(list) => { - // Check if this is a function definition if let Some(function_name) = list.first().and_then(|first| first.match_atom()) { if function_name.as_str() == "define-public" || function_name.as_str() == "define-private" || function_name.as_str() == "define-read-only" { - return build_function_definition_expr_tree(list); + return build_function_definition_cost_analysis_tree(list, user_args); } } - build_listlike_expr_tree(list, "list") + build_listlike_cost_analysis_tree(list, "list", user_args) + } + PreSymbolicExpressionType::AtomValue(value) => { + let cost = calculate_value_cost(value)?; + Ok(CostAnalysisNode::leaf( + ExprNode::AtomValue(value.clone()), + cost, + )) + } + PreSymbolicExpressionType::Atom(name) => { + let expr_node = parse_atom_expression(name, user_args)?; + Ok(CostAnalysisNode::leaf(expr_node, StaticCost::ZERO)) + } + PreSymbolicExpressionType::Tuple(tuple) => { + build_listlike_cost_analysis_tree(tuple, "tuple", user_args) } - PreSymbolicExpressionType::AtomValue(value) => Ok(ExprTree { - expr: ExprNode::AtomValue(value.clone()), - children: vec![], - branching: false, - }), - PreSymbolicExpressionType::Atom(name) => Ok(ExprTree { - expr: ExprNode::Atom(name.clone()), - children: vec![], - branching: false, - }), - PreSymbolicExpressionType::Tuple(tuple) => build_listlike_expr_tree(tuple, "tuple"), PreSymbolicExpressionType::SugaredContractIdentifier(contract_name) => { - // TODO: Look up the source for this contract identifier - Ok(ExprTree { - expr: ExprNode::SugaredContractIdentifier(contract_name.clone()), - children: vec![], - branching: false, - }) + Ok(CostAnalysisNode::leaf( + ExprNode::SugaredContractIdentifier(contract_name.clone()), + // TODO: Look up source + StaticCost::ZERO, + )) } PreSymbolicExpressionType::SugaredFieldIdentifier(contract_name, field_name) => { - // TODO: Look up the source for this field identifier - Ok(ExprTree { - expr: ExprNode::SugaredFieldIdentifier(contract_name.clone(), field_name.clone()), - children: vec![], - branching: false, - }) - } - PreSymbolicExpressionType::FieldIdentifier(field_name) => Ok(ExprTree { - expr: ExprNode::FieldIdentifier(field_name.clone()), - children: vec![], - branching: false, - }), - PreSymbolicExpressionType::TraitReference(trait_name) => { - // TODO: Look up the source for this trait reference - Ok(ExprTree { - expr: ExprNode::TraitReference(trait_name.clone()), - children: vec![], - branching: false, - }) + Ok(CostAnalysisNode::leaf( + ExprNode::SugaredFieldIdentifier(contract_name.clone(), field_name.clone()), + // TODO: Look up source + StaticCost::ZERO, + )) } + PreSymbolicExpressionType::FieldIdentifier(field_name) => Ok(CostAnalysisNode::leaf( + ExprNode::FieldIdentifier(field_name.clone()), + // TODO: Look up source + StaticCost::ZERO, + )), + PreSymbolicExpressionType::TraitReference(trait_name) => Ok(CostAnalysisNode::leaf( + ExprNode::TraitReference(trait_name.clone()), + // TODO: Look up source + StaticCost::ZERO, + )), // Comments and placeholders should be filtered out during traversal PreSymbolicExpressionType::Comment(_comment) => { Err("hit an irrelevant comment expr type".to_string()) @@ -264,32 +276,38 @@ fn build_expr_tree(expr: &PreSymbolicExpression) -> Result { } } -/// Build an expression tree for function definitions like (define-public (foo (a u64)) (ok a)) -fn build_function_definition_expr_tree(list: &[PreSymbolicExpression]) -> Result { - if list.len() < 3 { - return Err( - "Function definition must have at least 3 elements: define type, signature, and body" - .to_string(), - ); +/// Parse an atom expression into an ExprNode +fn parse_atom_expression( + name: &ClarityName, + user_args: &UserArgumentsContext, +) -> Result { + // Check if this atom is a user-defined function argument + if user_args.is_user_argument(name) { + if let Some(arg_type) = user_args.get_argument_type(name) { + Ok(ExprNode::UserArgument(name.clone(), arg_type.clone())) + } else { + Ok(ExprNode::Atom(name.clone())) + } + } else { + Ok(ExprNode::Atom(name.clone())) } +} +/// Build an expression tree for function definitions like (define-public (foo (a u64)) (ok a)) +fn build_function_definition_cost_analysis_tree( + list: &[PreSymbolicExpression], + _user_args: &UserArgumentsContext, +) -> Result { let define_type = list[0] .match_atom() - .ok_or("First element must be define type")?; + .ok_or("Expected atom for define type")?; let signature = list[1] .match_list() - .ok_or("Second element must be function signature")?; + .ok_or("Expected list for function signature")?; let body = &list[2]; - // Parse the function signature: (foo (a u64)) - if signature.is_empty() { - return Err("Function signature cannot be empty".to_string()); - } - - let _function_name = signature[0] - .match_atom() - .ok_or("Function name must be an atom")?; let mut children = Vec::new(); + let mut function_user_args = UserArgumentsContext::new(); // Process function arguments: (a u64) for arg_expr in signature.iter().skip(1) { @@ -297,51 +315,47 @@ fn build_function_definition_expr_tree(list: &[PreSymbolicExpression]) -> Result if arg_list.len() == 2 { let arg_name = arg_list[0] .match_atom() - .ok_or("Argument name must be an atom")?; + .ok_or("Expected atom for argument name")?; - // Handle both atom types and atom values for the type let arg_type = match &arg_list[1].pre_expr { PreSymbolicExpressionType::Atom(type_name) => type_name.clone(), PreSymbolicExpressionType::AtomValue(value) => { - // Convert the value to a string representation ClarityName::from(value.to_string().as_str()) } _ => return Err("Argument type must be an atom or atom value".to_string()), }; + // Add to function's user arguments context + function_user_args.add_argument(arg_name.clone(), arg_type.clone()); + // Create UserArgument node - children.push(ExprTree { - expr: ExprNode::UserArgument(arg_name.clone(), arg_type), - children: vec![], - branching: false, - }); - } else { - return Err( - "Function argument must have exactly 2 elements: name and type".to_string(), - ); + children.push(CostAnalysisNode::leaf( + ExprNode::UserArgument(arg_name.clone(), arg_type), + StaticCost::ZERO, + )); } - } else { - return Err("Function argument must be a list".to_string()); } } - // Process the function body - let body_tree = build_expr_tree(body)?; + // Process the function body with the function's user arguments context + let body_tree = build_cost_analysis_tree(body, &function_user_args)?; children.push(body_tree); - // Create the function definition node - Ok(ExprTree { - expr: ExprNode::Function(define_type.clone()), + // Create the function definition node with zero cost (function definitions themselves don't have execution cost) + Ok(CostAnalysisNode::new( + ExprNode::Function(define_type.clone()), + StaticCost::ZERO, children, - branching: false, - }) + false, + )) } /// Helper function to build expression trees for both lists and tuples -fn build_listlike_expr_tree( +fn build_listlike_cost_analysis_tree( items: &[PreSymbolicExpression], container_type: &str, -) -> Result { + user_args: &UserArgumentsContext, +) -> Result { let function_name = match &items[0].pre_expr { PreSymbolicExpressionType::Atom(name) => name, _ => { @@ -363,16 +377,30 @@ fn build_listlike_expr_tree( continue; } _ => { - children.push(build_expr_tree(arg)?); + children.push(build_cost_analysis_tree(arg, user_args)?); } } } - // Determine if this is a branching function let branching = is_branching_function(function_name); + let expr_node = map_function_to_expr_node(function_name.as_str()); + let cost = calculate_function_cost_from_name(function_name.as_str(), children.len() as u64)?; + + // Handle special cases for string arguments to functions that include their processing cost + if FUNCTIONS_WITH_ZERO_STRING_ARG_COST.contains(&function_name.as_str()) { + for child in &mut children { + if let ExprNode::AtomValue(Value::Sequence(SequenceData::String(_))) = &child.expr { + child.cost = StaticCost::ZERO; + } + } + } - // Create the appropriate ExprNode - let expr_node = match function_name.as_str() { + Ok(CostAnalysisNode::new(expr_node, cost, children, branching)) +} + +/// Maps function names to their corresponding ExprNode variants +fn map_function_to_expr_node(function_name: &str) -> ExprNode { + match function_name { "if" => ExprNode::If, "match" => ExprNode::Match, "unwrap!" | "unwrap-err!" | "unwrap-panic" | "unwrap-err-panic" => ExprNode::Unwrap, @@ -387,14 +415,8 @@ fn build_listlike_expr_tree( "-" | "sub" => ExprNode::Sub, "*" | "mul" => ExprNode::Mul, "/" | "div" => ExprNode::Div, - _ => ExprNode::Function(function_name.clone()), - }; - - Ok(ExprTree { - expr: expr_node, - children, - branching, - }) + _ => ExprNode::Function(ClarityName::from(function_name)), + } } /// Determine if a function name represents a branching function @@ -408,79 +430,27 @@ fn is_branching_function(function_name: &ClarityName) -> bool { } } -/// Build a cost tree from an expression tree, using branching logic for min/max calculation -fn build_cost_tree(expr_tree: &ExprTree) -> Result { - let function_name = match &expr_tree.expr { - ExprNode::If => "if", - ExprNode::Match => "match", - ExprNode::Unwrap => "unwrap!", - ExprNode::Ok => "ok", - ExprNode::Err => "err", - ExprNode::GT => ">", - ExprNode::LT => "<", - ExprNode::GE => ">=", - ExprNode::LE => "<=", - ExprNode::EQ => "=", - ExprNode::Add => "+", - ExprNode::Sub => "-", - ExprNode::Mul => "*", - ExprNode::Div => "/", - ExprNode::Function(name) => name.as_str(), - ExprNode::AtomValue(value) => { - // String literals have cost based on length only when they're standalone (not function arguments) - // TODO: not sure if /utf8 and ascii are treated the same cost-wise.. - if let Value::Sequence(SequenceData::String(CharType::UTF8(data))) = value { - return Ok(StaticCostNode::leaf(vec![], string_cost(data.data.len()))); - } else if let Value::Sequence(SequenceData::String(CharType::ASCII(data))) = value { - return Ok(StaticCostNode::leaf(vec![], string_cost(data.data.len()))); - } - // Other atom values have zero cost - return Ok(StaticCostNode::leaf(vec![], StaticCost::ZERO)); - } - ExprNode::Atom(_) - | ExprNode::SugaredContractIdentifier(_) - | ExprNode::SugaredFieldIdentifier(_, _) - | ExprNode::FieldIdentifier(_) - | ExprNode::TraitReference(_) - | ExprNode::UserArgument(_, _) => { - // Leaf nodes have zero cost - return Ok(StaticCostNode::leaf(vec![], StaticCost::ZERO)); - } - }; - - let mut children = Vec::new(); - for child_expr in &expr_tree.children { - // For certain functions like concat and len, string arguments should have zero cost - // since the function cost includes their processing - if function_name == "concat" || function_name == "len" { - if let ExprNode::AtomValue(Value::Sequence(SequenceData::String(_))) = &child_expr.expr - { - // String arguments to concat and len have zero cost - children.push(StaticCostNode::leaf(vec![], StaticCost::ZERO)); - continue; - } - } - children.push(build_cost_tree(child_expr)?); +/// Calculate the cost for a string based on its length +fn string_cost(length: usize) -> StaticCost { + let cost = linear(length as u64, STRING_COST_BASE, STRING_COST_MULTIPLIER); + let execution_cost = ExecutionCost::runtime(cost); + StaticCost { + min: execution_cost.clone(), + max: execution_cost, } +} - let cost = calculate_function_cost_from_name(function_name, expr_tree.children.len() as u64)?; - - // Create a representative PreSymbolicExpression for the node - let function_expr = PreSymbolicExpression { - pre_expr: PreSymbolicExpressionType::Atom(ClarityName::from(function_name)), - id: 0, // We don't need accurate IDs for cost analysis - }; - let mut expr_list = vec![function_expr]; - - // Add placeholder expressions for children (we don't need the actual child expressions) - for _ in &expr_tree.children { - expr_list.push(PreSymbolicExpression { - pre_expr: PreSymbolicExpressionType::Atom(ClarityName::from("placeholder")), - id: 0, - }); +/// Calculate cost for a value (used for literal values) +fn calculate_value_cost(value: &Value) -> Result { + match value { + Value::Sequence(SequenceData::String(CharType::UTF8(data))) => { + Ok(string_cost(data.data.len())) + } + Value::Sequence(SequenceData::String(CharType::ASCII(data))) => { + Ok(string_cost(data.data.len())) + } + _ => Ok(StaticCost::ZERO), } - - Ok(StaticCostNode::new(expr_list, cost, children)) } fn calculate_function_cost_from_name( @@ -606,19 +576,14 @@ fn get_cost_function_for_name(name: &str) -> Option InterpreterResult "string-to-uint?" => Some(Costs3::cost_string_to_uint), "int-to-ascii" => Some(Costs3::cost_int_to_ascii), "int-to-utf8?" => Some(Costs3::cost_int_to_utf8), - _ => None, // Unknown function name + _ => None, // TODO } } -fn calculate_total_cost(node: &StaticCostNode) -> StaticCost { - calculate_total_cost_with_summing(node).into() -} - /// Calculate total cost using SummingExecutionCost to handle branching properly -fn calculate_total_cost_with_summing(node: &StaticCostNode) -> SummingExecutionCost { +fn calculate_total_cost_with_summing(node: &CostAnalysisNode) -> SummingExecutionCost { let mut summing_cost = SummingExecutionCost::from_single(node.cost.min.clone()); - // For each child, calculate its cost and combine appropriately for child in &node.children { let child_summing = calculate_total_cost_with_summing(child); summing_cost.add_summing(&child_summing); @@ -627,29 +592,22 @@ fn calculate_total_cost_with_summing(node: &StaticCostNode) -> SummingExecutionC summing_cost } -/// Calculate total cost using branching logic from ExprTree -fn calculate_total_cost_with_branching( - expr_tree: &ExprTree, - cost_node: &StaticCostNode, -) -> SummingExecutionCost { +fn calculate_total_cost_with_branching(node: &CostAnalysisNode) -> SummingExecutionCost { let mut summing_cost = SummingExecutionCost::new(); - if expr_tree.branching { - // Handle different types of branching functions - match &expr_tree.expr { + if node.branching { + match &node.expr { ExprNode::If | ExprNode::Match => { - // For if and match, we need to create separate execution paths - // The first child is the condition, the rest are the branches - if cost_node.children.len() >= 2 { - let condition_cost = calculate_total_cost_with_summing(&cost_node.children[0]); + // TODO match? + if node.children.len() >= 2 { + let condition_cost = calculate_total_cost_with_summing(&node.children[0]); let condition_total = condition_cost.add_all(); // Add the root cost + condition cost to each branch - let mut root_and_condition = cost_node.cost.min.clone(); + let mut root_and_condition = node.cost.min.clone(); let _ = root_and_condition.add(&condition_total); - // For each branch (children 1+), create a complete path - for child_cost_node in cost_node.children.iter().skip(1) { + for child_cost_node in node.children.iter().skip(1) { let branch_cost = calculate_total_cost_with_summing(child_cost_node); let branch_total = branch_cost.add_all(); @@ -662,8 +620,8 @@ fn calculate_total_cost_with_branching( } _ => { // For other branching functions, fall back to sequential processing - let mut total_cost = cost_node.cost.min.clone(); - for child_cost_node in &cost_node.children { + let mut total_cost = node.cost.min.clone(); + for child_cost_node in &node.children { let child_summing = calculate_total_cost_with_summing(child_cost_node); let combined_cost = child_summing.add_all(); let _ = total_cost.add(&combined_cost); @@ -673,8 +631,8 @@ fn calculate_total_cost_with_branching( } } else { // For non-branching, add all costs sequentially - let mut total_cost = cost_node.cost.min.clone(); - for child_cost_node in &cost_node.children { + let mut total_cost = node.cost.min.clone(); + for child_cost_node in &node.children { let child_summing = calculate_total_cost_with_summing(child_cost_node); let combined_cost = child_summing.add_all(); let _ = total_cost.add(&combined_cost); @@ -778,115 +736,122 @@ mod tests { assert_eq!(cost.min.runtime, base_cost + 147); assert_eq!(cost.max.runtime, base_cost + 294); } - #[test] - fn test_function_arguments() { - let src = r#"(define-public (foo (a u64)) (ok a))"#; - let pre_expressions = parse(src).unwrap(); - let pre_expr = &pre_expressions[0]; - let expr_tree = build_expr_tree(pre_expr).unwrap(); - - // The root should be a Function node with "define-public" - assert!(matches!(expr_tree.expr, ExprNode::Function(_))); - if let ExprNode::Function(name) = &expr_tree.expr { - assert_eq!(name.as_str(), "define-public"); - } - - // Should have 2 children: UserArgument for (a u64) and the body (ok a) - assert_eq!(expr_tree.children.len(), 2); - - // First child should be UserArgument for (a u64) - let user_arg = &expr_tree.children[0]; - assert!(matches!(user_arg.expr, ExprNode::UserArgument(_, _))); - if let ExprNode::UserArgument(arg_name, arg_type) = &user_arg.expr { - assert_eq!(arg_name.as_str(), "a"); - assert_eq!(arg_type.as_str(), "u64"); - } - - // Second child should be the function body (ok a) - let body = &expr_tree.children[1]; - assert!(matches!(body.expr, ExprNode::Ok)); - assert_eq!(body.children.len(), 1); - // The body should reference the argument 'a' - let arg_ref = &body.children[0]; - assert!(matches!(arg_ref.expr, ExprNode::Atom(_))); - if let ExprNode::Atom(name) = &arg_ref.expr { - assert_eq!(name.as_str(), "a"); - } - } + // ---- ExprTreee building specific tests #[test] - fn test_build_expr_tree_if_expression() { + fn test_build_cost_analysis_tree_if_expression() { let source = "(if (> 3 0) (ok true) (ok false))"; let pre_expressions = parse(source).unwrap(); let pre_expr = &pre_expressions[0]; - let expr_tree = build_expr_tree(pre_expr).unwrap(); + let user_args = UserArgumentsContext::new(); + let cost_tree = build_cost_analysis_tree(pre_expr, &user_args).unwrap(); // Root should be an If node with branching=true - assert!(matches!(expr_tree.expr, ExprNode::If)); - assert!(expr_tree.branching); - assert_eq!(expr_tree.children.len(), 3); // condition, then, else + assert!(matches!(cost_tree.expr, ExprNode::If)); + assert!(cost_tree.branching); + assert_eq!(cost_tree.children.len(), 3); - // First child should be GT comparison - let gt_node = &expr_tree.children[0]; + let gt_node = &cost_tree.children[0]; assert!(matches!(gt_node.expr, ExprNode::GT)); - assert!(!gt_node.branching); - assert_eq!(gt_node.children.len(), 2); // 3 and 0 + assert_eq!(gt_node.children.len(), 2); - // GT children should be AtomValue(3) and AtomValue(0) let left_val = >_node.children[0]; let right_val = >_node.children[1]; assert!(matches!(left_val.expr, ExprNode::AtomValue(_))); assert!(matches!(right_val.expr, ExprNode::AtomValue(_))); - // Second child should be Ok(true) - let ok_true_node = &expr_tree.children[1]; + let ok_true_node = &cost_tree.children[1]; assert!(matches!(ok_true_node.expr, ExprNode::Ok)); - assert!(!ok_true_node.branching); assert_eq!(ok_true_node.children.len(), 1); - // Third child should be Ok(false) - let ok_false_node = &expr_tree.children[2]; + let ok_false_node = &cost_tree.children[2]; assert!(matches!(ok_false_node.expr, ExprNode::Ok)); - assert!(!ok_false_node.branching); assert_eq!(ok_false_node.children.len(), 1); } #[test] - fn test_build_expr_tree_arithmetic() { + fn test_build_cost_analysis_tree_arithmetic() { let source = "(+ (* 2 3) (- 5 1))"; let pre_expressions = parse(source).unwrap(); let pre_expr = &pre_expressions[0]; - let expr_tree = build_expr_tree(pre_expr).unwrap(); + let user_args = UserArgumentsContext::new(); + let cost_tree = build_cost_analysis_tree(pre_expr, &user_args).unwrap(); - // Root should be Add node - assert!(matches!(expr_tree.expr, ExprNode::Add)); - assert!(!expr_tree.branching); - assert_eq!(expr_tree.children.len(), 2); + assert!(matches!(cost_tree.expr, ExprNode::Add)); + assert!(!cost_tree.branching); + assert_eq!(cost_tree.children.len(), 2); - // First child should be Mul - let mul_node = &expr_tree.children[0]; + let mul_node = &cost_tree.children[0]; assert!(matches!(mul_node.expr, ExprNode::Mul)); assert_eq!(mul_node.children.len(), 2); - // Second child should be Sub - let sub_node = &expr_tree.children[1]; + let sub_node = &cost_tree.children[1]; assert!(matches!(sub_node.expr, ExprNode::Sub)); assert_eq!(sub_node.children.len(), 2); } #[test] - fn test_build_expr_tree_with_comments() { + fn test_build_cost_analysis_tree_with_comments() { let source = "(+ 1 ;; this is a comment\n 2)"; let pre_expressions = parse(source).unwrap(); let pre_expr = &pre_expressions[0]; - let expr_tree = build_expr_tree(pre_expr).unwrap(); + let user_args = UserArgumentsContext::new(); + let cost_tree = build_cost_analysis_tree(pre_expr, &user_args).unwrap(); - assert!(matches!(expr_tree.expr, ExprNode::Add)); - assert!(!expr_tree.branching); - assert_eq!(expr_tree.children.len(), 2); + assert!(matches!(cost_tree.expr, ExprNode::Add)); + assert!(!cost_tree.branching); + assert_eq!(cost_tree.children.len(), 2); - for child in &expr_tree.children { + for child in &cost_tree.children { assert!(matches!(child.expr, ExprNode::AtomValue(_))); } } + + #[test] + fn test_function_with_multiple_arguments() { + let src = r#"(define-public (add-two (x u64) (y u64)) (+ x y))"#; + let pre_expressions = parse(src).unwrap(); + let pre_expr = &pre_expressions[0]; + let user_args = UserArgumentsContext::new(); + let cost_tree = build_cost_analysis_tree(pre_expr, &user_args).unwrap(); + + // Should have 3 children: UserArgument for (x u64), UserArgument for (y u64), and the body (+ x y) + assert_eq!(cost_tree.children.len(), 3); + + // First child should be UserArgument for (x u64) + let user_arg_x = &cost_tree.children[0]; + assert!(matches!(user_arg_x.expr, ExprNode::UserArgument(_, _))); + if let ExprNode::UserArgument(arg_name, arg_type) = &user_arg_x.expr { + assert_eq!(arg_name.as_str(), "x"); + assert_eq!(arg_type.as_str(), "u64"); + } + + // Second child should be UserArgument for (y u64) + let user_arg_y = &cost_tree.children[1]; + assert!(matches!(user_arg_y.expr, ExprNode::UserArgument(_, _))); + if let ExprNode::UserArgument(arg_name, arg_type) = &user_arg_y.expr { + assert_eq!(arg_name.as_str(), "y"); + assert_eq!(arg_type.as_str(), "u64"); + } + + // Third child should be the function body (+ x y) + let body = &cost_tree.children[2]; + assert!(matches!(body.expr, ExprNode::Add)); + assert_eq!(body.children.len(), 2); + + // Both arguments in the body should be UserArguments + let arg_x_ref = &body.children[0]; + let arg_y_ref = &body.children[1]; + assert!(matches!(arg_x_ref.expr, ExprNode::UserArgument(_, _))); + assert!(matches!(arg_y_ref.expr, ExprNode::UserArgument(_, _))); + + if let ExprNode::UserArgument(name, arg_type) = &arg_x_ref.expr { + assert_eq!(name.as_str(), "x"); + assert_eq!(arg_type.as_str(), "u64"); + } + if let ExprNode::UserArgument(name, arg_type) = &arg_y_ref.expr { + assert_eq!(name.as_str(), "y"); + assert_eq!(arg_type.as_str(), "u64"); + } + } }