diff --git a/src/env.rs b/src/env.rs index 704de73..649d75c 100644 --- a/src/env.rs +++ b/src/env.rs @@ -1,20 +1,22 @@ use crate::object::Object; -use std::cell::RefCell; use std::collections::HashMap; -use std::rc::Rc; #[derive(Debug, PartialEq, Default)] -pub struct Env { - parent: Option>>, +pub struct Env<'a> { + parent: Option<&'a Env<'a>>, vars: HashMap, } -impl Env { - pub fn new() -> Self { - Default::default() +impl<'a> Env<'a> { + + pub fn new() -> Env<'a> { + Env { + vars: HashMap::new(), + parent: None, + } } - pub fn extend(parent: Rc>) -> Env { + pub fn extend(parent: &'a Self) -> Env<'a> { Env { vars: HashMap::new(), parent: Some(parent), @@ -24,10 +26,7 @@ impl Env { pub fn get(&self, name: &str) -> Option { match self.vars.get(name) { Some(value) => Some(value.clone()), - None => self - .parent - .as_ref() - .and_then(|o| o.borrow().get(name).clone()), + None => self.parent.and_then(|parent| parent.get(name)), } } diff --git a/src/eval.rs b/src/eval.rs index e865f64..81a8637 100644 --- a/src/eval.rs +++ b/src/eval.rs @@ -1,10 +1,8 @@ use crate::env::*; use crate::object::*; use crate::parser::*; -use std::cell::RefCell; -use std::rc::Rc; -fn eval_binary_op(list: &Vec, env: &mut Rc>) -> Result { +fn eval_binary_op(list: &Vec, env: &mut Env) -> Result { if list.len() != 3 { return Err(format!("Invalid number of arguments for infix operator")); } @@ -35,7 +33,7 @@ fn eval_binary_op(list: &Vec, env: &mut Rc>) -> Result, env: &mut Rc>) -> Result { +fn eval_define(list: &Vec, env: &mut Env) -> Result { if list.len() != 3 { return Err(format!("Invalid number of arguments for define")); } @@ -45,11 +43,11 @@ fn eval_define(list: &Vec, env: &mut Rc>) -> Result return Err(format!("Invalid define")), }; let val = eval_obj(&list[2], env)?; - env.borrow_mut().set(&sym, val); + env.set(&sym, val); Ok(Object::Void) } -fn eval_if(list: &Vec, env: &mut Rc>) -> Result { +fn eval_if(list: &Vec, env: &mut Env) -> Result { if list.len() != 4 { return Err(format!("Invalid number of arguments for if statement")); } @@ -92,9 +90,9 @@ fn eval_function_definition(list: &Vec) -> Result { fn eval_function_call( s: &str, list: &Vec, - env: &mut Rc>, + env: &mut Env, ) -> Result { - let lamdba = env.borrow_mut().get(s); + let lamdba = env.get(s); if lamdba.is_none() { return Err(format!("Unbound symbol: {}", s)); } @@ -102,10 +100,10 @@ fn eval_function_call( let func = lamdba.unwrap(); match func { Object::Lambda(params, body) => { - let mut new_env = Rc::new(RefCell::new(Env::extend(env.clone()))); + let mut new_env = Env::extend(env); for (i, param) in params.iter().enumerate() { - let val = eval_obj(&list[i + 1], env)?; - new_env.borrow_mut().set(param, val); + let val = eval_obj(&list[i + 1], &mut new_env)?; + new_env.set(param, val); } return eval_obj(&Object::List(body), &mut new_env); } @@ -113,15 +111,15 @@ fn eval_function_call( } } -fn eval_symbol(s: &str, env: &mut Rc>) -> Result { - let val = env.borrow_mut().get(s); +fn eval_symbol(s: &str, env: &mut Env) -> Result { + let val = env.get(s); if val.is_none() { return Err(format!("Unbound symbol: {}", s)); } Ok(val.unwrap().clone()) } -fn eval_list(list: &Vec, env: &mut Rc>) -> Result { +fn eval_list(list: &Vec, env: &mut Env) -> Result { let head = &list[0]; match head { Object::Symbol(s) => match s.as_str() { @@ -147,7 +145,7 @@ fn eval_list(list: &Vec, env: &mut Rc>) -> Result>) -> Result { +fn eval_obj(obj: &Object, env: &mut Env) -> Result { match obj { Object::List(list) => eval_list(list, env), Object::Void => Ok(Object::Void), @@ -158,7 +156,7 @@ fn eval_obj(obj: &Object, env: &mut Rc>) -> Result } } -pub fn eval(program: &str, env: &mut Rc>) -> Result { +pub fn eval(program: &str, env: &mut Env) -> Result { let parsed_list = parse(program); if parsed_list.is_err() { return Err(format!("{}", parsed_list.err().unwrap())); @@ -172,14 +170,14 @@ mod tests { #[test] fn test_simple_add() { - let mut env = Rc::new(RefCell::new(Env::new())); + let mut env = Box::new(Env::new()); let result = eval("(+ 1 2)", &mut env).unwrap(); assert_eq!(result, Object::Integer(3)); } #[test] fn test_area_of_a_circle() { - let mut env = Rc::new(RefCell::new(Env::new())); + let mut env = Box::new(Env::new()); let program = "( (define r 10) (define pi 314) @@ -194,7 +192,7 @@ mod tests { #[test] fn test_sqr_function() { - let mut env = Rc::new(RefCell::new(Env::new())); + let mut env = Box::new(Env::new()); let program = "( (define sqr (lambda (r) (* r r))) (sqr 10) @@ -208,7 +206,7 @@ mod tests { #[test] fn test_fibonaci() { - let mut env = Rc::new(RefCell::new(Env::new())); + let mut env = Box::new(Env::new()); let program = " ( (define fib (lambda (n) (if (< n 2) 1 (+ (fib (- n 1)) (fib (- n 2)))))) @@ -222,7 +220,7 @@ mod tests { #[test] fn test_factorial() { - let mut env = Rc::new(RefCell::new(Env::new())); + let mut env = Box::new(Env::new()); let program = " ( (define fact (lambda (n) (if (< n 1) 1 (* n (fact (- n 1)))))) @@ -236,7 +234,7 @@ mod tests { #[test] fn test_circle_area_function() { - let mut env = Rc::new(RefCell::new(Env::new())); + let mut env = Box::new(Env::new()); let program = " ( (define pi 314) diff --git a/src/main.rs b/src/main.rs index adf7b7b..fb697c0 100644 --- a/src/main.rs +++ b/src/main.rs @@ -6,14 +6,12 @@ mod parser; use linefeed::{Interface, ReadResult}; use object::Object; -use std::cell::RefCell; -use std::rc::Rc; const PROMPT: &str = "lisp-rs> "; fn main() -> Result<(), Box> { let reader = Interface::new(PROMPT).unwrap(); - let mut env = Rc::new(RefCell::new(env::Env::new())); + let mut env = Box::new(env::Env::new()); reader.set_prompt(format!("{}", PROMPT).as_ref()).unwrap();