diff --git a/Compiler/Ast.fir b/Compiler/Ast.fir index 6337d053..fa422f17 100644 --- a/Compiler/Ast.fir +++ b/Compiler/Ast.fir @@ -1,6 +1,7 @@ import [ Compiler/Defs, Compiler/Token, + Compiler/TypeCheck/Ty, ] @@ -98,18 +99,6 @@ type ConDecl( ) -#[derive(ToDoc)] -type Kind: - Star - Row(RecordOrVariant) - - -#[derive(ToDoc)] -type RecordOrVariant: - Record - Variant - - #[derive(ToDoc)] type ConFields: Empty @@ -194,10 +183,10 @@ type TraitDecl( ## ## ``` ## impl[ToStr[a]] ToStr[Vec[a]]: -## toStr(self): Str = ... +## toStr(self) Str: ... ## ## impl Iterator[VecIter[a], a]: -## next(self): Option[a] = ... +## next(self) Option[a]: ... ## ``` #[derive(ToDoc)] type ImplDecl( @@ -384,6 +373,12 @@ type NamedType( ) +NamedType.numTyArgs(self) U32: + match self.args: + Option.None: 0 + Option.Some(args): args.args.len() + + impl Tokens[NamedType]: firstToken(self: NamedType) TokenIdx: self.name.token @@ -499,14 +494,12 @@ type ForStmt( astTy: Option[Type], ## `ast_ty`, converted to type checking types by the type checker. - tcTy: Option[Type], # TODO: type-checking type - + tcTy: Option[Ty], expr: Expr, ## Filled in by the type checker: the iterator type. `iter` in `Iterator[iter, item]`. - exprTy: Option[Type], # TODO: type-checking type - + exprTy: Option[Ty], body: Vec[Stmt], @@ -651,8 +644,7 @@ type VarPat( var_: Id, ## Inferred type of the binder. Filled in by the type checker. - ty: Option[Type], # TODO: This should be type-checking type instead of AST type - + ty: Option[Ty], ) @@ -685,9 +677,7 @@ impl Tokens[ConstrPat]: type RecordPat( fields: Vec[Named[Pat]], ignoreRest: Bool, - - # TODO: This should be type-checking type instead of AST type - inferredTy: Option[Type], + inferredTy: Option[Ty], _firstToken: TokenIdx, _lastToken: TokenIdx, ) @@ -947,8 +937,7 @@ type MethodSelectExpr( # TODO: We could also add types to every expression if it's going to help with monomorphisation. # For efficiency though, we should only annotate inferred types and then type check from # the top-level expression every time we need to compute type of an expr. - # TODO: This should be a type-checking type. - objectTy: Option[Type], + objectTy: Option[Ty], ## The type or trait id that defines the method. ## @@ -964,14 +953,13 @@ type MethodSelectExpr( ## Type arguments of `method_ty_id`. ## - ## If the method is for a trait, the first arguments here will be for the trait type parameters. - ## E.g. in `Iterator.next`, the first two argumetns will be the `iter` and `item` parameters of - ## `trait Iterator[iter, item]`. + ## If the method is for a trait, the first arguments here will be for the + ## trait type parameters. E.g. in `Iterator.next`, the first two argumetns + ## will be the `iter` and `item` parameters of `trait Iterator[iter, item]`. ## - ## (If the method is not a trait method, then we don't care about the type parameter order.. I - ## think?) - tyArgs: Vec[Type], # TODO: type-checking type - + ## (If the method is not a trait method, then we don't care about the type + ## parameter order.. I think?) + tyArgs: Vec[Ty], _lastToken: TokenIdx, ) @@ -1001,9 +989,9 @@ type AssocFnSelectExpr( ## Type arguments explicitly passed to the variable. userTyArgs: Option[TyArgs], - ## Inferred type arguments of the type and associated function. Filled in by the type checker. - tyArgs: Vec[Type], # TODO: type-checking type - + ## Inferred type arguments of the type and associated function. Filled in by + ## the type checker. + tyArgs: Vec[Ty], _lastToken: TokenIdx, ) @@ -1506,24 +1494,6 @@ impl ToDoc[ConDecl]: Doc.grouped(Doc.str("ConDecl") + Doc.char('(') + args) -impl ToDoc[Kind]: - toDoc(self: Kind) Doc: - match self: - Kind.Star: Doc.str("Kind.Star") - Kind.Row(i0): - let args = Doc.break_(0) - args += i0.toDoc() - args = args.nest(4).group() + Doc.break_(0) + Doc.char(')') - Doc.grouped(Doc.str("Kind.Row") + Doc.char('(') + args) - - -impl ToDoc[RecordOrVariant]: - toDoc(self: RecordOrVariant) Doc: - match self: - RecordOrVariant.Record: Doc.str("RecordOrVariant.Record") - RecordOrVariant.Variant: Doc.str("RecordOrVariant.Variant") - - impl ToDoc[ConFields]: toDoc(self: ConFields) Doc: match self: diff --git a/Compiler/Defs.fir b/Compiler/Defs.fir index d4f9e8d9..ffca8a38 100644 --- a/Compiler/Defs.fir +++ b/Compiler/Defs.fir @@ -1,26 +1,52 @@ ## Defines identifiers, packages, modules, module-level items. -import [Compiler/Ast] +import [ + Compiler/Ast, + Compiler/Module, + Compiler/TypeCheck/TyCon, +] # ------------------------------------------------------------------------------ # Identifiers +## Local ids refer to function and type -local things, they cannot refer to +## top-level things. +## +## This type is used both in binder position (e.g. function arguments, patterns, +## type arguments) and use position (e.g. expressions, type variables). +#[derive(ToDoc)] +type LocalId( + name: Str, +) + + +impl Eq[LocalId]: + __eq(self: LocalId, other: LocalId) Bool: + self.name == other.name + + +impl Hash[LocalId]: + hash(self: LocalId) U32: + self.name.hash() + + ## A term variable id. ## ## This can refer to: local variables, top-level variables, associated ## functions. #[derive(ToDoc)] type VarId( + ## Index of the identifier's token in its module. token: TokenIdx, - _resolved: Option[VarDefIdx], -) + ## Text of the token at `token`. + name: Str, -VarId.new(token: TokenIdx) VarId: - VarId(token, _resolved = Option.None) + _resolved: Option[VarDefIdx], +) VarId.resolve(self, def: VarDefIdx): @@ -30,7 +56,12 @@ VarId.resolve(self, def: VarDefIdx): #[derive(ToDoc)] type AssocVarId( + ## Index of the associated member identifier's token in its module. token: TokenIdx, + + ## Text of the token at `token`. + name: Str, + _resolved: Option[AssocVarDefIdx], ) @@ -40,13 +71,14 @@ type AssocVarId( ## These can refer to: product and sum types, traits. #[derive(ToDoc)] type TyId( + ## Index of the type identifier's token in its module. token: TokenIdx, - _resolved: Option[TyDefIdx], -) + ## Text of the token at `token`. + name: Str, -TyId.new(token: TokenIdx) TyId: - TyId(token, _resolved = Option.None) + _resolved: Option[TyDefIdx], +) TyId.resolve(self, def: TyDefIdx): @@ -54,84 +86,8 @@ TyId.resolve(self, def: TyDefIdx): self._resolved = Option.Some(def) -# ------------------------------------------------------------------------------ -# Packages and modules - - -#[derive(ToDoc)] -type PackageUri( - _uri: Str, -) - - -#[derive(ToDoc)] -type ModuleIdx( - ## Index of the module in the program. - _idx: U32, -) - - -impl Eq[ModuleIdx]: - __eq(self: ModuleIdx, other: ModuleIdx) Bool: - self._idx == other._idx - - -impl Hash[ModuleIdx]: - hash(self: ModuleIdx) U32: - self._idx - - -#[derive(ToDoc)] -type Module( - ## The module's package. - _package: PackageUri, - - ## File path of the module's file. - # TODO: Is this absolute? Or relative? If relative, relative to what? - # For now we only use this for errors so it doesn't matter too much. - _filePath: Str, - - ## Path of the module in its package. - _path: Vec[Str], - - _idx: ModuleIdx, - - ## Functions defined in the module. - _funItems: HashMap[Str, FunDecl], - - ## Associated functions (including methods, but not trait methods) defined - ## in the module. - _assocItems: HashMap[Str, HashMap[Str, FunDecl]], - - ## Types defined in the module. - _tyItems: HashMap[Str, TypeDecl], - - ## Traits defined in the module. - _traitItems: HashMap[Str, TraitDecl], - - ## Impls defined in the module. - _impls: Vec[ImplDecl], - - ## Imported modules. These are used when resolving identifiers. - _imports: Vec[ModuleIdx], - - ## Tokens of the module. ASTs in the other fields for definitions in the - ## current module refer to tokens in this array. - _tokens: Array[Token], - - ## Term environment of the module, used for name resolving. - _termEnv: HashMap[Str, VarDefIdx], - - ## Associated item environment of the module, used for name resolving. - _assocTermEnv: HashMap[Str, HashMap[Str, AssocVarDefIdx]], - - ## Type environment of the module, used for name resolving. - _tyEnv: HashMap[Str, TyDefIdx], - - ## Strongly connected component index of the module in the program. - ## Generated by the dependency analysis. - _sccIdx: Option[U32], -) +TyId.def(self) TyDefIdx: + self._resolved.unwrapOrElse(||: panic("Type id is not resolved yet")) # ------------------------------------------------------------------------------ @@ -198,50 +154,50 @@ impl Eq[AssocVarDefIdx]: #[derive(ToDoc)] -type TyDefIdx: - Trait(TraitDefIdx) - Type(TyDefIdx_) - - -impl Eq[TyDefIdx]: - __eq(self: TyDefIdx, other: TyDefIdx) Bool: - match (left = self, right = other): - (left = TyDefIdx.Trait(l1), right = TyDefIdx.Trait(r1)): l1 == r1 - (left = TyDefIdx.Type(l1), right = TyDefIdx.Type(r1)): l1 == r1 - _: Bool.False - - -#[derive(ToDoc)] -type TraitDefIdx( - ## The trait's module. +type TyDefIdx( + ## The type's module. _mod: ModuleIdx, - ## Name of the trait in `_mod`. + ## Name of the type in `_mod`. _name: Str, -) - -impl Eq[TraitDefIdx]: - __eq(self: TraitDefIdx, other: TraitDefIdx) Bool: - self._mod == other._mod and self._name == other._name + ## Kind of the definition: type or trait. + _kind: TyDefKind, +) #[derive(ToDoc)] -type TyDefIdx_( - ## The type's module. - _mod: ModuleIdx, +type TyDefKind: + Type + Trait - ## Name of the type in `_mod`. - _name: Str, -) +impl Eq[TyDefIdx]: + __eq(self: TyDefIdx, other: TyDefIdx) Bool: + self._mod == other._mod + and self._name == other._name + and self._kind == other._kind -impl Eq[TyDefIdx_]: - __eq(self: TyDefIdx_, other: TyDefIdx_) Bool: - self._mod == other._mod and self._name == other._name + +impl Eq[TyDefKind]: + __eq(self: TyDefKind, other: TyDefKind) Bool: + match self: + TyDefKind.Type: other is TyDefKind.Type + TyDefKind.Trait: other is TyDefKind.Trait # ------------------------------------------------------------------------------ +# Generated ToDoc implementations + + +impl ToDoc[LocalId]: + toDoc(self: LocalId) Doc: + let args = Doc.break_(0) + args += Doc.grouped( + Doc.str("name =") + Doc.nested(4, Doc.break_(1) + self.name.toDoc()), + ) + args = args.nest(4).group() + Doc.break_(0) + Doc.char(')') + Doc.grouped(Doc.str("LocalId") + Doc.char('(') + args) impl ToDoc[VarId]: @@ -302,91 +258,6 @@ impl ToDoc[PackageUri]: Doc.grouped(Doc.str("PackageUri") + Doc.char('(') + args) -impl ToDoc[ModuleIdx]: - toDoc(self: ModuleIdx) Doc: - let args = Doc.break_(0) - args += Doc.grouped( - Doc.str("_idx =") + Doc.nested(4, Doc.break_(1) + self._idx.toDoc()), - ) - args = args.nest(4).group() + Doc.break_(0) + Doc.char(')') - Doc.grouped(Doc.str("ModuleIdx") + Doc.char('(') + args) - - -impl ToDoc[Module]: - toDoc(self: Module) Doc: - let args = Doc.break_(0) - args += Doc.grouped( - Doc.str("_package =") - + Doc.nested(4, Doc.break_(1) + self._package.toDoc()), - ) - args += Doc.char(',') + Doc.break_(1) - args += Doc.grouped( - Doc.str("_path =") - + Doc.nested(4, Doc.break_(1) + self._path.toDoc()), - ) - args += Doc.char(',') + Doc.break_(1) - args += Doc.grouped( - Doc.str("_idx =") + Doc.nested(4, Doc.break_(1) + self._idx.toDoc()), - ) - args += Doc.char(',') + Doc.break_(1) - args += Doc.grouped( - Doc.str("_funItems =") - + Doc.nested(4, Doc.break_(1) + self._funItems.toDoc()), - ) - args += Doc.char(',') + Doc.break_(1) - args += Doc.grouped( - Doc.str("_assocItems =") - + Doc.nested(4, Doc.break_(1) + self._assocItems.toDoc()), - ) - args += Doc.char(',') + Doc.break_(1) - args += Doc.grouped( - Doc.str("_tyItems =") - + Doc.nested(4, Doc.break_(1) + self._tyItems.toDoc()), - ) - args += Doc.char(',') + Doc.break_(1) - args += Doc.grouped( - Doc.str("_traitItems =") - + Doc.nested(4, Doc.break_(1) + self._traitItems.toDoc()), - ) - args += Doc.char(',') + Doc.break_(1) - args += Doc.grouped( - Doc.str("_impls =") - + Doc.nested(4, Doc.break_(1) + self._impls.toDoc()), - ) - args += Doc.char(',') + Doc.break_(1) - args += Doc.grouped( - Doc.str("_imports =") - + Doc.nested(4, Doc.break_(1) + self._imports.toDoc()), - ) - args += Doc.char(',') + Doc.break_(1) - args += Doc.grouped( - Doc.str("_tokens =") - + Doc.nested(4, Doc.break_(1) + self._tokens.toDoc()), - ) - args += Doc.char(',') + Doc.break_(1) - args += Doc.grouped( - Doc.str("_termEnv =") - + Doc.nested(4, Doc.break_(1) + self._termEnv.toDoc()), - ) - args += Doc.char(',') + Doc.break_(1) - args += Doc.grouped( - Doc.str("_assocTermEnv =") - + Doc.nested(4, Doc.break_(1) + self._assocTermEnv.toDoc()), - ) - args += Doc.char(',') + Doc.break_(1) - args += Doc.grouped( - Doc.str("_tyEnv =") - + Doc.nested(4, Doc.break_(1) + self._tyEnv.toDoc()), - ) - args += Doc.char(',') + Doc.break_(1) - args += Doc.grouped( - Doc.str("_sccIdx =") - + Doc.nested(4, Doc.break_(1) + self._sccIdx.toDoc()), - ) - args = args.nest(4).group() + Doc.break_(0) + Doc.char(')') - Doc.grouped(Doc.str("Module") + Doc.char('(') + args) - - impl ToDoc[VarDefIdx]: toDoc(self: VarDefIdx) Doc: match self: @@ -435,21 +306,6 @@ impl ToDoc[AssocVarDefIdx]: impl ToDoc[TyDefIdx]: toDoc(self: TyDefIdx) Doc: - match self: - TyDefIdx.Trait(i0): - let args = Doc.break_(0) - args += i0.toDoc() - args = args.nest(4).group() + Doc.break_(0) + Doc.char(')') - Doc.grouped(Doc.str("TyDefIdx.Trait") + Doc.char('(') + args) - TyDefIdx.Type(i0): - let args = Doc.break_(0) - args += i0.toDoc() - args = args.nest(4).group() + Doc.break_(0) + Doc.char(')') - Doc.grouped(Doc.str("TyDefIdx.Type") + Doc.char('(') + args) - - -impl ToDoc[TraitDefIdx]: - toDoc(self: TraitDefIdx) Doc: let args = Doc.break_(0) args += Doc.grouped( Doc.str("_mod =") + Doc.nested(4, Doc.break_(1) + self._mod.toDoc()), @@ -459,20 +315,26 @@ impl ToDoc[TraitDefIdx]: Doc.str("_name =") + Doc.nested(4, Doc.break_(1) + self._name.toDoc()), ) - args = args.nest(4).group() + Doc.break_(0) + Doc.char(')') - Doc.grouped(Doc.str("TraitDefIdx") + Doc.char('(') + args) - - -impl ToDoc[TyDefIdx_]: - toDoc(self: TyDefIdx_) Doc: - let args = Doc.break_(0) - args += Doc.grouped( - Doc.str("_mod =") + Doc.nested(4, Doc.break_(1) + self._mod.toDoc()), - ) args += Doc.char(',') + Doc.break_(1) args += Doc.grouped( - Doc.str("_name =") - + Doc.nested(4, Doc.break_(1) + self._name.toDoc()), + Doc.str("_kind =") + + Doc.nested(4, Doc.break_(1) + self._kind.toDoc()), ) args = args.nest(4).group() + Doc.break_(0) + Doc.char(')') - Doc.grouped(Doc.str("TyDefIdx_") + Doc.char('(') + args) + Doc.grouped(Doc.str("TyDefIdx") + Doc.char('(') + args) + + +impl ToDoc[TyDefKind]: + toDoc(self: TyDefKind) Doc: + match self: + TyDefKind.Type: Doc.str("TyDefKind.Type") + TyDefKind.Trait: Doc.str("TyDefKind.Trait") + + +# ------------------------------------------------------------------------------ +# ToStr implementations. These are used when generating error messages, so they +# should be readable by the users and should not expose implementation details. + +impl ToStr[LocalId]: + toStr(self: LocalId) Str: + panic("TODO") diff --git a/Compiler/Error.fir b/Compiler/Error.fir index c7ecc3f0..df87d2ce 100644 --- a/Compiler/Error.fir +++ b/Compiler/Error.fir @@ -9,8 +9,19 @@ impl ToStr[Error]: "`self.loc.file`:`self.loc.line + 1`:`self.loc.col + 1`: `self.msg`" +#[derive(ToDoc)] type Loc( file: Str, line: U32, col: U32, ) + + +impl ToStr[Loc]: + toStr(self: Loc) Str: + "`self.file`:`self.line + 1`:`self.col + 1`" + + +impl ToDoc[Loc]: + toDoc(self: Loc) Doc: + panic("TODO") diff --git a/Compiler/Grammar.fir b/Compiler/Grammar.fir index 133935a1..00a5d9d4 100644 --- a/Compiler/Grammar.fir +++ b/Compiler/Grammar.fir @@ -209,7 +209,7 @@ namedType(state: ParserState[Token]) NamedType / U32: let value = do: let args = Vec.fromIter(once(arg0).chain(args.iter())) NamedType( - name = newTyId(name), + name = newTyId(name, state._tokens), args = Option.Some( TyArgs( @@ -233,7 +233,10 @@ namedType(state: ParserState[Token]) NamedType / U32: else: throw(state.updateErrorCursor(state._cursor)) let value = do: - NamedType(name = newTyId(name), args = Option.None) + NamedType( + name = newTyId(name, state._tokens), + args = Option.None, + ) value, ) match altResult: @@ -2192,7 +2195,7 @@ simpleExpr(state: ParserState[Token]) Expr / U32: let value = do: Expr.Var( VarExpr( - id = newVarId(var_), + id = newVarId(var_, state._tokens), userTyArgs = tyArgs, tyArgs = Vec.empty(), ), @@ -2817,7 +2820,7 @@ constructor(state: ParserState[Token]) Constructor / U32: let value = do: Constructor( variant = Bool.False, - ty = newTyId(ty), + ty = newTyId(ty, state._tokens), constr = con.map(|con: U32|: newId(con)), userTyArgs, tyArgs = Vec.empty(), @@ -2875,7 +2878,7 @@ constructor(state: ParserState[Token]) Constructor / U32: let value = do: Constructor( variant = Bool.True, - ty = newTyId(ty), + ty = newTyId(ty, state._tokens), constr = con.map(|con: U32|: newId(con)), userTyArgs, tyArgs = Vec.empty(), @@ -4677,7 +4680,8 @@ topFunSig(state: ParserState[Token]) (parentTy: Option[TyId], name: Id, sig: Fun let value = do: let ret = ret.unwrapOr((ret = Option.None, exn = Option.None)) ( - parentTy = parentTy.map(|id: U32|: newTyId(id)), + parentTy = + parentTy.map(|id: U32|: newTyId(id, state._tokens)), name = newId(name), sig = FunSig( @@ -5090,7 +5094,7 @@ implDecl(state: ParserState[Token]) ImplDecl / U32: let value = do: ImplDecl( context = ctx, - trait_ = newTyId(name), + trait_ = newTyId(name, state._tokens), tys = Vec.fromIter(once(t0).chain(ts.iter())), items = rhs, _firstToken = TokenIdx(idx = first), diff --git a/Compiler/Grammar.peg b/Compiler/Grammar.peg index 0d857a03..980d94df 100644 --- a/Compiler/Grammar.peg +++ b/Compiler/Grammar.peg @@ -115,7 +115,7 @@ type_ Type: name=^ "UpperId" argFirst=^ "[" arg0=type_ args=(_"," type_)* ","? argLast=^ "]": let args = Vec.fromIter(once(arg0).chain(args.iter())) NamedType( - name = newTyId(name), + name = newTyId(name, state._tokens), args = Option.Some(TyArgs( args, _firstToken = TokenIdx(idx = argFirst), @@ -124,7 +124,7 @@ type_ Type: ) name=^ "UpperId": - NamedType(name = newTyId(name), args = Option.None) + NamedType(name = newTyId(name, state._tokens), args = Option.None) # - () @@ -516,7 +516,7 @@ simpleExpr Expr: # Variables var_=^ "LowerId" tyArgs=tyArgs?: Expr.Var(VarExpr( - id = newVarId(var_), + id = newVarId(var_, state._tokens), userTyArgs = tyArgs, tyArgs = Vec.empty(), )) @@ -705,7 +705,7 @@ constructor Constructor: ty=^ "UpperId" con=(_"." ^ _"UpperId")? userTyArgs=tyArgs?: Constructor( variant = Bool.False, - ty = newTyId(ty), + ty = newTyId(ty, state._tokens), constr = con.map(|con: U32|: newId(con)), userTyArgs, tyArgs = Vec.empty(), @@ -715,7 +715,7 @@ constructor Constructor: ty=^ "TildeUpperId" con=(_"." ^ _"UpperId")? userTyArgs=tyArgs?: Constructor( variant = Bool.True, - ty = newTyId(ty), + ty = newTyId(ty, state._tokens), constr = con.map(|con: U32|: newId(con)), userTyArgs, tyArgs = Vec.empty(), @@ -1006,7 +1006,7 @@ topFunSig (parentTy: Option[TyId], name: Id, sig: FunSig): parentTy=(^ _"UpperId" _".")? name=^ "LowerId" ctx=context? params=paramList ret=returnTy?: let ret = ret.unwrapOr((ret = Option.None, exn = Option.None)) ( - parentTy = parentTy.map(|id: U32|: newTyId(id)), + parentTy = parentTy.map(|id: U32|: newTyId(id, state._tokens)), name = newId(name), sig = FunSig( context = ctx, @@ -1076,7 +1076,7 @@ implDecl ImplDecl: first=^ "impl" ctx=context? name=^ "UpperId" "[" t0=type_ ts=(_"," type_)* ","? "]" rhs=implDeclRhs: ImplDecl( context = ctx, - trait_ = newTyId(name), + trait_ = newTyId(name, state._tokens), tys = Vec.fromIter(once(t0).chain(ts.iter())), items = rhs, _firstToken = TokenIdx(idx = first), diff --git a/Compiler/Main.fir b/Compiler/Main.fir index d2face60..d98889e8 100644 --- a/Compiler/Main.fir +++ b/Compiler/Main.fir @@ -1,6 +1,7 @@ import [ Compiler/NameResolver, Compiler/Program, + Compiler/TypeCheck, ] @@ -33,3 +34,4 @@ main(): pgm.loadCachedModule(parts) pgm.prepModuleEnvs() resolveNames(pgm) + prepTcEnvs(pgm) diff --git a/Compiler/Module.fir b/Compiler/Module.fir new file mode 100644 index 00000000..fabe7e71 --- /dev/null +++ b/Compiler/Module.fir @@ -0,0 +1,222 @@ +import [ + Compiler/Ast, + Compiler/Defs, + Compiler/TypeCheck, + Compiler/TypeCheck/Ty, + Compiler/TypeCheck/TyCon, +] + + +#[derive(ToDoc)] +type PackageUri( + _uri: Str, +) + + +#[derive(ToDoc)] +type ModuleIdx( + ## Index of the module in the program. + _idx: U32, +) + + +impl Eq[ModuleIdx]: + __eq(self: ModuleIdx, other: ModuleIdx) Bool: + self._idx == other._idx + + +impl Hash[ModuleIdx]: + hash(self: ModuleIdx) U32: + self._idx + + +#[derive(ToDoc)] +type Module( + ## The module's package. + _package: PackageUri, + + ## File path of the module's file. + # TODO: Is this absolute? Or relative? If relative, relative to what? + # For now we only use this for errors so it doesn't matter too much. + _filePath: Str, + + ## Path of the module in its package. + _path: Vec[Str], + + _idx: ModuleIdx, + + # -------------------------------------------------------------------------- + # ASTs of definitions in the module. + # + # These don't include imported things. + + ## Functions defined in the module. + _funItems: HashMap[Str, FunDecl], + + ## Associated functions (including methods, but not trait methods) defined + ## in the module. + _assocItems: HashMap[Str, HashMap[Str, FunDecl]], + + ## Types defined in the module. + _tyItems: HashMap[Str, TypeDecl], + + ## Traits defined in the module. + _traitItems: HashMap[Str, TraitDecl], + + ## Impls defined in the module. + _impls: Vec[ImplDecl], + + ## Imported modules. These are used when resolving identifiers. + _imports: Vec[ModuleIdx], + + ## Tokens of the module. ASTs in the other fields for definitions in the + ## current module refer to tokens in this array. + _tokens: Array[Token], + + # -------------------------------------------------------------------------- + # Environments for name resolving. + # + # At this stage we still can't create type constructors and schemes, but the + # environments can refer to things in other modules. + # + # Conceptually the values here are references to items, potentially in other + # modules. + + ## Term environment of the module, used for name resolving. + _termEnv: HashMap[Str, VarDefIdx], + + ## Associated item environment of the module, used for name resolving. + _assocTermEnv: HashMap[Str, HashMap[Str, AssocVarDefIdx]], + + ## Type environment of the module, used for name resolving. + _tyEnv: HashMap[Str, TyDefIdx], + + # -------------------------------------------------------------------------- + # Details of the definitions in the module: + # + # - Kinds, constructors, fields, methods of type constructors + # (types and traits) + # + # - Schemes of top-level and associated functions. + + ## Type constructors. + _cons: HashMap[Str, TyCon], + + ## Top-level function schemes. + _topSchemes: HashMap[Str, Scheme], + + ## Associated function schemes. + _assocFnSchemes: HashMap[Str, HashMap[Str, Scheme]], + + # -------------------------------------------------------------------------- + + _tcEnv: Option[ModuleTcEnv], + + # -------------------------------------------------------------------------- + + ## Strongly connected component index of the module in the program. + ## Generated by the dependency analysis. + _sccIdx: Option[U32], +) + + +Module.tokenText(self, idx: TokenIdx) Str: + self._tokens.get(idx.idx).text + + +Module.idText(self, id: Id) Str: + self.tokenText(id.token) + + +# ------------------------------------------------------------------------------ +# Generated ToDoc implementations + + +impl ToDoc[ModuleIdx]: + toDoc(self: ModuleIdx) Doc: + let args = Doc.break_(0) + args += Doc.grouped( + Doc.str("_idx =") + Doc.nested(4, Doc.break_(1) + self._idx.toDoc()), + ) + args = args.nest(4).group() + Doc.break_(0) + Doc.char(')') + Doc.grouped(Doc.str("ModuleIdx") + Doc.char('(') + args) + + +impl ToDoc[Module]: + toDoc(self: Module) Doc: + let args = Doc.break_(0) + args += Doc.grouped( + Doc.str("_package =") + + Doc.nested(4, Doc.break_(1) + self._package.toDoc()), + ) + args += Doc.char(',') + Doc.break_(1) + args += Doc.grouped( + Doc.str("_filePath =") + + Doc.nested(4, Doc.break_(1) + self._filePath.toDoc()), + ) + args += Doc.char(',') + Doc.break_(1) + args += Doc.grouped( + Doc.str("_path =") + + Doc.nested(4, Doc.break_(1) + self._path.toDoc()), + ) + args += Doc.char(',') + Doc.break_(1) + args += Doc.grouped( + Doc.str("_idx =") + Doc.nested(4, Doc.break_(1) + self._idx.toDoc()), + ) + args += Doc.char(',') + Doc.break_(1) + args += Doc.grouped( + Doc.str("_funItems =") + + Doc.nested(4, Doc.break_(1) + self._funItems.toDoc()), + ) + args += Doc.char(',') + Doc.break_(1) + args += Doc.grouped( + Doc.str("_assocItems =") + + Doc.nested(4, Doc.break_(1) + self._assocItems.toDoc()), + ) + args += Doc.char(',') + Doc.break_(1) + args += Doc.grouped( + Doc.str("_tyItems =") + + Doc.nested(4, Doc.break_(1) + self._tyItems.toDoc()), + ) + args += Doc.char(',') + Doc.break_(1) + args += Doc.grouped( + Doc.str("_traitItems =") + + Doc.nested(4, Doc.break_(1) + self._traitItems.toDoc()), + ) + args += Doc.char(',') + Doc.break_(1) + args += Doc.grouped( + Doc.str("_impls =") + + Doc.nested(4, Doc.break_(1) + self._impls.toDoc()), + ) + args += Doc.char(',') + Doc.break_(1) + args += Doc.grouped( + Doc.str("_imports =") + + Doc.nested(4, Doc.break_(1) + self._imports.toDoc()), + ) + args += Doc.char(',') + Doc.break_(1) + args += Doc.grouped( + Doc.str("_tokens =") + + Doc.nested(4, Doc.break_(1) + self._tokens.toDoc()), + ) + args += Doc.char(',') + Doc.break_(1) + args += Doc.grouped( + Doc.str("_termEnv =") + + Doc.nested(4, Doc.break_(1) + self._termEnv.toDoc()), + ) + args += Doc.char(',') + Doc.break_(1) + args += Doc.grouped( + Doc.str("_assocTermEnv =") + + Doc.nested(4, Doc.break_(1) + self._assocTermEnv.toDoc()), + ) + args += Doc.char(',') + Doc.break_(1) + args += Doc.grouped( + Doc.str("_tyEnv =") + + Doc.nested(4, Doc.break_(1) + self._tyEnv.toDoc()), + ) + args += Doc.char(',') + Doc.break_(1) + args += Doc.grouped( + Doc.str("_sccIdx =") + + Doc.nested(4, Doc.break_(1) + self._sccIdx.toDoc()), + ) + args = args.nest(4).group() + Doc.break_(0) + Doc.char(')') + Doc.grouped(Doc.str("Module") + Doc.char('(') + args) diff --git a/Compiler/NameResolver.fir b/Compiler/NameResolver.fir index ad226bc0..31ef1140 100644 --- a/Compiler/NameResolver.fir +++ b/Compiler/NameResolver.fir @@ -6,6 +6,7 @@ import [ Compiler/Ast, Compiler/AstVisitor, Compiler/Defs, + Compiler/Module, Compiler/ScopeMap, ] diff --git a/Compiler/ParseUtils.fir b/Compiler/ParseUtils.fir index 3cbb65cd..b409c939 100644 --- a/Compiler/ParseUtils.fir +++ b/Compiler/ParseUtils.fir @@ -55,24 +55,36 @@ parseCharLit(text: Str) Char: parseStrParts(text: Str) Vec[StrPart]: - Vec.empty() # TODO - + # TODO + Vec.empty() newId(cursor: U32) Id: Id(token = TokenIdx(idx = cursor)) -newVarId(cursor: U32) VarId: - VarId(token = TokenIdx(idx = cursor), _resolved = Option.None) +newVarId(cursor: U32, tokens: Vec[Token]) VarId: + VarId( + token = TokenIdx(idx = cursor), + name = tokens.get(cursor).text, + _resolved = Option.None, + ) -newAssocId(cursor: U32) AssocVarId: - AssocVarId(token = TokenIdx(idx = cursor), _resolved = Option.None) +newAssocId(cursor: U32, tokens: Vec[Token]) AssocVarId: + AssocVarId( + token = TokenIdx(idx = cursor), + name = tokens.get(cursor).text, + _resolved = Option.None, + ) -newTyId(cursor: U32) TyId: - TyId(token = TokenIdx(idx = cursor), _resolved = Option.None) +newTyId(cursor: U32, tokens: Vec[Token]) TyId: + TyId( + token = TokenIdx(idx = cursor), + name = tokens.get(cursor).text, + _resolved = Option.None, + ) processFields(fields: Vec[(name: Option[Id], ty: Type)]) ConFields: diff --git a/Compiler/Program.fir b/Compiler/Program.fir index 509049c5..c600955d 100644 --- a/Compiler/Program.fir +++ b/Compiler/Program.fir @@ -6,17 +6,25 @@ import [ Compiler/Defs, Compiler/Error, Compiler/Grammar, + Compiler/Module, Compiler/Scanner, + Compiler/TypeCheck, ] type Program( + ## Indexed by `ModuleIdx`. _modules: Vec[Module], + + ## Strongly connected graphs of modules. Indexed by SCC indices. + ## + ## Initialized by `prepModuleEnvs`. + _depGraph: Option[Vec[HashSet[ModuleIdx]]], ) Program.new() Program: - Program(_modules = Vec.empty()) + Program(_modules = Vec.empty(), _depGraph = Option.None,) ## Load a module, or return it from the cache if it's already loaded. @@ -84,7 +92,7 @@ Program.loadCachedModule(self, path: Vec[Str]) ModuleIdx: let old = funItems.insert(funName, funDecl) if old is Option.Some(_): panic( - "Function`funName` defined multipe times in `filePath`", + "Function `funName` defined multipe times in `filePath`", ) TopDecl.Trait(traitDecl): @@ -122,6 +130,10 @@ Program.loadCachedModule(self, path: Vec[Str]) ModuleIdx: _termEnv = HashMap.withCapacity(100), _assocTermEnv = HashMap.withCapacity(10), _tyEnv = HashMap.withCapacity(100), + _cons = HashMap.withCapacity(100), + _topSchemes = HashMap.withCapacity(10), + _assocFnSchemes = HashMap.withCapacity(100), + _tcEnv = Option.None, _sccIdx = Option.None, ) @@ -147,6 +159,7 @@ Program.prepModuleEnvs(self): # in the SCC. So we create SCCs, and then create one module env for each # SCC, which is shared with all of the modules of the SCC. let sccs = _sccs(self) + self._depGraph = Option.Some(sccs) print("Program modules:") for i: U32 in range(u32(0), self._modules.len()): @@ -259,9 +272,12 @@ _makeSccEnv( let tokens = mod._tokens for ty: HashMapEntry[Str, TypeDecl] in mod._tyItems.iter(): - let tyDefIdx = TyDefIdx.Type( - TyDefIdx_(_mod = modIdx, _name = ty.key), + let tyDefIdx = TyDefIdx( + _mod = modIdx, + _name = ty.key, + _kind = TyDefKind.Type, ) + let old = typeEnv.insert(ty.key, tyDefIdx) if old is Option.Some(_): panic( @@ -269,8 +285,10 @@ _makeSccEnv( ) for trait_: HashMapEntry[Str, TraitDecl] in mod._traitItems.iter(): - let traitDefIdx = TyDefIdx.Trait( - TraitDefIdx(_mod = modIdx, _name = trait_.key), + let traitDefIdx = TyDefIdx( + _mod = modIdx, + _name = trait_.key, + _kind = TyDefKind.Trait, ) let old = typeEnv.insert(trait_.key, traitDefIdx) if old is Option.Some(_): diff --git a/Compiler/ScopeMap.fir b/Compiler/ScopeMap.fir index 6da53498..e1b2daef 100644 --- a/Compiler/ScopeMap.fir +++ b/Compiler/ScopeMap.fir @@ -4,6 +4,12 @@ type ScopeMap[k, v]( ) +ScopeMap.empty() ScopeMap[k, v]: + let scopes = Vec.withCapacity(10) + scopes.push(HashMap.empty()) + ScopeMap(scopes) + + ScopeMap.fromMap(map: HashMap[k, v]) ScopeMap[k, v]: let scopes = Vec.withCapacity(10) scopes.push(map) diff --git a/Compiler/TypeCheck.fir b/Compiler/TypeCheck.fir new file mode 100644 index 00000000..a17bcf06 --- /dev/null +++ b/Compiler/TypeCheck.fir @@ -0,0 +1,252 @@ +# Import all modules temporarily to type check all. +import [ + Compiler/Program, + Compiler/ScopeMap, + Compiler/TypeCheck/Convert, + Compiler/TypeCheck/Error, + Compiler/TypeCheck/Normalization, + Compiler/TypeCheck/RowUtils, + Compiler/TypeCheck/TraitEnv, + Compiler/TypeCheck/Ty, + Compiler/TypeCheck/TyCon, + Compiler/TypeCheck/TyMap, + Compiler/TypeCheck/Unification, +] + + +prepTcEnvs(pgm: Program): + () + + +prepModuleTyEnv(pgm: Program, modIdx: ModuleIdx): + if pgm._modules.get(modIdx._idx)._tcEnv is Option.Some(env): + return + + # Reminder: with the current import semantics each SCC has access to all of + # the things in the SCC. + # + # However we still need per-module (rather than per-SCC) environments + # because + # + # (1) the import paths of items available in each of the modules will be + # different. + # + # (2) in the future we will add imported item lists so the environments will + # be different in places other than just the import paths. + + let depGraph = pgm._depGraph.unwrap() + let sccIdx = pgm._modules.get(modIdx._idx)._sccIdx.unwrap() + let sccModules: HashSet[ModuleIdx] = pgm._depGraph.unwrap().get(sccIdx) + + # Initialize module tc envs with just the defined things in each of the + # modules. + for modIdx: ModuleIdx in sccModules.iter(): + pgm._modules.get(modIdx._idx)._tcEnv = Option.Some( + moduleTcEnvFromDefinedItems(pgm._modules.get(modIdx._idx)), + ) + + let updated = Bool.True + + # Add imports until we can't add any more imports. + while updated: + updated = Bool.False + + for modIdx: ModuleIdx in sccModules.iter(): + if addImports(pgm, modIdx): + updated = Bool.True + + +moduleTcEnvFromDefinedItems(mod: Module) ModuleTcEnv: + panic("TODO") + + +addImports(pgm: Program, modIdx: ModuleIdx) Bool: + panic("TODO") + + +# ------------------------------------------------------------------------------ +# Module type checking environment + + +## The module environment holds type constructors, traits (with impls), and +## associated and top-level functions available in the module. +type ModuleTcEnv( + ## The traits, with impls. + _traitEnv: TraitEnv, + + ## Type constructors. + _cons: HashMap[Str, TcItem[TyId, TyCon]], + + ## Top-level function schemes. + _topSchemes: HashMap[Str, TcItem[VarId, Scheme]], + + ## Associated function schemes. + _assocFnSchemes: HashMap[Str, HashMap[Str, TcItem[VarId, Scheme]]], + + ## Type schemes of methods. + ## + ## Maps method names to (type or trait name, type scheme) pairs. + ## + ## These are associated functions (so they're also in + ## `associated_fn_schemes`) that take a `self` parameter. + ## + ## The first parameters of the function types here are the `self` types. + ## + ## Because these schemes are only used in method call syntax, the keys are + ## not type names but method names. The values are type schemes of methods + ## with the name. + _methodSchemes: HashMap[Str, Vec[(id: Str, scheme: Scheme)]], +) + + +## An item (function, associated function, method, type, trait) defined +## somewhere in the program. (current module or imported) +type TcItem[idx, info]( + ## Reference to the definition of the item. + idx: idx, + + ## The details of the definition we need to type check a module. + ## + ## For functions this will be the type scheme of the function. + ## + ## For types this will be the type info: + ## - Constructors (and fields etc.) of types + ## - Methods of traits + ## - Kinds + ## - etc. + info: info, + + ## Import paths of the item. + ## + ## An item can be imported via different paths. This will have all of those + ## paths. + ## + ## For items defined in the current module this will be empty. + ## + ## TODO: Introduce an import path type and use it here. + imports: Vec[Vec[Id]], +) + + +type Import( + ## Name the item imported as. E.g. in `import [Foo/Bar as Baz]` this will be + ## `Baz`. + name: Id, + + ## The path the item was imported from. In the example above: `Foo/Bar`. + ## + ## The same item can be imported via different paths. + importLoc: Vec[Id], +) + + +ModuleTcEnv.empty() ModuleTcEnv: + ModuleTcEnv( + _traitEnv = TraitEnv.empty(), + _cons = HashMap.empty(), + _topSchemes = HashMap.empty(), + _assocFnSchemes = HashMap.empty(), + _methodSchemes = HashMap.empty(), + ) + + +## Add imported things from an imported SCC to the current module. +ModuleTcEnv.addImportedThings(self, importedEnv: ModuleTcEnv): + # TODO: Panics below should be errors. + # TODO: The code below is not right: the same type or term can be imported + # via different modules. So importing something multiple times is not a + # problem as long as we import the same thing. + + # Add type constructors. + for con: HashMapEntry[Str, TcItem[TyId, TyCon]] in importedEnv._cons.iter(): + let old = self._cons.insert(con.key, con.value) + if old is Option.Some(_): + # TODO: Handle importing the same type. (compare TyIds, update + # import list) + panic("Type `con.key` imported multiple times") + + # Add traits. + # TODO: TraitEnv doesn't let you rename traits. Needs refactoring. + + # Add top-level functions. + for top: HashMapEntry[Str, TcItem[VarId, Scheme]] in + importedEnv._topSchemes.iter(): + let old = self._topSchemes.insert(top.key, top.value) + if old is Option.Some(_): + panic("Top-level function `top.key` imported multiple times") + + # Add associated functions. + # TODO: We may import a type with another name, the associated function + # should be added to the right type when we do that. + for assoc: HashMapEntry[Str, HashMap[Str, TcItem[VarId, Scheme]]] in + importedEnv._assocFnSchemes.iter(): + let tyMap = match self._assocFnSchemes.get(assoc.key): + Option.Some(tyMap): tyMap + Option.None: + let tyMap = HashMap.withCapacity(10) + self._assocFnSchemes.insert(assoc.key, tyMap) + tyMap + + for fun: HashMapEntry[Str, TcItem[VarId, Scheme]] in assoc.value.iter(): + let old = tyMap.insert(fun.key, fun.value) + if old is Option.Some(_): + panic( + "Associated function `assoc.key`.`fun.key` imported multiple times.", + ) + + +# Add methods. +# TODO + + +# ------------------------------------------------------------------------------ +# Function type checking environment + + +## Type checking state for a single function (top-level, associated, or method). +type FunTcEnv( + _modEnv: ModuleTcEnv, + + _termEnv: ScopeMap[Str, Ty], + + ## Unification variable generator. + _varGen: UVarGen, + + ## Exception type of the current function. + ## + ## Exceptions thrown by called functions are unified with this type. + ## + ## For now we don't do exception type inference, so this will always be a + ## concrete type (with rigid type variables). + _exnTy: Ty, + + ## Return type of the current function. + ## + ## This is used when checking expressions in return positions and in + ## `return` expressions. + _retTy: Ty, + + ## Predicates generated when checking the function body. + ## + ## After checking the body, these predicates should all be resolved with the + ## function context and trait environment. + ## + ## This is a `Vec` instead of `HashSet` as the type checker never visits an + ## expression twice, so every `Pred` here will have a different `Loc`. + _preds: Vec[Pred], + + ## The function context. + _assumps: Vec[Pred], +) + + +FunTcEnv.new(modEnv: ModuleTcEnv, exnTy: Ty, retTy: Ty, assumps: Vec[Pred]) FunTcEnv: + FunTcEnv( + _modEnv = modEnv, + _termEnv = ScopeMap.empty(), + _varGen = UVarGen.new(), + _exnTy = exnTy, + _retTy = retTy, + _preds = Vec.withCapacity(10), + _assumps = assumps, + ) diff --git a/Compiler/TypeCheck/Convert.fir b/Compiler/TypeCheck/Convert.fir new file mode 100644 index 00000000..b5e2934a --- /dev/null +++ b/Compiler/TypeCheck/Convert.fir @@ -0,0 +1,157 @@ +import [ + Compiler/Ast, + Compiler/TypeCheck/Error, + Compiler/TypeCheck/Ty, + Compiler/TypeCheck/TyMap, +] + + +convertAstTy(module: Module, tys: TyMap, astTy: Type, loc: Loc) Ty / TypeError: + match astTy: + Type.Named(namedTy): convertNamedTy(module, tys, namedTy, loc) + + Type.Var(var_): + let varText = module.idText(var_) + tys.getVar(varText).unwrapOrElse( + ||: + throw( + TypeError(loc, msg = "Unbound type variable `varText`",), + ), + ) + + Type.Record(RecordType(fields, extension, isRow, ..)): + let labels: HashMap[LocalId, Ty] = HashMap.withCapacity(fields.len()) + + for namedField: Named[Type] in fields.iter(): + let name = module.idText( + namedField.name.unwrapOrElse( + ||: + throw( + TypeError( + loc, + msg = + "Records with unnamed fields not supported yet", + ), + ), + ), + ) + let ty = convertAstTy(module, tys, namedField.node, loc) + let old = labels.insert(LocalId(name), ty) + if old is Option.Some(_): + throw( + TypeError( + loc, + msg = + "Field `name` defined multiple times in record", + ), + ) + + let extension = extension.map( + |extId|: + let extIdText = module.idText(extId) + tys.getVar(extIdText).unwrapOrElse( + ||: + throw( + TypeError( + loc, + msg = "Unbound type variable `extIdText`", + ), + ), + ), + ) + + Ty.Anonymous(labels, extension, kind = RecordOrVariant.Record, isRow) + + Type.Variant(VariantType(alts, extension, isRow, ..)): + let labels: HashMap[LocalId, Ty] = HashMap.withCapacity(alts.len()) + + for alt: NamedType in alts.iter(): + let ty = convertNamedTy(module, tys, alt, loc) + let tyNameText = alt.name.name + let old = labels.insert(LocalId(name = tyNameText), ty) + if old is Option.Some(_): + throw( + TypeError( + loc, + msg = + "Type `tyNameText` used multiple times in variant type", + ), + ) + + let extension = extension.map( + |extId|: + let extIdText = module.idText(extId) + tys.getVar(extIdText).unwrapOrElse( + ||: + throw( + TypeError( + loc, + msg = "Unbound type variable `extIdText`", + ), + ), + ), + ) + + Ty.Anonymous( + labels, + extension, + kind = RecordOrVariant.Variant, + isRow, + ) + + Type.Fn_(FnType(args, ret, exceptions, ..)): + let args = FunArgs.Positional( + Vec.fromIter( + args.iter().map( + |ty: Type| Ty: convertAstTy(module, tys, ty, loc), + ), + ), + ) + + let ret = match ret: + Option.None: Ty.unit() + Option.Some(ret): convertAstTy(module, tys, ret, loc) + + let exceptions = exceptions.unwrapOrElse( + ||: + throw( + TypeError( + loc, + msg = "Function type without exception type", + ), + ), + ) + + let exceptions = convertAstTy(module, tys, exceptions, loc) + + Ty.Fun(args, ret, exn = Option.Some(exceptions)) + + +convertNamedTy(module: Module, tys: TyMap, namedTy: NamedType, loc: Loc) Ty / TypeError: + let tyName = namedTy.name.name + let tyArgs = namedTy.args + let numTyArgs = namedTy.numTyArgs() + + let tyCon = tys.getCon(tyName).unwrapOrElse( + ||: throw(TypeError(loc, msg = "Unknown type `tyName`")), + ) + + if tyCon.arity() != numTyArgs: + throw( + TypeError( + loc, + msg = + "Incorrect number of type arguments to `tyName`, expected `tyCon.arity()`, found `numTyArgs`", + ), + ) + + if numTyArgs == 0: + return Ty.Con(id = namedTy.name) + + let convertedArgs: Vec[Ty] = Vec.fromIter( + tyArgs.unwrap().args.iter().map( + |astTy: Type| Ty: convertAstTy(module, tys, astTy, loc), + ), + ) + + Ty.App(conId = namedTy.name, args = convertedArgs) diff --git a/Compiler/TypeCheck/Error.fir b/Compiler/TypeCheck/Error.fir new file mode 100644 index 00000000..d3182c2a --- /dev/null +++ b/Compiler/TypeCheck/Error.fir @@ -0,0 +1,22 @@ +import [Compiler/Error] + + +#[derive(ToDoc)] +type TypeError( + loc: Loc, + msg: Str, +) + + +impl ToDoc[TypeError]: + toDoc(self: TypeError) Doc: + let args = Doc.break_(0) + args += Doc.grouped( + Doc.str("loc =") + Doc.nested(4, Doc.break_(1) + self.loc.toDoc()), + ) + args += Doc.char(',') + Doc.break_(1) + args += Doc.grouped( + Doc.str("msg =") + Doc.nested(4, Doc.break_(1) + self.msg.toDoc()), + ) + args = args.nest(4).group() + Doc.break_(0) + Doc.char(')') + Doc.grouped(Doc.str("TypeError") + Doc.char('(') + args) diff --git a/Compiler/TypeCheck/Normalization.fir b/Compiler/TypeCheck/Normalization.fir new file mode 100644 index 00000000..852571bf --- /dev/null +++ b/Compiler/TypeCheck/Normalization.fir @@ -0,0 +1,83 @@ +import [Compiler/TypeCheck/Ty] + + +Ty.normalize(self) Ty: + match self: + Ty.UVar(uvar): uvar.normalize() + _: self + + +UVar.normalize(self) Ty: + match self.link: + Option.None: Ty.UVar(self) + Option.Some(link): + let linkNormal = link.normalize() + self.link = Option.Some(linkNormal) + linkNormal + + +Ty.deepNormalize(self) Ty: + match self: + Ty.Con(..): self + + Ty.App(conId, args): + Ty.App( + conId, + args = + Vec.fromIter(args.iter().map(|arg: Ty|: arg.deepNormalize())), + ) + + Ty.QVar(_): panic("QVar in deepNormalize") + + Ty.UVar(uvar): + if uvar.link is Option.None: + return self + + let uvarLink = uvar.normalize().deepNormalize() + uvar.link = Option.Some(uvarLink) + uvarLink + + Ty.RVar(_): self + + Ty.Fun(args, ret, exn): + Ty.Fun( + args = + match args: + FunArgs.Positional(tys): + FunArgs.Positional( + Vec.fromIter( + tys.iter().map(|ty: Ty|: ty.deepNormalize()), + ), + ) + FunArgs.Named(tys): + FunArgs.Named( + HashMap.fromIter( + tys.iter().map( + |ty: HashMapEntry[LocalId, Ty]|: + ( + key = ty.key, + value = ty.value.deepNormalize(), + ), + ), + ), + ), + ret = ret.deepNormalize(), + exn = exn.map(|ty: Ty|: ty.deepNormalize()), + ) + + Ty.Anonymous(labels, extension, kind, isRow): + Ty.Anonymous( + labels = + HashMap.fromIter( + labels.iter().map( + |label: HashMapEntry[LocalId, Ty]|: + ( + key = label.key, + value = label.value.deepNormalize(), + ), + ), + ), + extension = extension.map(|ty: Ty|: ty.deepNormalize()), + kind, + isRow, + ) diff --git a/Compiler/TypeCheck/RowUtils.fir b/Compiler/TypeCheck/RowUtils.fir new file mode 100644 index 00000000..2f1e66f3 --- /dev/null +++ b/Compiler/TypeCheck/RowUtils.fir @@ -0,0 +1,54 @@ +import [ + Compiler/Assert, + Compiler/TypeCheck/Ty, +] + + +collectRows( + ty: Ty, + tyKind: RecordOrVariant, + labels: HashMap[LocalId, Ty], + extension: Option[Ty], +) (rows: HashMap[LocalId, Ty], extension: Option[Ty]): + let allLabels: HashMap[LocalId, Ty] = HashMap.fromIter( + labels.iter().map( + |label: HashMapEntry[LocalId, Ty]|: + (key = label.key, value = label.value.deepNormalize()), + ), + ) + + while extension is Option.Some(ext): + match ext: + Ty.Anonymous(labels, extension = nextExt, kind, isRow): + assert(kind == tyKind) + assert(isRow) + for label: HashMapEntry[LocalId, Ty] in labels.iter(): + if allLabels.insert(label.key, label.value) + is Option.Some(_): + panic( + "BUG: Duplicate label in anonymous type `ty.toDoc().render(50)`", + ) + extension = nextExt + + Ty.UVar(uvar): + assert(uvar.kind is Kind.Row(_)) + match uvar.normalize(): + Ty.Anonymous(labels, extension = nextExt, kind, isRow): + assert(isRow) + assert(kind == tyKind) + for label: HashMapEntry[LocalId, Ty] in labels.iter(): + if allLabels.insert(label.key, label.value) + is Option.Some(_): + panic( + "BUG: Duplicate label in anonymous type `ty.toDoc().render(50)`", + ) + extension = nextExt + + other: + return ( + rows = allLabels, extension = Option.Some(other) + ) + + other: return (rows = allLabels, extension = Option.Some(other)) + + (rows = allLabels, extension = Option.None) diff --git a/Compiler/TypeCheck/TraitEnv.fir b/Compiler/TypeCheck/TraitEnv.fir new file mode 100644 index 00000000..87590587 --- /dev/null +++ b/Compiler/TypeCheck/TraitEnv.fir @@ -0,0 +1,140 @@ +import [ + Compiler/TypeCheck/Ty, + Compiler/TypeCheck/Unification, +] + + +#[derive(ToDoc)] +type TraitEnv( + _map: HashMap[TyId, Vec[TraitImpl]], +) + + +TraitEnv.empty() TraitEnv: + TraitEnv(_map = HashMap.empty(),) + + +# Example: `impl[Iterator[iter, a]] Iterator[Map[iter, a, b], b]: ...` +#[derive(ToDoc)] +type TraitImpl( + # Free variables of the `impl`. + # + # In the example: `iter`, `a`, `b`. + qvars: Vec[QVar], + + # Arguments of the trait. + # + # In the example: `[Map[iter, a, b], b]`, where `iter`, `a` and `b` are + # `QVar`s in `qvars`. + traitArgs: Vec[Ty], + + # Predicates of the implementation. + # + # In the example: `[(Iterator, [iter, a])]`, where `iter` and `a` are + # `QVar`s in `qvars`. + # + # Note: these types should be instantiated together with `traitArgs` so that + # the same `QVar` in arguments and preds will be the same instantiated type + # variable, and as we match args the preds will be updated. + preds: Vec[TraitPred], +) + + +# Similar to `Pred`, but doesn't have a `Loc`. +#[derive(ToDoc)] +type TraitPred( + trait_: TyId, + params: Vec[Ty], +) + + +TraitImpl.tryMatch(self, args: Vec[Ty], varGen: UVarGen, tys: TyMap, loc: Loc) Option[Vec[Pred]]: + if args.len() != self.traitArgs.len(): + panic( + "`loc`: BUG: Number of arguments applied to the trait don't match the arity", + ) + + # Maps `QVar`s to instantiations. + let varMap: HashMap[LocalId, Ty] = HashMap.fromIter( + self.qvars.iter().map( + |qvar: QVar| (key: LocalId, value: Ty): + let instantiatedVar = Ty.UVar(varGen.newVar(0, qvar.kind, loc)) + (key = qvar.id, value = instantiatedVar), + ), + ) + + for argIdx: U32 in range(u32(0), self.traitArgs.len()): + let implArg = self.traitArgs.get(argIdx) + let tyArg = args.get(argIdx) + let instantiatedImplArg = implArg.substQVars(varMap) + if not unifyOneWay(instantiatedImplArg, tyArg, varGen, 0, loc): + return Option.None + + Option.Some( + Vec.fromIter( + self.preds.iter().map( + |pred: TraitPred|: + Pred( + trait_ = pred.trait_, + params = + Vec.fromIter( + args.iter().map( + |arg: Ty|: arg.substQVars(varMap), + ), + ), + loc, + ), + ), + ), + ) + + +# ------------------------------------------------------------------------------ +# Generated ToDoc implementations + + +impl ToDoc[TraitEnv]: + toDoc(self: TraitEnv) Doc: + let args = Doc.break_(0) + args += Doc.grouped( + Doc.str("_map =") + Doc.nested(4, Doc.break_(1) + self._map.toDoc()), + ) + args = args.nest(4).group() + Doc.break_(0) + Doc.char(')') + Doc.grouped(Doc.str("TraitEnv") + Doc.char('(') + args) + + +impl ToDoc[TraitImpl]: + toDoc(self: TraitImpl) Doc: + let args = Doc.break_(0) + args += Doc.grouped( + Doc.str("qvars =") + + Doc.nested(4, Doc.break_(1) + self.qvars.toDoc()), + ) + args += Doc.char(',') + Doc.break_(1) + args += Doc.grouped( + Doc.str("traitArgs =") + + Doc.nested(4, Doc.break_(1) + self.traitArgs.toDoc()), + ) + args += Doc.char(',') + Doc.break_(1) + args += Doc.grouped( + Doc.str("preds =") + + Doc.nested(4, Doc.break_(1) + self.preds.toDoc()), + ) + args = args.nest(4).group() + Doc.break_(0) + Doc.char(')') + Doc.grouped(Doc.str("TraitImpl") + Doc.char('(') + args) + + +impl ToDoc[TraitPred]: + toDoc(self: TraitPred) Doc: + let args = Doc.break_(0) + args += Doc.grouped( + Doc.str("trait_ =") + + Doc.nested(4, Doc.break_(1) + self.trait_.toDoc()), + ) + args += Doc.char(',') + Doc.break_(1) + args += Doc.grouped( + Doc.str("params =") + + Doc.nested(4, Doc.break_(1) + self.params.toDoc()), + ) + args = args.nest(4).group() + Doc.break_(0) + Doc.char(')') + Doc.grouped(Doc.str("TraitPred") + Doc.char('(') + args) diff --git a/Compiler/TypeCheck/Ty.fir b/Compiler/TypeCheck/Ty.fir new file mode 100644 index 00000000..af9e41ae --- /dev/null +++ b/Compiler/TypeCheck/Ty.fir @@ -0,0 +1,564 @@ +import [ + Compiler/Defs, + Compiler/Error, +] + + +# ------------------------------------------------------------------------------ + + +## A type checking type. +#[derive(ToDoc)] +type Ty: + ## A type constructor with kind `*`. E.g. `U32`, `Bool`. + Con( + id: TyId, + ) + + ## A type application, e.g. `Vec[U32]`, `Result[E, T]`. + ## + ## Invariant: the `args` vector is not empty. Nullary constructor + ## applications are represented as `Ty.Con`. + App( + conId: TyId, + args: Vec[Ty], + ) + + # -------------------------------------------------------------------------- + # Variables + + ## A quantified type variable, in type schemes. + QVar(QVar) + + ## A unification variable, created from a `Ty.QVar` in instantiation. + UVar(UVar) + + ## A rigid type variable. These are quantified type variables when checking + ## the body of the function with the quantified type variables. + RVar(RVar) + + # -------------------------------------------------------------------------- + + ## A function type, e.g. `Fn(U32) Str`, `Fn(x: U32, y: U32) T / Err`. + Fun( + args: FunArgs, + ret: Ty, + + ## Exception type of a function is always a `row(variant)`-kinded type + ## variable. In type schemes, this will be a `QVar`. + ## + ## Not available in constructors. + exn: Option[Ty], + ) + + ## An anonymous record or variant type or row type. E.g. `(a: Str, ..r)`, + ## `[Err1(Str), ..r]`. + Anonymous( + labels: HashMap[LocalId, Ty], + + ## Row extension. When available, this will be one of: + ## + ## - `Ty.Var`: a unification variable. + ## - `Ty.Con`: a rigid type variable. + extension: Option[Ty], + + kind: RecordOrVariant, + + ## Whether this is a row type. A row type has its own kind `row`. When + ## not a row, the type has kind `*`. + isRow: Bool, + ) + + +## Only in type schemes: a quantified type variable. +## +## When instantiated, these become unification variables (`Ty.UVar`). +## +## When checking function bodies, quantified type variables of the function's +## type scheme becomes rigid type variables (`Ty.RVar`). +#[derive(ToDoc)] +type QVar( + id: LocalId, + kind: Kind, +) + + +## A unification variable. +## +## Note: `Hash` and `Eq` are implemented based on `id`. +#[derive(ToDoc)] +type UVar( + ## Identity of the unification variable. + ## + ## This is used to compare unification variables for equality. + id: U32, + + ## Kind of the variable. + kind: Kind, + + ## Binding level: depth of the scope the unification variable was created + ## in. + level: U32, + + ## When unified with a type, this holds the type. + link: Option[Ty], +) + + +impl Eq[UVar]: + __eq(self: UVar, other: UVar) Bool: + self.id == other.id + + +UVar.setLink(self, link: Ty): + assert(self.link is Option.None) + self.link = Option.Some(link) + + +## A rigid type variable. These are quantified type variables when checking +## the body of the function with the quantified type variables. +#[derive(ToDoc)] +type RVar( + id: LocalId, + kind: Kind, +) + + +#[derive(ToDoc)] +type FunArgs: + Positional(Vec[Ty]) + Named(HashMap[LocalId, Ty]) + + +FunArgs.len(self) U32: + match self: + FunArgs.Positional(tys): tys.len() + FunArgs.Named(tys): tys.len() + + +## Kind of a type. +## +## We don't support higher-kinded variables yet, so this is either a `*` or +## `row` for now. +#[derive(ToDoc)] +type Kind: + Star + Row(RecordOrVariant) + + +impl Eq[Kind]: + __eq(self: Kind, other: Kind) Bool: + match (left = self, right = other): + (left = Kind.Star, right = Kind.Star): Bool.True + (left = Kind.Row(row1), right = Kind.Row(row2)): row1 == row2 + _: Bool.False + + +#[derive(ToDoc)] +type RecordOrVariant: + Record + Variant + + +impl Eq[RecordOrVariant]: + __eq(self: RecordOrVariant, other: RecordOrVariant) Bool: + match (left = self, right = other): + (left = RecordOrVariant.Record, right = RecordOrVariant.Record): + Bool.True + (left = RecordOrVariant.Variant, right = RecordOrVariant.Variant): + Bool.True + _: Bool.False + + +Ty.kind(self) Kind: + match self: + Ty.Con(..) | Ty.App(..) | Ty.Fun(..): Kind.Star + + Ty.QVar(QVar(kind, ..)) | Ty.RVar(RVar(kind, ..)): kind + + Ty.UVar(uvar): uvar.kind + + Ty.Anonymous(kind, ..): Kind.Row(kind) + + +Ty.unit() Ty: + Ty.Anonymous( + labels = HashMap.empty(), + extension = Option.None, + kind = RecordOrVariant.Record, + isRow = Bool.False, + ) + + +Ty.emptyVariant() Ty: + Ty.Anonymous( + labels = HashMap.empty(), + extension = Option.None, + kind = RecordOrVariant.Variant, + isRow = Bool.False, + ) + + +Ty.unitRow(kind: RecordOrVariant) Ty: + Ty.Anonymous( + labels = HashMap.empty(), + extension = Option.None, + kind, + isRow = Bool.True, + ) + + +# ------------------------------------------------------------------------------ +# Type schemes and instantiation + + +#[derive(ToDoc)] +type Scheme( + ## Generalized (quantified) variables with kinds. + ## + ## When the scheme is for a trait method, the first type parameters will be + ## the type parameters for the trait, in the right order. + qvars: Vec[QVar], + + ## Predicates of the type scheme. These can refer to `qvars` and need to be + ## instantiated with `qvars`. + ## + ## This is a `Vec`, so in principle it can contain duplicates. Schemes are + ## currently only generated from top-level functions, which have explicit + ## type signatures. So the only way to have duplicates here is when the user + ## writes duplicate predicates. + ## + ## Duplicates in predicates are harmless, just cause extra work when + ## resolving the predicates at the call sites. + preds: Vec[Pred], + + ## The generalized type. + # TODO: Should we have separate fields for arguments types and return type? + ty: Ty, + + ## Location of the scheme's function (or field etc.). + loc: Loc, +) + + +#[derive(ToDoc)] +type UVarGen( + _nextId: U32, +) + + +UVarGen.new() UVarGen: + UVarGen(_nextId = 0) + + +UVarGen.nextId(self) U32: + let next = self._nextId + self._nextId += 1 + next + + +UVarGen.newVar(self, level: U32, kind: Kind, loc: Loc) UVar: + UVar(id = self.nextId(), kind, level, link = Option.None) + + +## Instantiate the type scheme. Generated predicates are added to `preds`. +## Returns the instantiated type and instantiated type variables of the scheme. +Scheme.instantiate( + self, + level: U32, + varGen: UVarGen, + preds: Vec[Pred], + loc: Loc, +) (ty: Ty, vars: Vec[UVar]): + # TODO: We should rename type variables in a renaming pass, or disallow + # shadowing, or handle shadowing here. + + # Maps `QVar`s to unification variables. + let varMap = HashMap[LocalId, Ty].withCapacity(10) + + # Instantiated type parameters, in the same order as `self.qvars`. + let instantiations = Vec[UVar].withCapacity(self.qvars.len()) + + # Instantiate qvars of the scheme. + for qvar: QVar in self.qvars.iter(): + let qvarId = qvar.id + let qvarKind = qvar.kind + + let uvar = varGen.newVar(level, qvarKind, self.loc) + + let old = varMap.insert(qvarId, Ty.UVar(uvar)) + assert(old is Option.None) + + instantiations.push(uvar) + + # Generate predicates. + for pred: Pred in self.preds.iter(): + let instantiatedPred = Pred( + trait_ = pred.trait_, + params = + Vec.fromIter( + pred.params.iter().map(|param: Ty|: param.substQVars(varMap)), + ), + loc, + ) + preds.push(instantiatedPred) + + (ty = self.ty.substQVars(varMap), vars = instantiations) + + +# ------------------------------------------------------------------------------ +# Predicates + + +#[derive(ToDoc)] +type Pred( + trait_: TyId, + params: Vec[Ty], + loc: Loc, +) + + +## A set of predicates. +#[derive(ToDoc)] +type PredSet( + ## The set is actually a `Vec`, at least for now. + ## + ## The reason why we don't use a `HashSet` here is because the type checker + ## never adds a duplicate predicates, because it visits every expression + ## just once, and predicates hold source code location of the expression + ## that created them. So even if we have e.g. `Eq[U32]` multiple times here, + ## each of those predicates will have a different location. + ## + ## The only way to have duplicate predicates is when a user has a signature + ## with the same predicate multiple times. In that case `Scheme.preds` will + ## have duplicates, and calling the function will create duplicate + ## predicates. Those are harmless, just cause more work when resolving them. + preds: Vec[Pred], +) + + +# ------------------------------------------------------------------------------ +# Substitutions + + +Ty.substQVars(self, vars: HashMap[LocalId, Ty]) Ty: + match self: + Ty.Con(..) | Ty.UVar(..) | Ty.RVar(..): self + + Ty.App(conId, args): + Ty.App( + conId, + args = + Vec.fromIter(args.iter().map(|ty: Ty|: ty.substQVars(vars))), + ) + + Ty.QVar(QVar(id, ..)): + vars.get(id).unwrapOrElse( + ||: panic("Ty.substQVars: unbound QVar `id.name`"), + ) + + Ty.Fun(args, ret, exn): + Ty.Fun( + args = + match args: + FunArgs.Positional(tys): + FunArgs.Positional( + Vec.fromIter( + tys.iter().map(|ty: Ty|: ty.substQVars(vars)), + ), + ) + FunArgs.Named(tys): + FunArgs.Named( + HashMap.fromIter( + tys.iter().map( + |entry: HashMapEntry[LocalId, Ty]|: + ( + key = entry.key, + value = + entry.value.substQVars(vars), + ), + ), + ), + ), + ret = ret.substQVars(vars), + exn = exn.map(|exn: Ty|: exn.substQVars(vars)), + ) + + Ty.Anonymous(labels, extension, kind, isRow): + Ty.Anonymous( + labels = + HashMap.fromIter( + labels.iter().map( + |entry: HashMapEntry[LocalId, Ty]|: + ( + key = entry.key, + value = entry.value.substQVars(vars), + ), + ), + ), + extension = extension.map(|ty: Ty|: ty.substQVars(vars)), + kind, + isRow, + ) + + +# ------------------------------------------------------------------------------ +# Generated ToDoc implementations + + +impl ToDoc[Ty]: + toDoc(self: Ty) Doc: + panic("TODO") + + +impl ToDoc[QVar]: + toDoc(self: QVar) Doc: + panic("TODO") + + +impl ToDoc[UVar]: + toDoc(self: UVar) Doc: + let args = Doc.break_(0) + args += Doc.grouped( + Doc.str("id =") + Doc.nested(4, Doc.break_(1) + self.id.toDoc()), + ) + args += Doc.char(',') + Doc.break_(1) + args += Doc.grouped( + Doc.str("kind =") + Doc.nested(4, Doc.break_(1) + self.kind.toDoc()), + ) + args += Doc.char(',') + Doc.break_(1) + args += Doc.grouped( + Doc.str("level =") + + Doc.nested(4, Doc.break_(1) + self.level.toDoc()), + ) + args += Doc.char(',') + Doc.break_(1) + args += Doc.grouped( + Doc.str("link =") + Doc.nested(4, Doc.break_(1) + self.link.toDoc()), + ) + args = args.nest(4).group() + Doc.break_(0) + Doc.char(')') + Doc.grouped(Doc.str("UVar") + Doc.char('(') + args) + + +impl ToDoc[RVar]: + toDoc(self: RVar) Doc: + panic("TODO") + + +impl ToDoc[FunArgs]: + toDoc(self: FunArgs) Doc: + match self: + FunArgs.Positional(i0): + let args = Doc.break_(0) + args += i0.toDoc() + args = args.nest(4).group() + Doc.break_(0) + Doc.char(')') + Doc.grouped(Doc.str("FunArgs.Positional") + Doc.char('(') + args) + FunArgs.Named(i0): + let args = Doc.break_(0) + args += i0.toDoc() + args = args.nest(4).group() + Doc.break_(0) + Doc.char(')') + Doc.grouped(Doc.str("FunArgs.Named") + Doc.char('(') + args) + + +impl ToDoc[Kind]: + toDoc(self: Kind) Doc: + match self: + Kind.Star: Doc.str("Kind.Star") + Kind.Row(i0): + let args = Doc.break_(0) + args += i0.toDoc() + args = args.nest(4).group() + Doc.break_(0) + Doc.char(')') + Doc.grouped(Doc.str("Kind.Row") + Doc.char('(') + args) + + +impl ToDoc[RecordOrVariant]: + toDoc(self: RecordOrVariant) Doc: + match self: + RecordOrVariant.Record: Doc.str("RecordOrVariant.Record") + RecordOrVariant.Variant: Doc.str("RecordOrVariant.Variant") + + +impl ToDoc[Scheme]: + toDoc(self: Scheme) Doc: + let args = Doc.break_(0) + + # args += Doc.grouped( + # Doc.str("qvars =") + # + Doc.nested(4, Doc.break_(1) + self.qvars.toDoc()), + # ) + args += Doc.char(',') + Doc.break_(1) + args += Doc.grouped( + Doc.str("preds =") + + Doc.nested(4, Doc.break_(1) + self.preds.toDoc()), + ) + args += Doc.char(',') + Doc.break_(1) + args += Doc.grouped( + Doc.str("ty =") + Doc.nested(4, Doc.break_(1) + self.ty.toDoc()), + ) + args += Doc.char(',') + Doc.break_(1) + args += Doc.grouped( + Doc.str("loc =") + Doc.nested(4, Doc.break_(1) + self.loc.toDoc()), + ) + args = args.nest(4).group() + Doc.break_(0) + Doc.char(')') + Doc.grouped(Doc.str("Scheme") + Doc.char('(') + args) + + +impl ToDoc[UVarGen]: + toDoc(self: UVarGen) Doc: + let args = Doc.break_(0) + args += Doc.grouped( + Doc.str("_nextId =") + + Doc.nested(4, Doc.break_(1) + self._nextId.toDoc()), + ) + args = args.nest(4).group() + Doc.break_(0) + Doc.char(')') + Doc.grouped(Doc.str("UVarGen") + Doc.char('(') + args) + + +impl ToDoc[Pred]: + toDoc(self: Pred) Doc: + let args = Doc.break_(0) + args += Doc.grouped( + Doc.str("trait_ =") + + Doc.nested(4, Doc.break_(1) + self.trait_.toDoc()), + ) + args += Doc.char(',') + Doc.break_(1) + args += Doc.grouped( + Doc.str("params =") + + Doc.nested(4, Doc.break_(1) + self.params.toDoc()), + ) + args += Doc.char(',') + Doc.break_(1) + args += Doc.grouped( + Doc.str("loc =") + Doc.nested(4, Doc.break_(1) + self.loc.toDoc()), + ) + args = args.nest(4).group() + Doc.break_(0) + Doc.char(')') + Doc.grouped(Doc.str("Pred") + Doc.char('(') + args) + + +impl ToDoc[PredSet]: + toDoc(self: PredSet) Doc: + let args = Doc.break_(0) + args += Doc.grouped( + Doc.str("preds =") + + Doc.nested(4, Doc.break_(1) + self.preds.toDoc()), + ) + args = args.nest(4).group() + Doc.break_(0) + Doc.char(')') + Doc.grouped(Doc.str("PredSet") + Doc.char('(') + args) + + +# ------------------------------------------------------------------------------ +# ToStr implementations. These are used when generating error messages, so they +# should be readable by the users and should not expose implementation details. + +impl ToStr[Ty]: + toStr(self: Ty) Str: + panic("TODO") + + +impl ToStr[Kind]: + toStr(self: Kind) Str: + panic("TODO") + panic("TODO") + + +impl ToStr[RecordOrVariant]: + toStr(self: RecordOrVariant) Str: + panic("TODO") diff --git a/Compiler/TypeCheck/TyCon.fir b/Compiler/TypeCheck/TyCon.fir new file mode 100644 index 00000000..96527f72 --- /dev/null +++ b/Compiler/TypeCheck/TyCon.fir @@ -0,0 +1,128 @@ +import [Compiler/TypeCheck/Ty] + + +type TyCon( + ## Name of the type. + name: Str, + + ## Type parameters with kinds. + tyParams: Vec[TyParam], + + ## Methods for traits, constructor for sums, fields for products. + ## + ## Types can refer to `ty_params` and need to be substituted by the + ## instantiated the types in `ty_params` before use. + details: TyConDetails, +) + + +TyCon.arity(self) U32: + self.tyParams.len() + + +## A type parameter with kind. +type TyParam( + name: Id, + kind: Kind, +) + + +#[derive(ToDoc)] +type TyConDetails: + Trait(TraitDetails) + Type(TypeDetails) + + +#[derive(ToDoc)] +type TraitDetails( + ## Methods of the trait, with optional default implementations. + methods: HashMap[Str, TraitMethod], +) + + +#[derive(ToDoc)] +type TraitMethod( + ## Scheme of the trait method. Can refer to the type parameters of the + ## trait. + scheme: Scheme, + + ## The declaration of the trait method. When this has a body, the body is + ## used as the default implementation of the method. + ## + ## This can refer to the type parameters of the trait. + funDecl: FunDecl, +) + + +#[derive(ToDoc)] +type TypeDetails( + ## Names of value constructors of the type. + cons: Vec[Str], + + ## Whether the type is a sum type. + ## + ## A product type will always have one constructor in `cons`. + ## + ## A sum type can have any number of constructors in `cons`. + sum: Bool, +) + + +# ------------------------------------------------------------------------------ +# Generated ToDoc implementations + + +impl ToDoc[TyConDetails]: + toDoc(self: TyConDetails) Doc: + match self: + TyConDetails.Trait(i0): + let args = Doc.break_(0) + args += i0.toDoc() + args = args.nest(4).group() + Doc.break_(0) + Doc.char(')') + Doc.grouped(Doc.str("TyConDetails.Trait") + Doc.char('(') + args) + TyConDetails.Type(i0): + let args = Doc.break_(0) + args += i0.toDoc() + args = args.nest(4).group() + Doc.break_(0) + Doc.char(')') + Doc.grouped(Doc.str("TyConDetails.Type") + Doc.char('(') + args) + + +impl ToDoc[TraitDetails]: + toDoc(self: TraitDetails) Doc: + let args = Doc.break_(0) + args += Doc.grouped( + Doc.str("methods =") + + Doc.nested(4, Doc.break_(1) + self.methods.toDoc()), + ) + args = args.nest(4).group() + Doc.break_(0) + Doc.char(')') + Doc.grouped(Doc.str("TraitDetails") + Doc.char('(') + args) + + +impl ToDoc[TraitMethod]: + toDoc(self: TraitMethod) Doc: + let args = Doc.break_(0) + args += Doc.grouped( + Doc.str("scheme =") + + Doc.nested(4, Doc.break_(1) + self.scheme.toDoc()), + ) + args += Doc.char(',') + Doc.break_(1) + args += Doc.grouped( + Doc.str("funDecl =") + + Doc.nested(4, Doc.break_(1) + self.funDecl.toDoc()), + ) + args = args.nest(4).group() + Doc.break_(0) + Doc.char(')') + Doc.grouped(Doc.str("TraitMethod") + Doc.char('(') + args) + + +impl ToDoc[TypeDetails]: + toDoc(self: TypeDetails) Doc: + let args = Doc.break_(0) + args += Doc.grouped( + Doc.str("cons =") + Doc.nested(4, Doc.break_(1) + self.cons.toDoc()), + ) + args += Doc.char(',') + Doc.break_(1) + args += Doc.grouped( + Doc.str("sum =") + Doc.nested(4, Doc.break_(1) + self.sum.toDoc()), + ) + args = args.nest(4).group() + Doc.break_(0) + Doc.char(')') + Doc.grouped(Doc.str("TypeDetails") + Doc.char('(') + args) diff --git a/Compiler/TypeCheck/TyMap.fir b/Compiler/TypeCheck/TyMap.fir new file mode 100644 index 00000000..a1f60716 --- /dev/null +++ b/Compiler/TypeCheck/TyMap.fir @@ -0,0 +1,18 @@ +import [ + Compiler/TypeCheck/Ty, + Compiler/TypeCheck/TyCon, +] + + +type TyMap( + cons: HashMap[Str, TyCon], + vars: HashMap[Str, Ty], +) + + +TyMap.getVar(self, var_: Str) Option[Ty]: + self.vars.get(var_) + + +TyMap.getCon(self, con: Str) Option[TyCon]: + self.cons.get(con) diff --git a/Compiler/TypeCheck/Unification.fir b/Compiler/TypeCheck/Unification.fir new file mode 100644 index 00000000..48591f02 --- /dev/null +++ b/Compiler/TypeCheck/Unification.fir @@ -0,0 +1,575 @@ +import [ + Compiler/Assert, + Compiler/Error, + Compiler/TypeCheck/Error, + Compiler/TypeCheck/Normalization, + Compiler/TypeCheck/RowUtils, + Compiler/TypeCheck/Ty, +] + + +unify(ty1: Ty, ty2: Ty, varGen: UVarGen, level: U32, loc: Loc) / TypeError: + ty1 = ty1.normalize() + ty2 = ty2.normalize() + + match (ty1 = ty1, ty2 = ty2): + (ty1 = Ty.Con(id = id1), ty2 = Ty.Con(id = id2)): + if id1.def() != id2.def(): + throw( + TypeError( + loc, + msg = "Unable to unify types `ty1` and `ty2`", + ), + ) + + ( + ty1 = Ty.App(conId = id1, args = args1), + ty2 = Ty.App(conId = id2, args = args2), + ): + if id1.def() != id2.def(): + throw( + TypeError( + loc, + msg = "Unable to unify types `ty1` and `ty2`", + ), + ) + + if args1.len() != args2.len(): + panic( + "`loc`: BUG: Kind error: type constructor TODO applied to different number of arguments in unify", + ) + + for argIdx: U32 in range(u32(0), args1.len()): + unify(args1.get(argIdx), args2.get(argIdx), varGen, level, loc) + + (ty1 = Ty.QVar(_), ..) | (ty2 = Ty.QVar(_), ..): + panic("`loc`: BUG: QVar in unification") + + (ty1 = Ty.UVar(var1), ty2 = Ty.UVar(var2)): + if var1.id == var2.id: + return + + # We've normalized the types, so the links must be followed to the + # end. + assert(var1.link is Option.None) + assert(var2.link is Option.None) + + if var1.level < var2.level: + linkVar(var1, ty2) + else: + linkVar(var2, ty1) + + (ty1 = Ty.UVar(var1), ty2 = ty2): + if var1.kind != ty2.kind(): + throw( + TypeError( + loc, + msg = + "Unable to unify var with kind `var1.kind` with type with kind `ty2.kind()`", + ), + ) + + linkVar(var1, ty2) + + (ty1 = ty1, ty2 = Ty.UVar(var2)): + if var2.kind != ty1.kind(): + throw( + TypeError( + loc, + msg = + "Unable to unify var with kind `var2.kind` with type with kind `ty1.kind()`", + ), + ) + + linkVar(var2, ty1) + + ( + ty1 = Ty.Fun(args = args1, ret = ret1, exn = exn1), + ty2 = Ty.Fun(args = args2, ret = ret2, exn = exn2), + ): + if args1.len() != args2.len(): + throw( + TypeError( + loc, + msg = + "Unable to unify functions `ty1` and `ty2` (argument numbers don't match)", + ), + ) + + match (args1 = args1, args2 = args2): + ( + args1 = FunArgs.Positional(args1), + args2 = FunArgs.Positional(args2), + ): + for argIdx: U32 in range(u32(0), args1.len()): + unify( + args1.get(argIdx), + args2.get(argIdx), + varGen, + level, + loc, + ) + + (args1 = FunArgs.Named(args1), args2 = FunArgs.Named(args2)): + for arg1: HashMapEntry[LocalId, Ty] in args1.iter(): + let arg2 = args2.get(arg1.key).unwrapOrElse( + ||: + throw( + TypeError( + loc, + msg = + "Unable to unify functions with different named arguments", + ), + ), + ) + unify(arg1.value, arg2, varGen, level, loc) + + _: + throw( + TypeError( + loc, + msg = + "Unable to unify functions with positional and named arguments", + ), + ) + + match (exn1 = exn1, exn2 = exn2): + (exn1 = Option.Some(exn1), exn2 = Option.Some(exn2)): + unify(exn1, exn2, varGen, level, loc) + + _: + # In all other cases we have `None` as at least one of the + # types. `None` comes from a constructor, and constructors + # can't throw. So we let these unify. + () + + unify(ret1, ret2, varGen, level, loc) + + ( + ty1 = + Ty.Anonymous( + labels = labels1, + extension = extension1, + kind = kind1, + isRow = isRow1 + ), + ty2 = + Ty.Anonymous( + labels = labels2, + extension = extension2, + kind = kind2, + isRow = isRow2 + ), + ): + # Kind mismatches can happen when try to unify a record with a + # variant (e.g. pass a record when a variant is expected), and fail. + if kind1 != kind2: + throw( + TypeError( + loc, + msg = "Unable to unify `kind1` `ty1` with `kind2` `ty2`", + ), + ) + + # If we checked the kinds in type applications properly, we should + # only try to unify rows with rows and stars with stars. + assert(isRow1 == isRow2) + + let rows1 = collectRows(ty1, kind1, labels1, extension1) + let labels1 = rows1.rows + let extension1 = rows1.extension + + let rows2 = collectRows(ty2, kind2, labels2, extension2) + let labels2 = rows2.rows + let extension2 = rows2.extension + + let keys1: HashSet[LocalId] = HashSet.fromIter(labels1.keys()) + let keys2: HashSet[LocalId] = HashSet.fromIter(labels2.keys()) + + # Unify common labels. + for key: LocalId in keys1.intersection(keys2): + unify( + labels1.get(key).unwrap(), + labels2.get(key).unwrap(), + varGen, + level, + loc, + ) + + # Extra labels in one type will be added to the extension of the + # other. + let extras1: HashSet[LocalId] = HashSet.fromIter( + keys1.difference(keys2), + ) + let extras2: HashSet[LocalId] = HashSet.fromIter( + keys2.difference(keys1), + ) + + let kindStr = match kind1: + RecordOrVariant.Record: "record" + RecordOrVariant.Variant: "variant" + + let labelStr = match kind1: + RecordOrVariant.Record: "field" + RecordOrVariant.Variant: "constructor" + + # Add extras to rows. + if not extras1.isEmpty(): + match extension2: + Option.Some(Ty.UVar(uvar)): + # TODO: Not sure about level + extension2 = Option.Some( + Ty.UVar( + linkExtension( + kind2, + extras1, + labels1, + uvar, + varGen, + level, + loc, + ), + ), + ) + _: + throw( + TypeError( + loc, + msg = + "Unable to unify `kindStr` with `labelStr`s `keys1` with `kindStr` with `labelStr`s `keys2`", + ), + ) + + if not extras2.isEmpty(): + match extension1: + Option.Some(Ty.UVar(uvar)): + # TODO: Not sure about level + extension1 = Option.Some( + Ty.UVar( + linkExtension( + kind1, + extras2, + labels2, + uvar, + varGen, + level, + loc, + ), + ), + ) + _: + throw( + TypeError( + loc, + msg = + "Unable to unify `kindStr` with `labelStr`s `keys1` with `kindStr` with `labelStr`s `keys2`", + ), + ) + + match (ext1 = extension1, ext2 = extension2): + (ext1 = Option.None, ext2 = Option.None): () + + (ext1 = Option.Some(ext1), ext2 = Option.None): + unify(ext1, Ty.unitRow(kind1), varGen, level, loc) + + (ext1 = Option.None, ext2 = Option.Some(ext2)): + unify(ext2, Ty.unitRow(kind2), varGen, level, loc) + + (ext1 = Option.Some(ext1), ext2 = Option.Some(ext2)): + unify(ext1, ext2, varGen, level, loc) + + _: + throw( + TypeError(loc, msg = "Unable to unify types `ty1` and `ty2`",), + ) + + +## Unify `ty1` with `ty2`, without updating `t2`. +## +## Unlike `unify`, this does not throw on errors. Returns whether unification +## was successful. +# +# Currently this has two use cases: +# +# - When matching an `impl` arguments with given type arguments, to be able to +# select implementation of a predicate. Here `ty1` is instantiated type +# arguments of an `impl` and `ty2` is the argument we have from a trait type +# application. +# +# - When selecting a method in a method call, we unify a candidate's receiver +# type (`ty1`) with the actual receiver type (`ty2`). If the unification is +# successful, we consider the method as a potential target. +# +unifyOneWay(ty1: Ty, ty2: Ty, varGen: UVarGen, level: U32, loc: Loc) Bool: + ty1 = ty1.normalize() + ty2 = ty2.normalize() + + match (ty1 = ty1, ty2 = ty2): + (ty1 = Ty.Con(id = id1), ty2 = Ty.Con(id = id2)): id1.def() == id2.def() + + ( + ty1 = Ty.App(conId = id1, args = args1), + ty2 = Ty.App(conId = id2, args = args2), + ): + if id1.def() != id2.def(): + return Bool.False + + if args1.len() != args2.len(): + return Bool.False + + for argIdx: U32 in range(u32(0), args1.len()): + if not unifyOneWay( + args1.get(argIdx), + args2.get(argIdx), + varGen, + level, + loc, + ): + return Bool.False + + Bool.True + + (ty1 = Ty.QVar(_), ..) | (ty2 = Ty.QVar(_), ..): + panic("`loc`: BUG: QVar in unification") + + (ty1 = Ty.UVar(var1), ty2 = Ty.UVar(var2)): + # TODO: Interpreter doesn't have kind check here. + if var1.kind != ty2.kind(): + return Bool.False + + if var1 == var2: + return Bool.True + + # TODO: Copied from the interpreter, this updates ty2. Is this + # expected? We promise to not update ty2 in the documentation. + if var1.level > var2.level: + var1.setLink(ty2) + else: + var2.setLink(ty1) + + Bool.True + + ( + ty1 = Ty.Fun(args = args1, ret = ret1, exn = exn1), + ty2 = Ty.Fun(args = args2, ret = ret2, exn = exn2), + ): + if args1.len() != args2.len(): + return Bool.False + + match (args1 = args1, args2 = args2): + ( + args1 = FunArgs.Positional(args1), + args2 = FunArgs.Positional(args2), + ): + for argIdx: U32 in range(u32(0), args1.len()): + if not unifyOneWay( + args1.get(argIdx), + args2.get(argIdx), + varGen, + level, + loc, + ): + return Bool.False + + (args1 = FunArgs.Named(args1), args2 = FunArgs.Named(args2)): + let argNames1: HashSet[LocalId] = HashSet.fromIter( + args1.keys(), + ) + let argNames2: HashSet[LocalId] = HashSet.fromIter( + args2.keys(), + ) + + if argNames1 != argNames2: + return Bool.False + + for argName: LocalId in argNames1.iter(): + if not unifyOneWay( + args1.get(argName).unwrap(), + args2.get(argName).unwrap(), + varGen, + level, + loc, + ): + return Bool.False + + _: return Bool.False + + match (exn1 = exn1, exn2 = exn2): + (exn1 = Option.None, exn2 = Option.None): () + + (exn1 = Option.None, exn2 = Option.Some(_)) + | (exn1 = Option.Some(_), exn2 = Option.None): + # None is the same as [..r] with a fresh r, so it unifies + # with everything. + () + + (exn1 = Option.Some(exn1), exn2 = Option.Some(exn2)): + if not unifyOneWay(exn1, exn2, varGen, level, loc): + return Bool.False + + unifyOneWay(ret1, ret2, varGen, level, loc) + + ( + ty1 = + Ty.Anonymous( + labels = labels1, + extension = extension1, + kind = kind1, + isRow = isRow1 + ), + ty2 = + Ty.Anonymous( + labels = labels2, + extension = extension2, + kind = kind2, + isRow = isRow2 + ), + ): + # Kind mismatches can happen when try to unify a record with a + # variant (e.g. pass a record when a variant is expected), and fail. + if kind1 != kind2: + return Bool.False + + # If we checked the kinds in type applications properly, we should + # only try to unify rows with rows and stars with stars. + assert(isRow1 == isRow2) + + let rows1 = collectRows(ty1, kind1, labels1, extension1) + let labels1 = rows1.rows + let extension1 = rows1.extension + + let rows2 = collectRows(ty2, kind2, labels2, extension2) + let labels2 = rows2.rows + let extension2 = rows2.extension + + let keys1: HashSet[LocalId] = HashSet.fromIter(labels1.keys()) + let keys2: HashSet[LocalId] = HashSet.fromIter(labels2.keys()) + + # Extra labels in one type will be added to the extension of the + # other. + let extras1: HashSet[LocalId] = HashSet.fromIter( + keys1.difference(keys2), + ) + let extras2: HashSet[LocalId] = HashSet.fromIter( + keys2.difference(keys1), + ) + + # Unify common labels. + for key: LocalId in keys1.intersection(keys2): + let ty1 = labels1.get(key).unwrap() + let ty2 = labels2.get(key).unwrap() + if not unifyOneWay(ty1, ty2, varGen, level, loc): + return Bool.False + + if not extras1.isEmpty(): + return Bool.False + + if not extras2.isEmpty(): + match extension1: + Option.Some(Ty.UVar(uvar)): + # TODO: Not sure about level + extension1 = Option.Some( + Ty.UVar( + linkExtension( + kind1, + extras2, + labels2, + uvar, + varGen, + level, + loc, + ), + ), + ) + _: return Bool.False + + match (ext1 = extension1, ext2 = extension2): + (ext1 = Option.None, ext2 = Option.None): Bool.True + + (ext1 = Option.Some(ext1), ext2 = Option.None): + unifyOneWay(ext1, Ty.unitRow(kind1), varGen, level, loc) + + (ext1 = Option.None, ext2 = Option.Some(ext2)): Bool.False + + (ext1 = Option.Some(ext1), ext2 = Option.Some(ext2)): + unifyOneWay(ext1, ext2, varGen, level, loc) + + _: Bool.False + + +# TODO: This can't be `UVar.link` as `link` is also a field. +# Maybe for now rename the field (add underscore). +linkVar(var_: UVar, ty: Ty): + # TODO: Occurs check. + ty.pruneLevel(var_.level) + var_.setLink(ty) + + +linkExtension( + kind: RecordOrVariant, + extraLabels: HashSet[LocalId], + labelValues: HashMap[LocalId, Ty], + uvar: UVar, + varGen: UVarGen, + level: U32, + loc: Loc, +) UVar: + let extensionLabels: HashMap[LocalId, Ty] = HashMap.fromIter( + extraLabels.iter().map( + |label: LocalId|: + (key = label, value = labelValues.get(label).unwrap()), + ), + ) + + # TODO: Not sure about the level. + let newExtensionVar = varGen.newVar(level, Kind.Row(kind), loc) + let newExtensionTy = Ty.Anonymous( + labels = extensionLabels, + extension = Option.Some(Ty.UVar(newExtensionVar)), + kind, + isRow = Bool.True, + ) + + uvar.setLink(newExtensionTy) + newExtensionVar + + +Ty.pruneLevel(self, maxLevel: U32): + match self: + Ty.Con(..): () + + Ty.App(args, ..): + for arg: Ty in args.iter(): + arg.pruneLevel(maxLevel) + + Ty.QVar(..): panic("QVar in pruneLevel") + + Ty.UVar(uvar): uvar.pruneLevel(maxLevel) + + Ty.RVar(..): () + + Ty.Fun(args, ret, exn): + match args: + FunArgs.Positional(tys): + for ty: Ty in tys.iter(): + ty.pruneLevel(maxLevel) + + FunArgs.Named(tys): + for ty: Ty in tys.values(): + ty.pruneLevel(maxLevel) + + ret.pruneLevel(maxLevel) + + if exn is Option.Some(exn): + exn.pruneLevel(maxLevel) + + Ty.Anonymous(labels, extension, ..): + for ty: Ty in labels.values(): + ty.pruneLevel(maxLevel) + + if extension is Option.Some(extension): + extension.pruneLevel(maxLevel) + + +UVar.pruneLevel(self, maxLevel: U32): + self.level = min(self.level, maxLevel) diff --git a/Fir/HashMap.fir b/Fir/HashMap.fir index b1870508..507d75dc 100644 --- a/Fir/HashMap.fir +++ b/Fir/HashMap.fir @@ -155,6 +155,10 @@ HashMap.get[Hash[k], Eq[k]](self: HashMap[k, v], key: k) Option[v]: panic("unreachable") +HashMap.containsKey[Hash[k], Eq[k]](self: HashMap[k, v], key: k) Bool: + self.get(key) is Option.Some(_) + + impl[ToStr[k], ToStr[v]] ToStr[HashMap[k, v]]: toStr(self: HashMap[k, v]) Str: let buf = StrBuf.withCapacity(100) diff --git a/Fir/Vec.fir b/Fir/Vec.fir index 2c25f6d4..692ca5a2 100644 --- a/Fir/Vec.fir +++ b/Fir/Vec.fir @@ -237,6 +237,14 @@ impl[Eq[t]] Eq[Vec[t]]: Bool.True +impl[Hash[t]] Hash[Vec[t]]: + hash(self: Vec[t]) U32: + let hash = self.len() + for t: t in self.iter(): + hash += t.hash() + hash + + type VecIter[t]( _vec: Vec[t], _idx: U32, diff --git a/src/ast.rs b/src/ast.rs index 25066f6e..ac3fd9aa 100644 --- a/src/ast.rs +++ b/src/ast.rs @@ -756,10 +756,10 @@ pub struct Context { /// /// ```ignore /// impl[ToStr[a]] ToStr[Vec[a]]: -/// toStr(self): Str = ... +/// toStr(self) Str: ... /// /// impl Iterator[VecIter[a], a]: -/// next(self): Option[a] = ... +/// next(self) Option[a]: ... /// ``` #[derive(Debug, Clone)] pub struct ImplDecl { diff --git a/src/scanner.rs b/src/scanner.rs index 9d55ac2b..7ce9ec03 100644 --- a/src/scanner.rs +++ b/src/scanner.rs @@ -17,7 +17,7 @@ where None => return vec![], }; scan_indented(&mut tokens, module, &mut new_tokens, start_loc); - assert_eq!(tokens.next(), None); + assert_eq!(tokens.next(), None, "module = {module}"); new_tokens }