Skip to content

Commit

Permalink
Merge pull request #481 from erg-lang/general_pred
Browse files Browse the repository at this point in the history
Enhance dependent types & refinement types
  • Loading branch information
mtshiba authored Jan 30, 2024
2 parents 2939c74 + 83cb225 commit 107e5a0
Show file tree
Hide file tree
Showing 34 changed files with 2,859 additions and 304 deletions.
9 changes: 8 additions & 1 deletion crates/erg_common/dict.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
use std::borrow::Borrow;
use std::collections::hash_map::{Entry, IntoValues, Iter, IterMut, Keys, Values, ValuesMut};
use std::collections::hash_map::{
Entry, IntoKeys, IntoValues, Iter, IterMut, Keys, Values, ValuesMut,
};
use std::fmt::{self, Write};
use std::hash::{Hash, Hasher};
use std::iter::FromIterator;
Expand Down Expand Up @@ -138,6 +140,11 @@ impl<K, V> Dict<K, V> {
self.dict.into_values()
}

#[inline]
pub fn into_keys(self) -> IntoKeys<K, V> {
self.dict.into_keys()
}

#[inline]
pub fn iter(&self) -> Iter<K, V> {
self.dict.iter()
Expand Down
14 changes: 14 additions & 0 deletions crates/erg_common/triple.rs
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,20 @@ impl<T, E> Triple<T, E> {
Triple::Err(err) => Triple::Err(f(err)),
}
}

pub fn is_ok_and(self, f: impl FnOnce(T) -> bool) -> bool {
match self {
Triple::Ok(ok) => f(ok),
_ => false,
}
}

pub fn is_err_and(self, f: impl FnOnce(E) -> bool) -> bool {
match self {
Triple::Err(err) => f(err),
_ => false,
}
}
}

impl<T> Triple<T, T> {
Expand Down
172 changes: 165 additions & 7 deletions crates/erg_compiler/context/compare.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use erg_common::traits::StructuralEq;
use erg_common::{assume_unreachable, log};
use erg_common::{Str, Triple};

use crate::context::eval::UndoableLinkedList;
use crate::context::initialize::const_func::sub_tpdict_get;
use crate::ty::constructors::{self, and, bounded, not, or, poly};
use crate::ty::free::{Constraint, FreeKind, FreeTyVar};
Expand Down Expand Up @@ -614,7 +615,22 @@ impl Context {
return false;
}
}
self.is_super_pred_of(&l.pred, &r.pred)
for tp in r.pred.possible_tps() {
let substituted = l.pred.clone().substitute(&l.var, tp);
if self.bool_eval_pred(substituted).is_ok_and(|b| b) {
return true;
}
}
if self.is_super_pred_of(&l.pred, &r.pred) {
true
} else {
let list = UndoableLinkedList::new();
for tp in l.t.typarams() {
list.push_tp(&tp);
}
let _ = self.undoable_sub_unify(&r.t, &l.t, &(), &list, None);
self.is_super_pred_of(&l.pred, &r.pred)
}
}
(Nat | Bool, re @ Refinement(_)) => {
let refine = Type::Refinement(lhs.clone().into_refinement());
Expand All @@ -633,6 +649,7 @@ impl Context {
// Array({1, 2}, _) :> {[3, 4]} == false
(l, Refinement(r)) => {
// Type / {S: Set(Str) | S == {"a", "b"}}
// TODO: GeneralEq
if let Pred::Equal { rhs, .. } = r.pred.as_ref() {
if self.subtype_of(l, &Type) && self.convert_tp_into_type(rhs.clone()).is_ok() {
return true;
Expand Down Expand Up @@ -895,7 +912,7 @@ impl Context {
},
(TyParam::Array(sup), TyParam::Array(sub))
| (TyParam::Tuple(sup), TyParam::Tuple(sub)) => {
if sup.len() > sub.len() {
if sup.len() > sub.len() || (variance.is_invariant() && sup.len() != sub.len()) {
return false;
}
for (sup_p, sub_p) in sup.iter().zip(sub.iter()) {
Expand All @@ -907,7 +924,9 @@ impl Context {
}
// {Int: Str} :> {Int: Str, Bool: Int}
(TyParam::Dict(sup_d), TyParam::Dict(sub_d)) => {
if sup_d.len() > sub_d.len() {
if sup_d.len() > sub_d.len()
|| (variance.is_invariant() && sup_d.len() != sub_d.len())
{
return false;
}
for (sub_k, sub_v) in sub_d.iter() {
Expand All @@ -924,11 +943,77 @@ impl Context {
}
true
}
(TyParam::Record(sup_r), TyParam::Record(sub_r)) => {
if sup_r.len() > sub_r.len()
|| (variance.is_invariant() && sup_r.len() != sub_r.len())
{
return false;
}
for (sub_k, sub_v) in sub_r.iter() {
if let Some(sup_v) = sup_r.get(sub_k) {
if !self.supertype_of_tp(sup_v, sub_v, variance) {
return false;
}
} else {
return false;
}
}
true
}
(TyParam::UnsizedArray(sup), TyParam::UnsizedArray(sub)) => {
self.supertype_of_tp(sup, sub, variance)
}
(TyParam::Type(sup), TyParam::Type(sub)) => match variance {
Variance::Contravariant => self.subtype_of(sup, sub),
Variance::Covariant => self.supertype_of(sup, sub),
Variance::Invariant => self.same_type_of(sup, sub),
},
(
TyParam::App { name, args },
TyParam::App {
name: sub_name,
args: sub_args,
},
) => {
if name != sub_name || args.len() != sub_args.len() {
return false;
}
for (sup_p, sub_p) in args.iter().zip(sub_args.iter()) {
if !self.supertype_of_tp(sup_p, sub_p, variance) {
return false;
}
}
true
}
(TyParam::Lambda(sup_l), TyParam::Lambda(sub_l)) => {
for (sup_nd, sub_nd) in sup_l.nd_params.iter().zip(sub_l.nd_params.iter()) {
if !self.subtype_of(sup_nd.typ(), sub_nd.typ()) {
return false;
}
}
if let Some((sup_var, sub_var)) =
sup_l.var_params.as_ref().zip(sub_l.var_params.as_ref())
{
if !self.subtype_of(sup_var.typ(), sub_var.typ()) {
return false;
}
}
for (sup_d, sub_d) in sup_l.d_params.iter().zip(sub_l.d_params.iter()) {
if !self.subtype_of(sup_d.typ(), sub_d.typ()) {
return false;
}
}
if let Some((sup_kw_var, sub_kw_var)) = sup_l
.kw_var_params
.as_ref()
.zip(sub_l.kw_var_params.as_ref())
{
if !self.subtype_of(sup_kw_var.typ(), sub_kw_var.typ()) {
return false;
}
}
true
}
(TyParam::FreeVar(fv), _) if fv.is_unbound() => {
let Some(fv_t) = fv.get_type() else {
return false;
Expand Down Expand Up @@ -1514,6 +1599,24 @@ impl Context {
| Predicate::LessEqual { rhs, .. } => self.get_tp_t(rhs).unwrap_or(Obj),
Predicate::Not(pred) => self.get_pred_type(pred),
Predicate::Value(val) => val.class(),
Predicate::Call { receiver, name, .. } => {
let receiver_t = self.get_tp_t(receiver).unwrap_or(Obj);
if let Some(name) = name {
let ctx = self.get_nominal_type_ctx(&receiver_t).unwrap();
if let Some((_, method)) = ctx.get_var_info(name) {
method.t.return_t().cloned().unwrap_or(Obj)
} else {
Obj
}
} else {
receiver_t.return_t().cloned().unwrap_or(Obj)
}
}
// REVIEW
Predicate::GeneralEqual { rhs, .. }
| Predicate::GeneralGreaterEqual { rhs, .. }
| Predicate::GeneralLessEqual { rhs, .. }
| Predicate::GeneralNotEqual { rhs, .. } => self.get_pred_type(rhs),
// x == 1 or x == "a" => Int or Str
Predicate::Or(lhs, rhs) => {
self.union(&self.get_pred_type(lhs), &self.get_pred_type(rhs))
Expand Down Expand Up @@ -1612,10 +1715,6 @@ impl Context {
return true;
}
match (lhs, rhs) {
(Pred::Value(ValueObj::Bool(b)), _) => *b,
(_, Pred::Value(ValueObj::Bool(b))) => !b,
(Pred::LessEqual { rhs, .. }, _) if !rhs.has_upper_bound() => true,
(Pred::GreaterEqual { rhs, .. }, _) if !rhs.has_lower_bound() => true,
(
Pred::Equal { .. },
Pred::GreaterEqual { .. } | Pred::LessEqual { .. } | Pred::NotEqual { .. },
Expand Down Expand Up @@ -1645,6 +1744,34 @@ impl Context {
.map(|ord| ord.canbe_eq())
.unwrap_or(false)
}
(
Pred::GeneralEqual { lhs, rhs },
Pred::GeneralEqual {
lhs: lhs2,
rhs: rhs2,
},
)
| (
Pred::GeneralGreaterEqual { lhs, rhs },
Pred::GeneralGreaterEqual {
lhs: lhs2,
rhs: rhs2,
},
)
| (
Pred::GeneralLessEqual { lhs, rhs },
Pred::GeneralLessEqual {
lhs: lhs2,
rhs: rhs2,
},
)
| (
Pred::GeneralNotEqual { lhs, rhs },
Pred::GeneralNotEqual {
lhs: lhs2,
rhs: rhs2,
},
) => self.is_super_pred_of(lhs, lhs2) && self.is_super_pred_of(rhs, rhs2),
// {T >= 0} :> {T >= 1}, {T >= 0} :> {T == 1}
(
Pred::GreaterEqual { rhs, .. },
Expand Down Expand Up @@ -1693,6 +1820,37 @@ impl Context {
}
true
}
(
Pred::Call {
receiver,
name,
args,
},
Pred::Call {
receiver: sub_receiver,
name: name2,
args: args2,
},
) => {
self.supertype_of_tp(receiver, sub_receiver, Variance::Covariant)
&& name == name2
&& args.len() == args2.len()
&& args
.iter()
.zip(args2.iter())
.all(|(l, r)| self.supertype_of_tp(l, r, Variance::Covariant))
}
(pred @ Predicate::Call { .. }, Predicate::Value(ValueObj::Bool(b))) => {
if let Ok(Predicate::Value(ValueObj::Bool(evaled))) = self.eval_pred(pred.clone()) {
b == &evaled
} else {
false
}
}
(Pred::Value(ValueObj::Bool(b)), _) => *b,
(_, Pred::Value(ValueObj::Bool(b))) => !b,
(Pred::LessEqual { rhs, .. }, _) if !rhs.has_upper_bound() => true,
(Pred::GreaterEqual { rhs, .. }, _) if !rhs.has_lower_bound() => true,
(lhs, Pred::And(l, r)) => {
self.is_super_pred_of(lhs, l) || self.is_super_pred_of(lhs, r)
}
Expand Down
Loading

0 comments on commit 107e5a0

Please sign in to comment.