Skip to content

Commit

Permalink
fix: trait impl bugs (2)
Browse files Browse the repository at this point in the history
  • Loading branch information
mtshiba committed Dec 26, 2024
1 parent 017b13f commit 40b53a3
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 25 deletions.
5 changes: 5 additions & 0 deletions crates/erg_common/dict.rs
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,11 @@ impl<K, V> Dict<K, V> {
self.dict.retain(f);
}

pub fn retained(mut self, f: impl FnMut(&K, &mut V) -> bool) -> Self {
self.retain(f);
self
}

pub fn get_by(&self, k: &K, cmp: impl Fn(&K, &K) -> bool) -> Option<&V> {
for (k_, v) in self.dict.iter() {
if cmp(k, k_) {
Expand Down
54 changes: 37 additions & 17 deletions crates/erg_compiler/context/register.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,13 @@ use super::instantiate::TyVarCache;
use super::instantiate_spec::ParamKind;
use super::{ControlKind, MethodContext, ParamSpec, TraitImpl, TypeContext};

type ClassTrait<'c> = (Type, Option<(Type, &'c TypeSpecWithOp)>);
type ClassTraitErrors<'c> = (
Option<Type>,
Option<(Type, &'c TypeSpecWithOp)>,
TyCheckErrors,
);

pub fn valid_mod_name(name: &str) -> bool {
!name.is_empty() && !name.starts_with('/') && name.trim() == name
}
Expand Down Expand Up @@ -987,7 +994,7 @@ impl Context {
pub(crate) fn get_class_and_impl_trait<'c>(
&mut self,
class_spec: &'c ast::TypeSpec,
) -> TyCheckResult<(Type, Option<(Type, &'c TypeSpecWithOp)>)> {
) -> Result<ClassTrait<'c>, ClassTraitErrors<'c>> {
let mut errs = TyCheckErrors::empty();
let mut dummy_tv_cache = TyVarCache::new(self.level, self);
match class_spec {
Expand All @@ -1013,14 +1020,21 @@ impl Context {
(t, &tasc.t_spec)
}
other => {
return Err(TyCheckErrors::from(TyCheckError::syntax_error(
self.cfg.input.clone(),
line!() as usize,
other.loc(),
self.caused_by(),
format!("expected type ascription, but found {}", other.name()),
return Err((
None,
None,
)))
TyCheckErrors::from(TyCheckError::syntax_error(
self.cfg.input.clone(),
line!() as usize,
other.loc(),
self.caused_by(),
format!(
"expected type ascription, but found {}",
other.name()
),
None,
)),
))
}
};
let class = match self.instantiate_typespec_full(
Expand All @@ -1039,7 +1053,7 @@ impl Context {
if errs.is_empty() {
Ok((class, Some((impl_trait, t_spec))))
} else {
Err(errs)
Err((Some(class), Some((impl_trait, t_spec)), errs))
}
}
ast::TypeAppArgsKind::SubtypeOf(trait_spec) => {
Expand All @@ -1052,8 +1066,10 @@ impl Context {
) {
Ok(t) => t,
Err((t, es)) => {
errs.extend(es);
t
if !PYTHON_MODE {
errs.extend(es);
}
t.replace(&Type::Failure, &Type::Never)
}
};
let class = match self.instantiate_typespec_full(
Expand All @@ -1072,7 +1088,7 @@ impl Context {
if errs.is_empty() {
Ok((class, Some((impl_trait, trait_spec.as_ref()))))
} else {
Err(errs)
Err((Some(class), Some((impl_trait, trait_spec.as_ref())), errs))
}
}
}
Expand All @@ -1094,7 +1110,7 @@ impl Context {
if errs.is_empty() {
Ok((t, None))
} else {
Err(errs)
Err((Some(t), None, errs))
}
}
}
Expand Down Expand Up @@ -1243,10 +1259,14 @@ impl Context {
.instantiate_vis_modifier(class_def.def.sig.vis())
.unwrap_or(VisibilityModifier::Public);
for methods in class_def.methods_list.iter() {
let Ok((class, impl_trait)) = self.get_class_and_impl_trait(&methods.class)
else {
continue;
};
let (class, impl_trait) =
match self.get_class_and_impl_trait(&methods.class) {
Ok(x) => x,
Err((class, trait_, errs)) => {
total_errs.extend(errs);
(class.unwrap_or(Type::Obj), trait_)
}
};
// assume the class has implemented the trait, regardless of whether the implementation is correct
if let Some((trait_, trait_loc)) = &impl_trait {
if let Err(errs) = self.register_trait_impl(&class, trait_, *trait_loc)
Expand Down
25 changes: 18 additions & 7 deletions crates/erg_compiler/lower.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2705,9 +2705,9 @@ impl<A: ASTBuildable> GenericASTLowerer<A> {
let (class, impl_trait) =
match self.module.context.get_class_and_impl_trait(&methods.class) {
Ok(x) => x,
Err(errs) => {
Err((class, trait_, errs)) => {
errors.extend(errs);
continue;
(class.unwrap_or(Type::Obj), trait_)
}
};
if let Some(class_root) = self.module.context.get_nominal_type_ctx(&class) {
Expand Down Expand Up @@ -3224,16 +3224,27 @@ impl<A: ASTBuildable> GenericASTLowerer<A> {
) -> (Set<&VarName>, CompileErrors) {
let mut errors = CompileErrors::empty();
let mut unverified_names = self.module.context.locals.keys().collect::<Set<_>>();
let mut super_impls = set! {};
let tys_decls = if let Some(sups) = self.module.context.get_super_types(trait_type) {
sups.map(|sup| {
if implemented.linear_contains(&sup) {
return (sup, Dict::new());
}
let decls = self
.module
.context
.get_nominal_type_ctx(&sup)
.map_or(Dict::new(), |ctx| ctx.decls.clone());
let decls =
self.module
.context
.get_nominal_type_ctx(&sup)
.map_or(Dict::new(), |ctx| {
super_impls.extend(ctx.locals.keys());
for methods in &ctx.methods_list {
super_impls.extend(methods.locals.keys());
}
ctx.decls.clone().retained(|k, _| {
let implemented_in_super = super_impls.contains(k);
let class_decl = ctx.kind.is_class();
!implemented_in_super && !class_decl
})
});
(sup, decls)
})
.collect::<Vec<_>>()
Expand Down
2 changes: 1 addition & 1 deletion crates/erg_compiler/module/promise.rs
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ impl SharedPromises {
}

pub fn wait_until_finished(&self, path: &NormalizedPathBuf) {
if self.promises.borrow().get(path).is_none() {
if !self.graph.entries().contains(path) {
panic!("not registered: {path}");
}
while !self.is_finished(path) {
Expand Down

0 comments on commit 40b53a3

Please sign in to comment.