From 92aa59d2c7f150a3f755d6e68b2f00c707515645 Mon Sep 17 00:00:00 2001 From: Mukunda Rao Katta Date: Sat, 25 Apr 2026 23:30:00 +0000 Subject: [PATCH] fix: accept start as keyword arg in enumerate builtin Matches CPython's `enumerate(iterable, start=0)` signature so that `start` works as both a positional and a keyword argument. Closes #394. --- crates/monty/src/builtins/enumerate.rs | 103 +++++++++++++++++- .../test_cases/builtin__more_iter_funcs.py | 14 +++ 2 files changed, 113 insertions(+), 4 deletions(-) diff --git a/crates/monty/src/builtins/enumerate.rs b/crates/monty/src/builtins/enumerate.rs index 96fade42b..06b4df7e1 100644 --- a/crates/monty/src/builtins/enumerate.rs +++ b/crates/monty/src/builtins/enumerate.rs @@ -3,11 +3,11 @@ use smallvec::smallvec; use crate::{ - args::ArgValues, + args::{ArgValues, KwargsValues}, bytecode::VM, defer_drop, defer_drop_mut, - exception_private::{ExcType, RunResult, SimpleException}, - heap::HeapData, + exception_private::{ExcType, RunError, RunResult, SimpleException}, + heap::{DropWithHeap, HeapData}, resource::ResourceTracker, types::{List, MontyIter, PyTrait, allocate_tuple}, value::Value, @@ -17,8 +17,11 @@ use crate::{ /// /// Returns a list of (index, value) tuples. /// Note: In Python this returns an iterator, but we return a list for simplicity. +/// +/// Matches CPython's `enumerate(iterable, start=0)` signature: `start` may be +/// passed either positionally or as the `start=` keyword argument. pub fn builtin_enumerate(vm: &mut VM<'_, '_, impl ResourceTracker>, args: ArgValues) -> RunResult { - let (iterable, start) = args.get_one_two_args("enumerate", vm.heap)?; + let (iterable, start) = parse_enumerate_args(args, vm)?; let iter = MontyIter::new(iterable, vm)?; defer_drop_mut!(iter, vm); defer_drop!(start, vm); @@ -50,3 +53,95 @@ pub fn builtin_enumerate(vm: &mut VM<'_, '_, impl ResourceTracker>, args: ArgVal let heap_id = vm.heap.allocate(HeapData::List(List::new(result)))?; Ok(Value::Ref(heap_id)) } + +/// Parses arguments for `enumerate(iterable, start=0)`. +/// +/// `iterable` is required and positional. `start` is optional and may be passed +/// either positionally or as the `start=` keyword argument, but not both. +fn parse_enumerate_args( + args: ArgValues, + vm: &mut VM<'_, '_, impl ResourceTracker>, +) -> RunResult<(Value, Option)> { + let (mut positional, kwargs) = args.into_parts(); + let positional_count = positional.len(); + + // Need at least one positional argument (the iterable). + let Some(iterable) = positional.next() else { + kwargs.drop_with_heap(vm.heap); + return Err(ExcType::type_error_at_least("enumerate", 1, 0)); + }; + + // Optional positional `start`; reject any extra positional args. + let positional_start: Option = positional.next(); + if positional.len() > 0 { + kwargs.drop_with_heap(vm.heap); + iterable.drop_with_heap(vm.heap); + positional_start.drop_with_heap(vm.heap); + positional.drop_with_heap(vm.heap); + return Err(ExcType::type_error_at_most("enumerate", 2, positional_count)); + } + + // Pull out the `start=` keyword argument if any. + let saw_positional_start = positional_start.is_some(); + match extract_start_kwarg(kwargs, saw_positional_start, vm) { + Ok(kw_start) => Ok((iterable, kw_start.or(positional_start))), + Err(err) => { + iterable.drop_with_heap(vm.heap); + positional_start.drop_with_heap(vm.heap); + Err(err) + } + } +} + +/// Walks `kwargs` and returns the value passed for `start=`, if any. +/// +/// All kwargs are fully consumed before returning, so reference counts stay correct +/// even when some kwargs come after a bad one. `saw_positional_start` indicates that +/// `start` was already supplied positionally, in which case any `start=` kwarg is +/// rejected as a duplicate. +fn extract_start_kwarg( + kwargs: KwargsValues, + saw_positional_start: bool, + vm: &mut VM<'_, '_, impl ResourceTracker>, +) -> Result, RunError> { + let mut start: Option = None; + let mut error: Option = None; + + for (key, value) in kwargs { + defer_drop!(key, vm); + + // If we already hit an error, drop remaining values and continue. + if error.is_some() { + value.drop_with_heap(vm.heap); + continue; + } + + let Some(keyword_name) = key.as_either_str(vm.heap) else { + value.drop_with_heap(vm.heap); + error = Some(ExcType::type_error_kwargs_nonstring_key()); + continue; + }; + + let key_str = keyword_name.as_str(vm.interns); + if key_str == "start" { + if saw_positional_start || start.is_some() { + value.drop_with_heap(vm.heap); + error = Some(ExcType::type_error_multiple_values("enumerate", "start")); + } else { + start = Some(value); + } + } else { + // Build the error before touching `vm` mutably (the format borrows `key_str`). + let err = ExcType::type_error_unexpected_keyword("enumerate", key_str); + value.drop_with_heap(vm.heap); + error = Some(err); + } + } + + if let Some(err) = error { + start.drop_with_heap(vm.heap); + Err(err) + } else { + Ok(start) + } +} diff --git a/crates/monty/test_cases/builtin__more_iter_funcs.py b/crates/monty/test_cases/builtin__more_iter_funcs.py index e027ccbb9..3a7ef34c3 100644 --- a/crates/monty/test_cases/builtin__more_iter_funcs.py +++ b/crates/monty/test_cases/builtin__more_iter_funcs.py @@ -331,6 +331,20 @@ def negate(x): assert list(enumerate(['a', 'b'], 1)) == [(1, 'a'), (2, 'b')], 'enumerate with start' assert list(enumerate(['a', 'b'], 10)) == [(10, 'a'), (11, 'b')], 'enumerate with start 10' +# enumerate with start as keyword argument +assert list(enumerate(['a', 'b', 'c'], start=1)) == [(1, 'a'), (2, 'b'), (3, 'c')], 'enumerate start=1 kwarg' +assert list(enumerate(['a', 'b'], start=10)) == [(10, 'a'), (11, 'b')], 'enumerate start=10 kwarg' +assert list(enumerate([], start=5)) == [], 'enumerate empty with start kwarg' + +# enumerate rejects unknown keyword arguments +try: + list(enumerate(['a'], nope=1)) + assert False, 'enumerate with invalid keyword should raise TypeError' +except TypeError as e: + assert e.args == ("enumerate() got an unexpected keyword argument 'nope'",), ( + 'enumerate invalid keyword error matches CPython' + ) + # enumerate string assert list(enumerate('ab')) == [(0, 'a'), (1, 'b')], 'enumerate string'