diff --git a/examples/codegen/mutability.an b/examples/codegen/mutability.an index d6b6e073..a2119fd8 100644 --- a/examples/codegen/mutability.an +++ b/examples/codegen/mutability.an @@ -6,7 +6,7 @@ print num mutate num print num -mutate (n: Ref I32) = +mutate (n: &mut I32) = x = double @n n := x diff --git a/examples/codegen/string_builder.an b/examples/codegen/string_builder.an index df9ce468..28e56781 100644 --- a/examples/codegen/string_builder.an +++ b/examples/codegen/string_builder.an @@ -1,11 +1,11 @@ import StringBuilder -sb: Ref StringBuilder = mut empty () +sb: &mut StringBuilder = mut empty () reserve sb 10 append sb "你好," append sb " World!" -print (to_string sb) +print (to_string @sb) // args: --delete-binary // expected stdout: 你好, World! diff --git a/examples/nameresolution/type_decl.an b/examples/nameresolution/type_decl.an index 5abe8e1f..2c4db879 100644 --- a/examples/nameresolution/type_decl.an +++ b/examples/nameresolution/type_decl.an @@ -1,7 +1,7 @@ type Struct1 = a:I32, b:F64, c:String -type Thingy is Struct1 +type Thingy = Struct1 type Generic a b = first: a, second: b @@ -17,7 +17,7 @@ type Option a = t = Just 1 -type UniquePtr a is Ref a +type MyRef a = &a // args: --check // expected stdout: diff --git a/examples/parsing/type_decl.an b/examples/parsing/type_decl.an index 46c79007..a24a9a67 100644 --- a/examples/parsing/type_decl.an +++ b/examples/parsing/type_decl.an @@ -11,10 +11,10 @@ type Maybe a = | Some a | None -type List a = | Nil | Cons a (ref (List a)) +type List a = | Nil | Cons a (&List a) -type UniquePtr a is ref a +type UniquePtr a = &a t = 3 : I32 @@ -24,6 +24,6 @@ t = 3 : I32 // (type Struct2 t = a: Thingy, b: (Generic t Thingy)); // (type Union1 ab = | Variant1 | Variant2 ); // (type Maybe a = | Some a| None ); -// (type List a = | Nil | Cons a (ref (List a))); -// (type UniquePtr a = (ref a)); +// (type List a = | Nil | Cons a (& (List a))); +// (type UniquePtr a = (& a)); // (t = (: 3 I32)) diff --git a/src/hir/definitions.rs b/src/hir/definitions.rs index 49293745..199ea6bf 100644 --- a/src/hir/definitions.rs +++ b/src/hir/definitions.rs @@ -63,7 +63,10 @@ impl std::hash::Hash for DefinitionType { types::Type::TypeVariable(_) => (), // Do nothing types::Type::Function(_) => (), types::Type::TypeApplication(_, _) => (), - types::Type::Ref(_) => (), + types::Type::Ref(shared, mutable, _) => { + shared.hash(state); + mutable.hash(state); + }, types::Type::Struct(field_names, _) => { for name in field_names { name.hash(state); @@ -91,7 +94,13 @@ fn definition_type_eq(a: &types::Type, b: &types::Type) -> bool { match (a, b) { (Type::Primitive(primitive1), Type::Primitive(primitive2)) => primitive1 == primitive2, (Type::UserDefined(id1), Type::UserDefined(id2)) => id1 == id2, - (Type::TypeVariable(_), Type::TypeVariable(_)) | (Type::Ref(_), Type::Ref(_)) => true, // Do nothing + (Type::TypeVariable(_), Type::TypeVariable(_)) => true, // Do nothing + // This will monomorphize separate definitions for polymorphically-owned references + // which is undesired. Defaulting them to shared/owned though can change behavior + // if traits are involved. + (Type::Ref(shared1, mutable1, _), Type::Ref(shared2, mutable2, _)) => { + shared1 == shared2 && mutable1 == mutable2 + }, (Type::Function(f1), Type::Function(f2)) => { if f1.parameters.len() != f2.parameters.len() { return false; diff --git a/src/hir/monomorphisation.rs b/src/hir/monomorphisation.rs index 0cebbd4a..fd2f9874 100644 --- a/src/hir/monomorphisation.rs +++ b/src/hir/monomorphisation.rs @@ -142,12 +142,12 @@ impl<'c> Context<'c> { let fuel = fuel - 1; match &self.cache.type_bindings[id.0] { - Bound(TypeVariable(id2) | Ref(id2)) => self.find_binding(*id2, fuel), + Bound(TypeVariable(id2) | Ref(_, _, id2)) => self.find_binding(*id2, fuel), Bound(binding) => Ok(binding), Unbound(..) => { for bindings in self.monomorphisation_bindings.iter().rev() { match bindings.get(&id) { - Some(TypeVariable(id2) | Ref(id2)) => return self.find_binding(*id2, fuel), + Some(TypeVariable(id2) | Ref(_, _, id2)) => return self.find_binding(*id2, fuel), Some(binding) => return Ok(binding), None => (), } @@ -204,7 +204,7 @@ impl<'c> Context<'c> { let args = fmap(args, |arg| self.follow_all_bindings_inner(arg, fuel)); TypeApplication(Box::new(con), args) }, - Ref(_) => typ.clone(), + Ref(..) => typ.clone(), Struct(fields, id) => match self.find_binding(*id, fuel) { Ok(binding) => self.follow_all_bindings_inner(binding, fuel), Err(_) => { @@ -346,7 +346,7 @@ impl<'c> Context<'c> { _ => unreachable!("Kind error inside size_of_type"), }, - Ref(_) => Self::ptr_size(), + Ref(..) => Self::ptr_size(), Struct(fields, rest) => { if let Ok(binding) = self.find_binding(*rest, RECURSION_LIMIT) { let binding = binding.clone(); @@ -519,7 +519,7 @@ impl<'c> Context<'c> { let typ = self.follow_bindings_shallow(typ); match typ { - Ok(Primitive(PrimitiveType::Ptr) | Ref(_)) => Type::Primitive(hir::PrimitiveType::Pointer), + Ok(Primitive(PrimitiveType::Ptr) | Ref(..)) => Type::Primitive(hir::PrimitiveType::Pointer), Ok(Primitive(PrimitiveType::IntegerType)) => { if self.is_type_variable(&args[0]) { // Default to i32 @@ -553,7 +553,7 @@ impl<'c> Context<'c> { } }, - Ref(_) => { + Ref(..) => { unreachable!( "Kind error during monomorphisation. Attempted to translate a `ref` without a type argument" ) @@ -1517,7 +1517,7 @@ impl<'c> Context<'c> { TypeApplication(typ, args) => { match typ.as_ref() { // Pass through ref types transparently - types::Type::Ref(_) => self.get_field_index(field_name, &args[0]), + types::Type::Ref(..) => self.get_field_index(field_name, &args[0]), // These last 2 cases are the same. They're duplicated to avoid another follow_bindings_shallow call. typ => self.get_field_index(field_name, typ), } @@ -1551,7 +1551,7 @@ impl<'c> Context<'c> { let ref_type = match lhs_type { types::Type::TypeApplication(constructor, args) => match self.follow_bindings_shallow(constructor.as_ref()) { - Ok(types::Type::Ref(_)) => Some(self.convert_type(&args[0])), + Ok(types::Type::Ref(..)) => Some(self.convert_type(&args[0])), _ => None, }, _ => None, diff --git a/src/lexer/mod.rs b/src/lexer/mod.rs index ab042eb2..0ecbb8ce 100644 --- a/src/lexer/mod.rs +++ b/src/lexer/mod.rs @@ -121,20 +121,16 @@ impl<'cache, 'contents> Lexer<'cache, 'contents> { ("Ptr", Token::PointerType), ("Bool", Token::BooleanType), ("Unit", Token::UnitType), - ("Ref", Token::Ref), ("mut", Token::Mut), ("true", Token::BooleanLiteral(true)), ("false", Token::BooleanLiteral(false)), ("and", Token::And), ("as", Token::As), ("block", Token::Block), - ("break", Token::Break), - ("continue", Token::Continue), ("do", Token::Do), ("effect", Token::Effect), ("else", Token::Else), ("extern", Token::Extern), - ("for", Token::For), ("fn", Token::Fn), ("given", Token::Given), ("handle", Token::Handle), @@ -142,14 +138,16 @@ impl<'cache, 'contents> Lexer<'cache, 'contents> { ("impl", Token::Impl), ("import", Token::Import), ("in", Token::In), - ("is", Token::Is), - ("isnt", Token::Isnt), ("loop", Token::Loop), ("match", Token::Match), ("module", Token::Module), ("not", Token::Not), ("or", Token::Or), + ("owned", Token::Owned), ("return", Token::Return), + ("ref", Token::Ref), + ("return", Token::Return), + ("shared", Token::Shared), ("then", Token::Then), ("trait", Token::Trait), ("type", Token::Type), @@ -590,7 +588,7 @@ impl<'cache, 'contents> Iterator for Lexer<'cache, 'contents> { // This will overflow if there are mismatched parenthesis, // should we handle this inside the lexer, // or leave that to the parsing stage? - self.open_braces.parenthesis -= 1; + self.open_braces.parenthesis = self.open_braces.parenthesis.saturating_sub(1); self.advance_with(Token::ParenthesisRight) }, ('+', _) => self.advance_with(Token::Add), @@ -599,7 +597,7 @@ impl<'cache, 'contents> Iterator for Lexer<'cache, 'contents> { self.advance_with(Token::BracketLeft) }, (']', _) => { - self.open_braces.square -= 1; + self.open_braces.square = self.open_braces.square.saturating_sub(1); self.advance_with(Token::BracketRight) }, ('|', _) => self.advance_with(Token::Pipe), diff --git a/src/lexer/token.rs b/src/lexer/token.rs index 307f520d..84cfd51a 100644 --- a/src/lexer/token.rs +++ b/src/lexer/token.rs @@ -85,20 +85,16 @@ pub enum Token { PointerType, BooleanType, UnitType, - Ref, Mut, // Keywords And, As, Block, - Break, - Continue, Do, Effect, Else, Extern, - For, Fn, Given, Handle, @@ -106,14 +102,15 @@ pub enum Token { Impl, Import, In, - Is, - Isnt, Loop, Match, Module, Not, Or, + Owned, + Ref, Return, + Shared, Then, Trait, Type, @@ -166,8 +163,6 @@ impl Token { And | As | At | In - | Is - | Isnt | Not | Or | EqualEqual @@ -186,7 +181,7 @@ impl Token { | LessThanOrEqual | GreaterThanOrEqual | Divide - | Ampersand + | Range ) } } @@ -271,20 +266,16 @@ impl Display for Token { Token::PointerType => write!(f, "'Ptr'"), Token::BooleanType => write!(f, "'bool'"), Token::UnitType => write!(f, "'unit'"), - Token::Ref => write!(f, "'ref'"), Token::Mut => write!(f, "'mut'"), // Keywords Token::And => write!(f, "'and'"), Token::As => write!(f, "'as'"), Token::Block => write!(f, "'block'"), - Token::Break => write!(f, "'break'"), - Token::Continue => write!(f, "'continue'"), Token::Do => write!(f, "'do'"), Token::Effect => write!(f, "'effect'"), Token::Else => write!(f, "'else'"), Token::Extern => write!(f, "'extern'"), - Token::For => write!(f, "'for'"), Token::Fn => write!(f, "'fn'"), Token::Given => write!(f, "'given'"), Token::Handle => write!(f, "'handle'"), @@ -292,14 +283,15 @@ impl Display for Token { Token::Impl => write!(f, "'impl'"), Token::Import => write!(f, "'import'"), Token::In => write!(f, "'in'"), - Token::Is => write!(f, "'is'"), - Token::Isnt => write!(f, "'isnt'"), Token::Loop => write!(f, "'loop'"), Token::Match => write!(f, "'match'"), Token::Module => write!(f, "'module'"), Token::Not => write!(f, "'not'"), Token::Or => write!(f, "'or'"), + Token::Owned => write!(f, "'owned'"), Token::Return => write!(f, "'return'"), + Token::Ref => write!(f, "'ref'"), + Token::Shared => write!(f, "'shared'"), Token::Then => write!(f, "'then'"), Token::Trait => write!(f, "'trait'"), Token::Type => write!(f, "'type'"), diff --git a/src/nameresolution/mod.rs b/src/nameresolution/mod.rs index af6a9519..ac67ec9d 100644 --- a/src/nameresolution/mod.rs +++ b/src/nameresolution/mod.rs @@ -589,7 +589,7 @@ impl<'c> NameResolver { Type::TypeVariable(_) => 0, Type::UserDefined(id) => cache[*id].args.len(), Type::TypeApplication(_, _) => 0, - Type::Ref(_) => 1, + Type::Ref(..) => 1, Type::Struct(_, _) => 0, Type::Effects(_) => 0, } @@ -739,14 +739,14 @@ impl<'c> NameResolver { Type::TypeApplication(Box::new(pair), args) }, - ast::Type::Reference(_) => { + ast::Type::Reference(sharednes, mutability, _) => { // When translating ref types, all have a hidden lifetime variable that is unified // under the hood by the compiler to determine the reference's stack lifetime. // This is never able to be manually specified by the programmer, so we use // next_type_variable_id on the cache rather than the NameResolver's version which // would add a name into scope. let lifetime_variable = cache.next_type_variable_id(self.let_binding_level); - Type::Ref(lifetime_variable) + Type::Ref(*sharednes, *mutability, lifetime_variable) }, } } diff --git a/src/parser/ast.rs b/src/parser/ast.rs index 23b51271..2aa8e402 100644 --- a/src/parser/ast.rs +++ b/src/parser/ast.rs @@ -33,6 +33,7 @@ use crate::types::typechecker::TypeBindings; use crate::types::{self, LetBindingLevel, TypeInfoId}; use std::borrow::Cow; use std::collections::{BTreeMap, HashMap, HashSet}; +use std::fmt::Display; use std::rc::Rc; #[derive(Clone, Debug, Eq, PartialOrd, Ord)] @@ -205,7 +206,7 @@ pub enum Type<'a> { Pointer(Location<'a>), Boolean(Location<'a>), Unit(Location<'a>), - Reference(Location<'a>), + Reference(Sharedness, Mutability, Location<'a>), Function(Vec>, Box>, /*varargs:*/ bool, /*closure*/ bool, Location<'a>), TypeVariable(String, Location<'a>), UserDefined(String, Location<'a>), @@ -213,6 +214,40 @@ pub enum Type<'a> { Pair(Box>, Box>, Location<'a>), } +#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub enum Sharedness { + Polymorphic, + Shared, + Owned, +} + +#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub enum Mutability { + Polymorphic, + Immutable, + Mutable, +} + +impl Display for Sharedness { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Sharedness::Polymorphic => Ok(()), + Sharedness::Shared => write!(f, "shared"), + Sharedness::Owned => write!(f, "owned"), + } + } +} + +impl Display for Mutability { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Mutability::Polymorphic => write!(f, "mut?"), + Mutability::Immutable => Ok(()), + Mutability::Mutable => write!(f, "mut"), + } + } +} + /// The AST representation of a trait usage. /// A trait's definition would be a TraitDefinition node. /// This struct is used in e.g. `given` to list the required traits. @@ -801,7 +836,7 @@ impl<'a> Locatable<'a> for Type<'a> { Type::Pointer(location) => *location, Type::Boolean(location) => *location, Type::Unit(location) => *location, - Type::Reference(location) => *location, + Type::Reference(_, _, location) => *location, Type::Function(_, _, _, _, location) => *location, Type::TypeVariable(_, location) => *location, Type::UserDefined(_, location) => *location, diff --git a/src/parser/mod.rs b/src/parser/mod.rs index 3fd28112..8b98530d 100644 --- a/src/parser/mod.rs +++ b/src/parser/mod.rs @@ -25,12 +25,14 @@ pub mod pretty_printer; use std::{collections::HashSet, iter::FromIterator}; -use crate::error::location::Location; use crate::lexer::token::Token; +use crate::{error::location::Location, parser::ast::Mutability}; use ast::{Ast, Trait, Type, TypeDefinitionBody}; use combinators::*; use error::{ParseError, ParseResult}; +use self::ast::Sharedness; + type AstResult<'a, 'b> = ParseResult<'a, 'b, Ast<'b>>; /// The entry point to parsing. Parses an entire file, printing any @@ -170,7 +172,7 @@ parser!(type_definition loc = name <- typename; args <- many0(identifier); _ <- expect(Token::Equal); - body !<- type_definition_body; + body <- type_definition_body; Ast::type_definition(name, args, body, loc) ); @@ -178,8 +180,8 @@ parser!(type_alias loc = _ <- expect(Token::Type); name <- typename; args <- many0(identifier); - _ <- expect(Token::Is); - body !<- parse_type; + _ <- expect(Token::Equal); + body <- parse_type; Ast::type_definition(name, args, TypeDefinitionBody::Alias(body), loc) ); @@ -370,8 +372,6 @@ fn precedence(token: &Token) -> Option<(i8, bool)> { Token::Or => Some((4, false)), Token::And => Some((5, false)), Token::EqualEqual - | Token::Is - | Token::Isnt | Token::NotEqual | Token::GreaterThan | Token::LessThan @@ -589,7 +589,7 @@ parser!(type_annotation loc = ); fn parse_type<'a, 'b>(input: Input<'a, 'b>) -> ParseResult<'a, 'b, Type<'b>> { - or(&[function_type, pair_type, type_application, basic_type], "type")(input) + or(&[function_type, pair_type, reference_type, type_application, basic_type], "type")(input) } fn function_arg_type<'a, 'b>(input: Input<'a, 'b>) -> ParseResult<'a, 'b, Type<'b>> { @@ -597,7 +597,7 @@ fn function_arg_type<'a, 'b>(input: Input<'a, 'b>) -> ParseResult<'a, 'b, Type<' } fn parse_type_no_pair<'a, 'b>(input: Input<'a, 'b>) -> ParseResult<'a, 'b, Type<'b>> { - or(&[function_type, type_application, basic_type], "type")(input) + or(&[function_type, reference_type, type_application, basic_type], "type")(input) } fn basic_type<'a, 'b>(input: Input<'a, 'b>) -> ParseResult<'a, 'b, Type<'b>> { @@ -611,7 +611,7 @@ fn basic_type<'a, 'b>(input: Input<'a, 'b>) -> ParseResult<'a, 'b, Type<'b>> { Token::PointerType => pointer_type(input), Token::BooleanType => boolean_type(input), Token::UnitType => unit_type(input), - Token::Ref => reference_type(input), + Token::Ampersand => basic_reference_type(input), Token::Identifier(_) => type_variable(input), Token::TypeName(_) => user_defined_type(input), Token::ParenthesisLeft => parenthesized_type(input), @@ -861,10 +861,44 @@ parser!(unit_type loc -> 'b Type<'b> = ); parser!(reference_type loc -> 'b Type<'b> = - _ <- expect(Token::Ref); - Type::Reference(loc) + _ <- expect(Token::Ampersand); + sharedness <- sharedness; + mutability <- maybe(expect(Token::Mut)); + element <- maybe(reference_element_type); + { + let mutability = if mutability.is_some() { Mutability::Mutable } else { Mutability::Immutable }; + make_reference_type(Type::Reference(sharedness, mutability, loc), element, loc) + } +); + +// The basic reference type `&t` can be used without parenthesis in a type application +parser!(basic_reference_type loc -> 'b Type<'b> = + _ <- expect(Token::Ampersand); + element <- maybe(basic_type); + make_reference_type(Type::Reference(Sharedness::Polymorphic, Mutability::Immutable, loc), element, loc) +); + +parser!(reference_element_type loc -> 'b Type<'b> = + typ <- or(&[type_application, basic_type], "type"); + typ ); +fn make_reference_type<'b>(reference: Type<'b>, element: Option>, loc: Location<'b>) -> Type<'b> { + match element { + Some(element) => Type::TypeApplication(Box::new(reference), vec![element], loc), + None => reference, + } +} + +// Parses 'owned' or 'shared' on a reference type +fn sharedness<'a, 'b>(input: Input<'a, 'b>) -> ParseResult<'a, 'b, Sharedness> { + match input[0].0 { + Token::Shared => Ok((&input[1..], Sharedness::Shared, input[0].1)), + Token::Owned => Ok((&input[1..], Sharedness::Owned, input[0].1)), + _ => Ok((input, Sharedness::Polymorphic, input[0].1)), + } +} + parser!(type_variable loc -> 'b Type<'b> = name <- identifier; Type::TypeVariable(name, loc) diff --git a/src/parser/pretty_printer.rs b/src/parser/pretty_printer.rs index 097f7c54..251eff91 100644 --- a/src/parser/pretty_printer.rs +++ b/src/parser/pretty_printer.rs @@ -1,7 +1,7 @@ //! Defines a simple pretty printer to print the Ast to stdout. //! Used for the golden tests testing parsing to ensure there //! are no parsing regressions. -use crate::parser::ast::{self, Ast}; +use crate::parser::ast::{self, Ast, Sharedness}; use crate::util::{fmap, join_with}; use std::fmt::{self, Display, Formatter}; use std::sync::atomic::AtomicUsize; @@ -98,7 +98,10 @@ impl<'a> Display for ast::Type<'a> { Pointer(_) => write!(f, "Ptr"), Boolean(_) => write!(f, "Bool"), Unit(_) => write!(f, "Unit"), - Reference(_) => write!(f, "Ref"), + Reference(shared, mutable, _) => { + let space = if *shared == Sharedness::Polymorphic { "" } else { " " }; + write!(f, "&{shared}{space}{mutable}") + }, TypeVariable(name, _) => write!(f, "{}", name), UserDefined(name, _) => write!(f, "{}", name), Function(params, return_type, varargs, is_closure, _) => { diff --git a/src/types/mod.rs b/src/types/mod.rs index b0c442aa..05f7422c 100644 --- a/src/types/mod.rs +++ b/src/types/mod.rs @@ -9,6 +9,7 @@ use std::collections::BTreeMap; use crate::cache::{DefinitionInfoId, ModuleCache}; use crate::error::location::{Locatable, Location}; use crate::lexer::token::{FloatKind, IntegerKind}; +use crate::parser::ast::{Mutability, Sharedness}; use crate::util::fmap; use crate::{lifetimes, util}; @@ -102,7 +103,7 @@ pub enum Type { /// A region-allocated reference to some data. /// Contains a region variable that is unified with other refs during type /// inference. All these refs will be allocated in the same region. - Ref(lifetimes::LifetimeVariableId), + Ref(Sharedness, Mutability, lifetimes::LifetimeVariableId), /// A (row-polymorphic) struct type. Unlike normal rho variables, /// the type variable used here replaces the entire type if bound. @@ -190,7 +191,7 @@ impl Type { use Type::*; match self { Primitive(_) => None, - Ref(_) => None, + Ref(..) => None, Function(function) => function.return_type.union_constructor_variants(cache), TypeApplication(typ, _) => typ.union_constructor_variants(cache), UserDefined(id) => cache.type_infos[id.0].union_variants(), @@ -231,7 +232,7 @@ impl Type { function.environment.traverse_rec(cache, f); function.return_type.traverse_rec(cache, f); }, - Type::TypeVariable(id) | Type::Ref(id) => match &cache.type_bindings[id.0] { + Type::TypeVariable(id) | Type::Ref(_, _, id) => match &cache.type_bindings[id.0] { TypeBinding::Bound(binding) => binding.traverse_rec(cache, f), TypeBinding::Unbound(_, _) => (), }, @@ -273,7 +274,7 @@ impl Type { Type::Primitive(_) => (), Type::UserDefined(_) => (), Type::TypeVariable(_) => (), - Type::Ref(_) => (), + Type::Ref(..) => (), Type::Function(function) => { for parameter in &function.parameters { @@ -324,7 +325,7 @@ impl Type { let args = fmap(args, |arg| arg.approx_to_string()); format!("({} {})", constructor, args.join(" ")) }, - Type::Ref(id) => format!("(ref tv{})", id.0), + Type::Ref(shared, mutable, id) => format!("&'{} {}{}", id.0, shared, mutable), Type::Struct(fields, id) => { let fields = fmap(fields, |(name, typ)| format!("{}: {}", name, typ.approx_to_string())); format!("{{ {}, ..tv{} }}", fields.join(", "), id.0) diff --git a/src/types/typechecker.rs b/src/types/typechecker.rs index 04e8fa7c..6d6853e1 100644 --- a/src/types/typechecker.rs +++ b/src/types/typechecker.rs @@ -30,7 +30,7 @@ use crate::cache::{DefinitionInfoId, DefinitionKind, EffectInfoId, ModuleCache, use crate::cache::{ImplScopeId, VariableId}; use crate::error::location::{Locatable, Location}; use crate::error::{Diagnostic, DiagnosticKind as D, TypeErrorKind, TypeErrorKind as TE}; -use crate::parser::ast::{self, ClosureEnvironment}; +use crate::parser::ast::{self, ClosureEnvironment, Mutability, Sharedness}; use crate::types::traits::{RequiredTrait, TraitConstraint, TraitConstraints}; use crate::types::typed::Typed; use crate::types::EffectSet; @@ -154,7 +154,7 @@ pub fn type_application_bindings(info: &TypeInfo<'_>, typeargs: &[Type], cache: /// Given `a` returns `ref a` fn ref_of(typ: Type, cache: &mut ModuleCache) -> Type { let new_var = next_type_variable_id(cache); - let constructor = Box::new(Type::Ref(new_var)); + let constructor = Box::new(Type::Ref(Sharedness::Polymorphic, Mutability::Polymorphic, new_var)); TypeApplication(constructor, vec![typ]) } @@ -204,10 +204,13 @@ pub fn replace_all_typevars_with_bindings( UserDefined(id) => UserDefined(*id), // We must recurse on the lifetime variable since they are unified as normal type variables - Ref(lifetime) => match replace_typevar_with_binding(*lifetime, new_bindings, Ref, cache) { - TypeVariable(new_lifetime) => Ref(new_lifetime), - Ref(new_lifetime) => Ref(new_lifetime), - _ => unreachable!("Bound Ref lifetime to non-lifetime type"), + Ref(sharedness, mutability, lifetime) => { + let make_ref = |new_lifetime| Ref(*sharedness, *mutability, new_lifetime); + match replace_typevar_with_binding(*lifetime, new_bindings, make_ref, cache) { + TypeVariable(new_lifetime) => make_ref(new_lifetime), + new_ref @ Ref(..) => new_ref, + _ => unreachable!("Bound Ref lifetime to non-lifetime type"), + } }, TypeApplication(typ, args) => { @@ -240,7 +243,7 @@ pub fn replace_all_typevars_with_bindings( /// `default` should be either TypeVariable or Ref and controls which kind of type gets /// created that wraps the newly-instantiated TypeVariableId if one is made. fn replace_typevar_with_binding( - id: TypeVariableId, new_bindings: &mut TypeBindings, default: fn(TypeVariableId) -> Type, + id: TypeVariableId, new_bindings: &mut TypeBindings, default: impl FnOnce(TypeVariableId) -> Type, cache: &mut ModuleCache<'_>, ) -> Type { if let Bound(typ) = &cache.type_bindings[id.0] { @@ -249,8 +252,9 @@ fn replace_typevar_with_binding( var.clone() } else { let new_typevar = next_type_variable_id(cache); - new_bindings.insert(id, default(new_typevar)); - default(new_typevar) + let typ = default(new_typevar); + new_bindings.insert(id, typ.clone()); + typ } } @@ -275,10 +279,13 @@ pub fn bind_typevars(typ: &Type, type_bindings: &TypeBindings, cache: &ModuleCac }, UserDefined(id) => UserDefined(*id), - Ref(lifetime) => match bind_typevar(*lifetime, type_bindings, Ref, cache) { - TypeVariable(new_lifetime) => Ref(new_lifetime), - Ref(new_lifetime) => Ref(new_lifetime), - _ => unreachable!("Bound Ref lifetime to non-lifetime type"), + Ref(sharedness, mutability, lifetime) => { + let make_ref = |lifetime| Ref(*sharedness, *mutability, lifetime); + match bind_typevar(*lifetime, type_bindings, make_ref, cache) { + TypeVariable(new_lifetime) => make_ref(new_lifetime), + new_ref @ Ref(..) => new_ref, + _ => unreachable!("Bound Ref lifetime to non-lifetime type"), + } }, TypeApplication(typ, args) => { @@ -322,7 +329,8 @@ pub fn bind_typevars(typ: &Type, type_bindings: &TypeBindings, cache: &ModuleCac /// and it is found in the type_bindings. If a type_binding wasn't found, a /// default TypeVariable or Ref is constructed by passing the relevant constructor to `default`. fn bind_typevar( - id: TypeVariableId, type_bindings: &TypeBindings, default: fn(TypeVariableId) -> Type, cache: &ModuleCache<'_>, + id: TypeVariableId, type_bindings: &TypeBindings, default: impl FnOnce(TypeVariableId) -> Type, + cache: &ModuleCache<'_>, ) -> Type { // TODO: This ordering of checking type_bindings first is important. // There seems to be an issue currently where forall-bound variables @@ -356,7 +364,7 @@ pub fn contains_any_typevars_from_list(typ: &Type, list: &[TypeVariableId], cach || contains_any_typevars_from_list(&function.effects, list, cache) }, - Ref(lifetime) => type_variable_contains_any_typevars_from_list(*lifetime, list, cache), + Ref(_, _, lifetime) => type_variable_contains_any_typevars_from_list(*lifetime, list, cache), TypeApplication(typ, args) => { contains_any_typevars_from_list(typ, list, cache) @@ -555,7 +563,7 @@ pub(super) fn occurs( .then_all(&function.parameters, |param| occurs(id, level, param, bindings, fuel, cache)), TypeApplication(typ, args) => occurs(id, level, typ, bindings, fuel, cache) .then_all(args, |arg| occurs(id, level, arg, bindings, fuel, cache)), - Ref(lifetime) => typevars_match(id, level, *lifetime, bindings, fuel, cache), + Ref(_, _, lifetime) => typevars_match(id, level, *lifetime, bindings, fuel, cache), Struct(fields, var_id) => typevars_match(id, level, *var_id, bindings, fuel, cache) .then_all(fields.iter().map(|(_, typ)| typ), |field| occurs(id, level, field, bindings, fuel, cache)), Effects(effects) => effects.occurs(id, level, bindings, fuel, cache), @@ -582,7 +590,7 @@ pub(super) fn typevars_match( /// Returns what a given type is bound to, following all typevar links until it reaches an Unbound one. pub fn follow_bindings_in_cache_and_map(typ: &Type, bindings: &UnificationBindings, cache: &ModuleCache<'_>) -> Type { match typ { - TypeVariable(id) | Ref(id) => match find_binding(*id, bindings, cache) { + TypeVariable(id) | Ref(_, _, id) => match find_binding(*id, bindings, cache) { Bound(typ) => follow_bindings_in_cache_and_map(&typ, bindings, cache), Unbound(..) => typ.clone(), }, @@ -592,7 +600,7 @@ pub fn follow_bindings_in_cache_and_map(typ: &Type, bindings: &UnificationBindin pub fn follow_bindings_in_cache(typ: &Type, cache: &ModuleCache<'_>) -> Type { match typ { - TypeVariable(id) | Ref(id) => match &cache.type_bindings[id.0] { + TypeVariable(id) | Ref(_, _, id) => match &cache.type_bindings[id.0] { Bound(typ) => follow_bindings_in_cache(typ, cache), Unbound(..) => typ.clone(), }, @@ -664,7 +672,15 @@ pub fn try_unify_with_bindings_inner<'b>( }, // Refs have a hidden lifetime variable we need to unify here - (Ref(a_lifetime), Ref(_)) => { + (Ref(shared1, mut1, a_lifetime), Ref(shared2, mut2, _)) => { + if shared1 != shared2 || mut1 != mut2 { + if *shared1 != Sharedness::Polymorphic && *shared2 != Sharedness::Polymorphic { + if *mut1 != Mutability::Polymorphic && *mut2 != Mutability::Polymorphic { + return Err(()); + } + } + } + try_unify_type_variable_with_bindings(*a_lifetime, t1, t2, bindings, location, cache) }, @@ -816,7 +832,7 @@ fn get_fields( } }, TypeApplication(constructor, args) => match follow_bindings_in_cache_and_map(constructor, bindings, cache) { - Ref(_) => get_fields(&args[0], &[], bindings, cache), + Ref(..) => get_fields(&args[0], &[], bindings, cache), other => get_fields(&other, args, bindings, cache), }, Struct(fields, rest) => match &cache.type_bindings[rest.0] { @@ -960,7 +976,7 @@ pub fn find_all_typevars(typ: &Type, polymorphic_only: bool, cache: &ModuleCache } type_variables }, - Ref(lifetime) => find_typevars_in_typevar_binding(*lifetime, polymorphic_only, cache), + Ref(_, _, lifetime) => find_typevars_in_typevar_binding(*lifetime, polymorphic_only, cache), Struct(fields, id) => match &cache.type_bindings[id.0] { Bound(t) => find_all_typevars(t, polymorphic_only, cache), Unbound(..) => { @@ -1675,7 +1691,9 @@ impl<'a> Inferable<'a> for ast::Definition<'a> { let mut result = infer(self.expr.as_mut(), cache); if self.mutable { let lifetime = next_type_variable_id(cache); - result.typ = Type::TypeApplication(Box::new(Type::Ref(lifetime)), vec![result.typ]); + let shared = Sharedness::Polymorphic; + let mutability = Mutability::Mutable; + result.typ = Type::TypeApplication(Box::new(Type::Ref(shared, mutability, lifetime)), vec![result.typ]); } // The rhs of a Definition must be inferred at a greater LetBindingLevel than @@ -1942,7 +1960,8 @@ impl<'a> Inferable<'a> for ast::Assignment<'a> { result.combine(&mut rhs, cache); let lifetime = next_type_variable_id(cache); - let mutref = Type::TypeApplication(Box::new(Type::Ref(lifetime)), vec![rhs.typ.clone()]); + let mut_ref = Type::Ref(Sharedness::Polymorphic, Mutability::Mutable, lifetime); + let mutref = Type::TypeApplication(Box::new(mut_ref), vec![rhs.typ.clone()]); match try_unify(&result.typ, &mutref, self.location, cache, TE::NeverShown) { Ok(bindings) => bindings.perform(cache), @@ -1959,7 +1978,8 @@ fn issue_assignment_error<'c>( // Try to offer a more specific error message let lifetime = next_type_variable_id(cache); let var = next_type_variable(cache); - let mutref = Type::TypeApplication(Box::new(Type::Ref(lifetime)), vec![var]); + let mutref = Type::Ref(Sharedness::Polymorphic, Mutability::Mutable, lifetime); + let mutref = Type::TypeApplication(Box::new(mutref), vec![var]); if let Err(msg) = try_unify(&mutref, lhs, lhs_loc, cache, TE::AssignToNonMutRef) { cache.push_full_diagnostic(msg); diff --git a/src/types/typeprinter.rs b/src/types/typeprinter.rs index 4d50606d..0b97a484 100644 --- a/src/types/typeprinter.rs +++ b/src/types/typeprinter.rs @@ -4,6 +4,7 @@ //! types/traits are displayed via `type.display(cache)` rather than directly having //! a Display impl. use crate::cache::{ModuleCache, TraitInfoId}; +use crate::parser::ast::{Mutability, Sharedness}; use crate::types::traits::{ConstraintSignature, ConstraintSignaturePrinter, RequiredTrait, TraitConstraintId}; use crate::types::typechecker::find_all_typevars; use crate::types::{FunctionType, PrimitiveType, Type, TypeBinding, TypeInfoId, TypeVariableId}; @@ -165,7 +166,7 @@ impl<'a, 'b> TypePrinter<'a, 'b> { Type::TypeVariable(id) => self.fmt_type_variable(*id, f), Type::UserDefined(id) => self.fmt_user_defined_type(*id, f), Type::TypeApplication(constructor, args) => self.fmt_type_application(constructor, args, f), - Type::Ref(lifetime) => self.fmt_ref(*lifetime, f), + Type::Ref(shared, mutable, lifetime) => self.fmt_ref(*shared, *mutable, *lifetime, f), Type::Struct(fields, rest) => self.fmt_struct(fields, *rest, f), Type::Effects(effects) => self.fmt_effects(effects, f), } @@ -277,11 +278,17 @@ impl<'a, 'b> TypePrinter<'a, 'b> { } } - fn fmt_ref(&self, lifetime: TypeVariableId, f: &mut Formatter) -> std::fmt::Result { + fn fmt_ref( + &self, shared: Sharedness, mutable: Mutability, lifetime: TypeVariableId, f: &mut Formatter, + ) -> std::fmt::Result { match &self.cache.type_bindings[lifetime.0] { TypeBinding::Bound(typ) => self.fmt_type(typ, f), TypeBinding::Unbound(..) => { - write!(f, "{}", "ref".blue())?; + let shared = shared.to_string(); + let mutable = mutable.to_string(); + let space = if shared.is_empty() { "" } else { " " }; + + write!(f, "{}{}{}{}", "&".blue(), shared.blue(), space, mutable.blue())?; if self.debug { match self.typevar_names.get(&lifetime) { diff --git a/stdlib/HashMap.an b/stdlib/HashMap.an index 9a94f4d4..7f208db8 100644 --- a/stdlib/HashMap.an +++ b/stdlib/HashMap.an @@ -15,14 +15,14 @@ trait Hash t with empty () = HashMap 0 0 (null ()) -clear (map: Ref (HashMap k v)) : Unit = +clear (map: &mut HashMap k v) : Unit = if map.capacity != 0 then repeat map.capacity fn i -> entry = mut deref_ptr <| offset map.entries i entry.&occupied := false entry.&tombstone := false -resize (map: Ref (HashMap k v)) (new_capacity: Usz) : Unit = +resize (map: &mut HashMap k v) (new_capacity: Usz) : Unit = if new_capacity > map.capacity then new_memory = calloc new_capacity (size_of (MkType : Type (Entry k v))) @@ -44,16 +44,16 @@ should_resize (map: HashMap k v) : Bool = scale_factor = 2 (map.len + 1) * scale_factor > map.capacity -insert (map: Ref (HashMap k v)) (key: k) (value: v) : Unit = +insert (map: &mut HashMap k v) (key: k) (value: v) : Unit = if should_resize @map then resize map ((map.capacity + 1) * 2) - iter_until (map: Ref (HashMap k v)) (key: k) (value: v) (start: Usz) (end: Usz) : Bool = + iter_until (map: &mut HashMap k v) (key: k) (value: v) (start: Usz) (end: Usz) : Bool = if start >= end then false else entry_ptr = offset (map.entries) start - entry = transmute entry_ptr : Ref (Entry k v) + entry = transmute entry_ptr : &mut Entry k v if entry.occupied then iter_until map key value (start + 1) end else @@ -105,7 +105,7 @@ get (map: HashMap k v) (key: k) : Maybe v = | None -> None -remove (map: ref (HashMap k v)) (key: k) : Maybe v = +remove (map: &mut HashMap k v) (key: k) : Maybe v = match get_entry (@map) key | None -> None | Some e2 -> diff --git a/stdlib/StringBuilder.an b/stdlib/StringBuilder.an index 5bb49b7f..d21fbfa1 100644 --- a/stdlib/StringBuilder.an +++ b/stdlib/StringBuilder.an @@ -6,7 +6,7 @@ type StringBuilder = empty () = StringBuilder (null ()) 0 0 // reserve space for at least n additional characters -reserve (s: Ref StringBuilder) (n: Usz) : Unit = +reserve (s: &mut StringBuilder) (n: Usz) : Unit = if s.length + n > s.cap then new_size = s.cap + n ptr = realloc s.data new_size @@ -19,15 +19,15 @@ reserve (s: Ref StringBuilder) (n: Usz) : Unit = s.&cap := new_size // append a string -append (s: Ref StringBuilder) (new_str: String) : Unit = +append (s: &mut StringBuilder) (new_str: String) : Unit = reserve s new_str.length memcpy (cast (cast s.data + s.length)) new_str.c_string (cast (new_str.length+1)) s.&length := s.length + new_str.length // convert to string -to_string (s: Ref StringBuilder) : String = +to_string (s: StringBuilder) : String = String (s.data) (s.length) impl Print StringBuilder with - printne (s: Ref StringBuilder) : Unit = + printne (s: StringBuilder) : Unit = print (to_string s) diff --git a/stdlib/Vec.an b/stdlib/Vec.an index f50a9af7..583b77d9 100644 --- a/stdlib/Vec.an +++ b/stdlib/Vec.an @@ -18,12 +18,12 @@ len v = v.len capacity v = v.cap //Fill Vec with items from the iterable -fill (v: Ref (Vec t)) iterable : Ref (Vec t) = +fill (v: &mut Vec t) iterable : &mut Vec t = iter iterable (push v _) v //reserve numElements in Vec v, elements will be uninitialized -reserve (v: Ref (Vec t)) (numElems: Usz) : Unit = +reserve (v: &mut Vec t) (numElems: Usz) : Unit = if v.len + numElems > v.cap then size = (v.cap + numElems) * size_of (MkType: Type t) ptr = realloc (v.data) size @@ -37,7 +37,7 @@ reserve (v: Ref (Vec t)) (numElems: Usz) : Unit = //push an element onto the end of the vector. //resizes if necessary -push (v: Ref (Vec t)) (elem: t) : Unit = +push (v: &mut Vec t) (elem: t) : Unit = if v.len >= v.cap then reserve v (if v.cap == 0usz then 1 else v.cap) @@ -46,7 +46,7 @@ push (v: Ref (Vec t)) (elem: t) : Unit = //pop the last element off if it exists //this will never resize the vector. -pop (v: Ref (Vec t)) : Maybe t = +pop (v: &mut Vec t) : Maybe t = if is_empty v then None else v.&len := v.len - 1 @@ -54,7 +54,7 @@ pop (v: Ref (Vec t)) : Maybe t = //remove the element at the given index and return it. //will error if the index is out of bounds. -remove_index (v: Ref (Vec t)) (idx:Usz) : t = +remove_index (v: &mut Vec t) (idx:Usz) : t = if idx == v.len - 1 then v.&len := v.len - 1 else if idx >= 0 and idx < v.len - 1 then @@ -71,7 +71,7 @@ remove_index (v: Ref (Vec t)) (idx:Usz) : t = //the vector or none if the element was not found. //Uses == to determine element equality. //returns the index where the element was found. -remove_first (v: Ref (Vec t)) (elem: t) : Maybe Usz = +remove_first (v: &mut Vec t) (elem: t) : Maybe Usz = loop (i = 0) -> if i >= v.len then None @@ -83,7 +83,7 @@ remove_first (v: Ref (Vec t)) (elem: t) : Maybe Usz = //Remove the given indices from the vector //Expects the indices to be in sorted order. //Will error if any index is out of bounds. -remove_indices (v: Ref (Vec t)) (idxs: Vec Usz) : Unit = +remove_indices (v: &mut Vec t) (idxs: Vec Usz) : Unit = moved = mut 0 iter (indices idxs) fn i -> cur = idxs.data#i @@ -105,7 +105,7 @@ remove_indices (v: Ref (Vec t)) (idxs: Vec Usz) : Unit = //remove all matching elements from the vector and //return the number of elements removed. //Uses = to determine element equality. -remove_all (v: Ref (Vec t)) (elem: t) : Usz = +remove_all (v: &mut Vec t) (elem: t) : Usz = idxs = mut empty () iter (indices @v) fn i -> if elem == v.data#i then @@ -118,7 +118,7 @@ remove_all (v: Ref (Vec t)) (elem: t) : Usz = //Remove an element by swapping it with the last element in O(1) time. //Returns true if a swap was performed or false otherwise. //Will not swap if the given index is the index of the last element. -swap_last (v: Ref (Vec t)) (idx:Usz) : Bool = +swap_last (v: &mut Vec t) (idx:Usz) : Bool = if idx >= v.len or idx + 1 == v.len then false else v.&len := v.len - 1 diff --git a/stdlib/prelude.an b/stdlib/prelude.an index 4f735f34..a65d4ead 100644 --- a/stdlib/prelude.an +++ b/stdlib/prelude.an @@ -321,20 +321,20 @@ offset (ptr: Ptr t) (index: Usz) : Ptr t = // new_addr = addr + index * size_of (MkType: Type t) // transmute new_addr -deref (x: Ref t) : t = +deref (x: &t) : t = builtin "Deref" x deref_ptr (p: Ptr t) : t = deref <| transmute p ptr_store (p: Ptr a) (value: a) : Unit = - addr: Ref a = transmute p + addr: &a = transmute p addr := value array_insert (p: Ptr a) (index: Usz) (value: a) : Unit = offset p index |> ptr_store value -ptr_to_ref: Ptr a -> Ref a = transmute +ptr_to_ref: Ptr a -> &a = transmute (@) = deref @@ -416,7 +416,7 @@ impl Print (Maybe a) given Print a with | Some a -> printne a | None -> printne "None" -impl Print (Ref a) given Print a with +impl Print &a given Print a with printne r = printne @r unwrap (m: Maybe t) : t = @@ -517,7 +517,7 @@ with // // Some (if isNeg then -1 * sum else sum) -impl Eq (Ref t) given Eq t with +impl Eq &t given Eq t with (==) l r = deref l == deref r // This `given` clause is required by the current trait checker since t is generalized @@ -559,7 +559,7 @@ impl Iterator InFile String with else Some (infile, next_line infile) // TODO: manually construct a String from parts -// impl Cast (Ref Char) String with +// impl Cast &Char String with // cast c_string = String c_string (cast (strlen c_string)) // impl Eq String