From c8ea6f6c9b81ca37a4798f4007f9596fcc6adcbb Mon Sep 17 00:00:00 2001 From: Shunsuke Shibayama Date: Tue, 6 Jun 2023 00:04:18 +0900 Subject: [PATCH] fix: nested polymorphic type check bug --- crates/erg_compiler/context/compare.rs | 35 +++++++++---- crates/erg_compiler/context/eval.rs | 62 +++++++++++++++++++++++ crates/erg_compiler/ty/mod.rs | 68 ++++++++++++++++++++++++++ crates/erg_compiler/ty/predicate.rs | 15 ++++++ crates/erg_compiler/ty/typaram.rs | 23 +++++++++ tests/should_err/subtyping.er | 6 +++ tests/test.rs | 2 +- 7 files changed, 200 insertions(+), 11 deletions(-) diff --git a/crates/erg_compiler/context/compare.rs b/crates/erg_compiler/context/compare.rs index 05d925be5..519344549 100644 --- a/crates/erg_compiler/context/compare.rs +++ b/crates/erg_compiler/context/compare.rs @@ -301,6 +301,13 @@ impl Context { panic!("err: {err}"); } } + } else if typ.has_undoable_linked_var() { + if let Err(err) = self.overwrite_typarams(typ, rhs) { + Self::undo_substitute_typarams(typ); + if DEBUG_MODE { + panic!("err: {err}"); + } + } } for rhs_sup in rhs_ctx.super_traits.iter() { // Not `supertype_of` (only structures are compared) @@ -459,8 +466,8 @@ impl Context { // => ?P.undoable_link(Int) // => Mul Int :> Int (FreeVar(lfv), rhs) => { - if let FreeKind::Linked(t) | FreeKind::UndoableLinked { t, .. } = &*lfv.borrow() { - return self.supertype_of(t, rhs); + if let Some(t) = lfv.get_linked() { + return self.supertype_of(&t, rhs); } if let Some((_sub, sup)) = lfv.get_subsup() { lfv.undoable_link(rhs); @@ -482,8 +489,8 @@ impl Context { } } (lhs, FreeVar(rfv)) => { - if let FreeKind::Linked(t) | FreeKind::UndoableLinked { t, .. } = &*rfv.borrow() { - return self.supertype_of(lhs, t); + if let Some(t) = rfv.get_linked() { + return self.supertype_of(lhs, &t); } if let Some((sub, _sup)) = rfv.get_subsup() { rfv.undoable_link(lhs); @@ -908,15 +915,23 @@ impl Context { } } _ => { - if let (Ok(sup), Ok(sub)) = ( + match ( self.convert_tp_into_type(sup_p.clone()), self.convert_tp_into_type(sub_p.clone()), ) { - return match variance { - Variance::Contravariant => self.subtype_of(&sup, &sub), - Variance::Covariant => self.supertype_of(&sup, &sub), - Variance::Invariant => self.same_type_of(&sup, &sub), - }; + (Ok(sup), Ok(sub)) => { + return match variance { + Variance::Contravariant => self.subtype_of(&sup, &sub), + Variance::Covariant => self.supertype_of(&sup, &sub), + Variance::Invariant => self.same_type_of(&sup, &sub), + }; + } + (Err(le), Err(re)) => { + log!(err "cannot convert {le}, {re} to types") + } + (Err(err), _) | (_, Err(err)) => { + log!(err "cannot convert {err} to a type"); + } } self.eq_tp(sup_p, sub_p) } diff --git a/crates/erg_compiler/context/eval.rs b/crates/erg_compiler/context/eval.rs index 0a5c3099c..6d89949fa 100644 --- a/crates/erg_compiler/context/eval.rs +++ b/crates/erg_compiler/context/eval.rs @@ -1338,6 +1338,11 @@ impl Context { TyParam::FreeVar(fv) if fv.is_linked() => self.convert_tp_into_type(fv.crack().clone()), TyParam::Type(t) => Ok(t.as_ref().clone()), TyParam::Mono(name) => Ok(Type::Mono(name)), + TyParam::App { name, args } => Ok(Type::Poly { name, params: args }), + TyParam::Proj { obj, attr } => { + let lhs = self.convert_tp_into_type(*obj)?; + Ok(lhs.proj(attr)) + } // TyParam::Erased(_t) => Ok(Type::Obj), TyParam::Value(v) => self.convert_value_into_type(v).map_err(TyParam::Value), // TODO: Dict, Set @@ -1650,6 +1655,20 @@ impl Context { Ok(()) } + pub(crate) fn overwrite_typarams(&self, qt: &Type, st: &Type) -> EvalResult<()> { + let qtps = qt.typarams(); + let stps = st.typarams(); + if qt.qual_name() != st.qual_name() || qtps.len() != stps.len() { + log!(err "{qt} / {st}"); + log!(err "[{}] [{}]", erg_common::fmt_vec(&qtps), erg_common::fmt_vec(&stps)); + return Ok(()); // TODO: e.g. Sub(Int) / Eq and Sub(?T) + } + for (qtp, stp) in qtps.into_iter().zip(stps.into_iter()) { + self.overwrite_typaram(qtp, stp)?; + } + Ok(()) + } + fn substitute_typaram(&self, qtp: TyParam, stp: TyParam) -> EvalResult<()> { match qtp { TyParam::FreeVar(ref fv) if fv.is_generalized() => { @@ -1693,6 +1712,49 @@ impl Context { Ok(()) } + fn overwrite_typaram(&self, qtp: TyParam, stp: TyParam) -> EvalResult<()> { + match qtp { + TyParam::FreeVar(ref fv) if fv.is_undoable_linked() => { + if !stp.is_unbound_var() || !stp.is_generalized() { + fv.undoable_link(&stp); + } + if let Err(errs) = self.sub_unify_tp(&stp, &qtp, None, &(), false) { + log!(err "{errs}"); + } + Ok(()) + } + TyParam::Type(qt) => self.overwrite_type(stp, *qt), + TyParam::Value(ValueObj::Type(qt)) => self.overwrite_type(stp, qt.into_typ()), + _ => Ok(()), + } + } + + fn overwrite_type(&self, stp: TyParam, qt: Type) -> EvalResult<()> { + let st = self.convert_tp_into_type(stp).map_err(|tp| { + EvalError::not_a_type_error( + self.cfg.input.clone(), + line!() as usize, + ().loc(), + self.caused_by(), + &tp.to_string(), + ) + })?; + if qt.has_undoable_linked_var() { + if let Ok(qt) = <&FreeTyVar>::try_from(&qt) { + if !st.is_unbound_var() || !st.is_generalized() { + qt.undoable_link(&st); + } + } + } + if !st.is_unbound_var() || !st.is_generalized() { + self.overwrite_typarams(&qt, &st)?; + } + if let Err(errs) = self.sub_unify(&st, &qt, &(), None) { + log!(err "{errs}"); + } + Ok(()) + } + pub(crate) fn undo_substitute_typarams(substituted_q: &Type) { for tp in substituted_q.typarams().into_iter() { match tp { diff --git a/crates/erg_compiler/ty/mod.rs b/crates/erg_compiler/ty/mod.rs index c6895cdbf..a77c7c0b0 100644 --- a/crates/erg_compiler/ty/mod.rs +++ b/crates/erg_compiler/ty/mod.rs @@ -444,6 +444,22 @@ impl SubrType { || self.return_t.has_qvar() } + pub fn has_undoable_linked_var(&self) -> bool { + self.non_default_params + .iter() + .any(|pt| pt.typ().has_undoable_linked_var()) + || self + .var_params + .as_ref() + .map(|pt| pt.typ().has_undoable_linked_var()) + .unwrap_or(false) + || self + .default_params + .iter() + .any(|pt| pt.typ().has_undoable_linked_var()) + || self.return_t.has_undoable_linked_var() + } + pub fn typarams(&self) -> Vec { [ self.non_default_params @@ -2450,6 +2466,58 @@ impl Type { } } + pub fn has_undoable_linked_var(&self) -> bool { + match self { + Self::FreeVar(fv) if fv.is_undoable_linked() => true, + Self::FreeVar(fv) if fv.is_linked() => fv.crack().has_undoable_linked_var(), + Self::FreeVar(fv) => { + if let Some((sub, sup)) = fv.get_subsup() { + fv.dummy_link(); + let res_sub = sub.has_undoable_linked_var(); + let res_sup = sup.has_undoable_linked_var(); + fv.undo(); + res_sub || res_sup + } else { + let opt_t = fv.get_type(); + opt_t.map_or(false, |t| t.has_undoable_linked_var()) + } + } + Self::Ref(ty) => ty.has_undoable_linked_var(), + Self::RefMut { before, after } => { + before.has_undoable_linked_var() + || after + .as_ref() + .map(|t| t.has_undoable_linked_var()) + .unwrap_or(false) + } + Self::And(lhs, rhs) | Self::Or(lhs, rhs) => { + lhs.has_undoable_linked_var() || rhs.has_undoable_linked_var() + } + Self::Not(ty) => ty.has_undoable_linked_var(), + Self::Callable { param_ts, return_t } => { + param_ts.iter().any(|t| t.has_undoable_linked_var()) + || return_t.has_undoable_linked_var() + } + Self::Subr(subr) => subr.has_undoable_linked_var(), + Self::Quantified(quant) => quant.has_undoable_linked_var(), + Self::Record(r) => r.values().any(|t| t.has_undoable_linked_var()), + Self::Refinement(refine) => { + refine.t.has_undoable_linked_var() || refine.pred.has_undoable_linked_var() + } + Self::Poly { params, .. } => params.iter().any(|tp| tp.has_undoable_linked_var()), + Self::Proj { lhs, .. } => lhs.has_undoable_linked_var(), + Self::ProjCall { lhs, args, .. } => { + lhs.has_undoable_linked_var() || args.iter().any(|tp| tp.has_undoable_linked_var()) + } + Self::Structural(ty) => ty.has_undoable_linked_var(), + Self::Guard(guard) => guard.to.has_undoable_linked_var(), + Self::Bounded { sub, sup } => { + sub.has_undoable_linked_var() || sup.has_undoable_linked_var() + } + _ => false, + } + } + pub fn has_no_qvar(&self) -> bool { !self.has_qvar() } diff --git a/crates/erg_compiler/ty/predicate.rs b/crates/erg_compiler/ty/predicate.rs index 7d2265526..c256fb121 100644 --- a/crates/erg_compiler/ty/predicate.rs +++ b/crates/erg_compiler/ty/predicate.rs @@ -341,6 +341,21 @@ impl Predicate { } } + pub fn has_undoable_linked_var(&self) -> bool { + match self { + Self::Value(_) => false, + Self::Const(_) => false, + Self::Equal { rhs, .. } + | Self::GreaterEqual { rhs, .. } + | Self::LessEqual { rhs, .. } + | Self::NotEqual { rhs, .. } => rhs.has_undoable_linked_var(), + Self::Or(lhs, rhs) | Self::And(lhs, rhs) => { + lhs.has_undoable_linked_var() || rhs.has_undoable_linked_var() + } + Self::Not(pred) => pred.has_undoable_linked_var(), + } + } + pub fn min_max<'a>( &'a self, min: Option<&'a TyParam>, diff --git a/crates/erg_compiler/ty/typaram.rs b/crates/erg_compiler/ty/typaram.rs index 4295458b4..274c446ce 100644 --- a/crates/erg_compiler/ty/typaram.rs +++ b/crates/erg_compiler/ty/typaram.rs @@ -1062,6 +1062,29 @@ impl TyParam { !self.has_unbound_var() } + pub fn has_undoable_linked_var(&self) -> bool { + match self { + Self::FreeVar(fv) => fv.is_undoable_linked(), + Self::Type(t) => t.has_undoable_linked_var(), + Self::Proj { obj, .. } => obj.has_undoable_linked_var(), + Self::Array(ts) | Self::Tuple(ts) => ts.iter().any(|t| t.has_undoable_linked_var()), + Self::Set(ts) => ts.iter().any(|t| t.has_undoable_linked_var()), + Self::Dict(kv) => kv + .iter() + .any(|(k, v)| k.has_undoable_linked_var() || v.has_undoable_linked_var()), + Self::Record(rec) => rec.iter().any(|(_, v)| v.has_undoable_linked_var()), + Self::Lambda(lambda) => lambda.body.iter().any(|t| t.has_undoable_linked_var()), + Self::UnaryOp { val, .. } => val.has_undoable_linked_var(), + Self::BinOp { lhs, rhs, .. } => { + lhs.has_undoable_linked_var() || rhs.has_undoable_linked_var() + } + Self::App { args, .. } => args.iter().any(|p| p.has_undoable_linked_var()), + Self::Erased(t) => t.has_undoable_linked_var(), + Self::Value(ValueObj::Type(t)) => t.typ().has_undoable_linked_var(), + _ => false, + } + } + pub fn union_size(&self) -> usize { match self { Self::FreeVar(fv) if fv.is_linked() => fv.crack().union_size(), diff --git a/tests/should_err/subtyping.er b/tests/should_err/subtyping.er index 14e6fe99a..dfb153176 100644 --- a/tests/should_err/subtyping.er +++ b/tests/should_err/subtyping.er @@ -70,3 +70,9 @@ _: Array!({"a", "b"}, 2) = !["a", "b"] # OK _: Array!({"a", "b", "c"}, 2) = !["a", "b"] # OK _: Array!({"a", "c"}, 2) = !["a", "b"] # ERR _: Array!({"a"}, 2) = !["a", "b"] # ERR + +ii _: Iterable(Iterable(Str)) = None +ii [1] # ERR +ii [[1]] # ERR +ii [["a"]] +ii ["aaa"] # Str <: Iterable Str diff --git a/tests/test.rs b/tests/test.rs index 81a7bd064..3991727bf 100644 --- a/tests/test.rs +++ b/tests/test.rs @@ -345,7 +345,7 @@ fn exec_structural_err() -> Result<(), ()> { #[test] fn exec_subtyping_err() -> Result<(), ()> { - expect_failure("tests/should_err/subtyping.er", 0, 15) + expect_failure("tests/should_err/subtyping.er", 0, 17) } #[test]