Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DeclaredType in TypeChecker #2008

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
29 changes: 19 additions & 10 deletions pil-analyzer/src/pil_analyzer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ use powdr_ast::parsed::asm::{
use powdr_ast::parsed::types::Type;
use powdr_ast::parsed::visitor::{AllChildren, Children};
use powdr_ast::parsed::{
self, FunctionKind, LambdaExpression, PILFile, PilStatement, SymbolCategory,
self, FunctionKind, LambdaExpression, PILFile, PilStatement, SourceReference, SymbolCategory,
TraitImplementation, TypedExpression,
};
use powdr_number::{FieldElement, GoldilocksField};
Expand All @@ -29,7 +29,7 @@ use powdr_parser_util::Error;

use crate::traits_resolver::TraitsResolver;
use crate::type_builtins::constr_function_statement_type;
use crate::type_inference::infer_types;
use crate::type_inference::{infer_types, DeclaredType};
use crate::{side_effect_checker, AnalysisDriver};

use crate::statement_processor::{Counters, PILItem, StatementProcessor};
Expand Down Expand Up @@ -303,39 +303,48 @@ impl PILAnalyzer {
)
})
.flat_map(|(name, (symbol, value))| {
let (type_scheme, expr) = match (symbol.kind, value) {
let (declared_type, expr) = match (symbol.kind, value) {
(SymbolKind::Poly(PolynomialType::Committed), Some(value)) => {
// Witness column, move its value (query function) into the expressions to be checked separately.
let type_scheme = type_from_definition(symbol, &None);

let FunctionValueDefinition::Expression(TypedExpression { e, .. }) = value
else {
panic!("Invalid value for query function")
};

let source = e.source_reference().clone();
expressions.push((e, query_type.clone().into()));

(type_scheme, None)
let declared_type = type_from_definition(symbol, &None)
.map(|ts| ts.into())
.map(|dec: DeclaredType| dec.with_source(source));
(declared_type, None)
}
(
_,
Some(FunctionValueDefinition::Expression(TypedExpression {
type_scheme,
e,
})),
) => (type_scheme.clone(), Some(e)),
) => {
let source = e.source_reference();
let declared_type = type_scheme
.clone()
.map(|ts| ts.into())
.map(|dec: DeclaredType| dec.with_source(source.clone()));
(declared_type, Some(e))
}
(_, value) => {
let type_scheme = type_from_definition(symbol, value);
let declared_type = type_from_definition(symbol, value).map(|ts| ts.into());

if let Some(FunctionValueDefinition::Array(items)) = value {
// Expect all items in the arrays to be field elements.
expressions.extend(items.children_mut().map(|e| (e, Type::Fe.into())));
}

(type_scheme, None)
(declared_type, None)
}
};
Some((name.clone(), (type_scheme, expr)))
Some((name.clone(), (declared_type, expr)))
})
.collect();
for expr in &mut self.proof_items {
Expand Down
153 changes: 109 additions & 44 deletions pil-analyzer/src/type_inference.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ use crate::{
/// Sets the generic arguments for references and the literal types in all expressions.
/// Returns the types for symbols without explicit type.
pub fn infer_types(
definitions: HashMap<String, (Option<TypeScheme>, Option<&mut Expression>)>,
definitions: HashMap<String, (Option<DeclaredType>, Option<&mut Expression>)>,
expressions: &mut [(&mut Expression, ExpectedType)],
) -> Result<Vec<(String, Type)>, Vec<Error>> {
TypeChecker::new().infer_types(definitions, expressions)
Expand Down Expand Up @@ -60,13 +60,66 @@ impl From<Type> for ExpectedType {
}
}

#[derive(Debug, Clone)]
pub struct DeclaredType {
pub source: SourceRef,
pub vars: TypeBounds,
pub ty: DeclaredTypeKind,
}

impl DeclaredType {
fn scheme(&self) -> TypeScheme {
match &self.ty {
DeclaredTypeKind::Struct(ty, _) | DeclaredTypeKind::Type(ty) => TypeScheme {
vars: self.vars.clone(),
ty: ty.clone(),
},
}
}

fn type_mut(&mut self) -> &mut Type {
match &mut self.ty {
DeclaredTypeKind::Struct(ty, _) => ty,
DeclaredTypeKind::Type(ty) => ty,
}
}

fn declared_type(&self) -> &Type {
match &self.ty {
DeclaredTypeKind::Struct(ty, _) | DeclaredTypeKind::Type(ty) => ty,
}
}

pub fn with_source(mut self, source: SourceRef) -> Self {
self.source = source;
self
}
}

#[derive(Debug, Clone)]
pub enum DeclaredTypeKind {
#[allow(dead_code)] // Remove when #1910 is merged
Struct(Type, HashMap<String, Type>),
Type(Type),
}

impl From<TypeScheme> for DeclaredType {
fn from(scheme: TypeScheme) -> Self {
Self {
source: SourceRef::unknown(),
vars: scheme.vars.clone(),
ty: DeclaredTypeKind::Type(scheme.ty.clone()),
}
}
}

struct TypeChecker {
/// Types for local variables, might contain type variables.
local_var_types: Vec<Type>,
/// Declared types for all symbols and their source references.
/// Contains the unmodified type scheme for symbols with generic types and newly
/// created type variables for symbols without declared type.
declared_types: HashMap<String, (SourceRef, TypeScheme)>,
declared_types: HashMap<String, DeclaredType>,
/// Current mapping of declared type vars to type. Reset before checking each definition.
declared_type_vars: HashMap<String, Type>,
unifier: Unifier,
Expand All @@ -89,7 +142,7 @@ impl TypeChecker {
/// returns the types for symbols without explicit type.
pub fn infer_types(
mut self,
mut definitions: HashMap<String, (Option<TypeScheme>, Option<&mut Expression>)>,
mut definitions: HashMap<String, (Option<DeclaredType>, Option<&mut Expression>)>,
expressions: &mut [(&mut Expression, ExpectedType)],
) -> Result<Vec<(String, Type)>, Vec<Error>> {
let type_var_mapping = self
Expand All @@ -100,11 +153,11 @@ impl TypeChecker {
.into_iter()
.filter(|(_, (ty, _))| ty.is_none())
.map(|(name, _)| {
let (_, mut scheme) = self.declared_types.remove(&name).unwrap();
assert!(scheme.vars.is_empty());
self.substitute(&mut scheme.ty);
assert!(scheme.ty.is_concrete_type());
(name, scheme.ty)
let mut declared_type = self.declared_types.remove(&name).unwrap();
assert!(declared_type.vars.is_empty());
self.substitute(declared_type.type_mut());
assert!(declared_type.scheme().ty.is_concrete_type());
(name, declared_type.scheme().ty)
})
.collect())
}
Expand All @@ -113,7 +166,7 @@ impl TypeChecker {
/// the type variables used by the type checker to those used in the declaration.
fn infer_types_inner(
&mut self,
definitions: &mut HashMap<String, (Option<TypeScheme>, Option<&mut Expression>)>,
definitions: &mut HashMap<String, (Option<DeclaredType>, Option<&mut Expression>)>,
expressions: &mut [(&mut Expression, ExpectedType)],
) -> Result<HashMap<String, HashMap<String, Type>>, Error> {
// TODO in order to fix type inference on recursive functions, we need to:
Expand All @@ -129,6 +182,14 @@ impl TypeChecker {
);

self.setup_declared_types(definitions);
// After we setup declared types, every definition
// related with a Struct Declaration is not nedded any more
let mut definitions: HashMap<_,_> = definitions
.iter_mut()
.filter(|(_, (ty, _))| {
!matches!(ty, Some(declared) if matches!(declared.ty, DeclaredTypeKind::Struct(_, _)))
})
.collect();

// These are the inferred types for symbols that are declared
// as type schemes. They are compared to the declared types
Expand All @@ -145,10 +206,10 @@ impl TypeChecker {
continue;
};

let (_, declared_type) = self.declared_types[&name].clone();
let declared_type = self.declared_types[&name].clone();
gzanitti marked this conversation as resolved.
Show resolved Hide resolved
if declared_type.vars.is_empty() {
self.declared_type_vars.clear();
self.process_concrete_symbol(declared_type.ty.clone(), value)?;
self.process_concrete_symbol(declared_type.declared_type().clone(), value)?;
} else {
self.declared_type_vars = declared_type
.vars
Expand All @@ -168,13 +229,13 @@ impl TypeChecker {

// Now we check for all symbols that are not declared as a type scheme that they
// can resolve to a concrete type.
for (name, (source_ref, declared_type)) in &self.declared_types {
gzanitti marked this conversation as resolved.
Show resolved Hide resolved
for (name, declared_type) in &self.declared_types {
if declared_type.vars.is_empty() {
// It is not a type scheme, see if we were able to derive a concrete type.
let inferred = self.type_into_substituted(declared_type.ty.clone());
let inferred = self.type_into_substituted(declared_type.declared_type().clone());
if !inferred.is_concrete_type() {
let inferred_scheme = self.to_type_scheme(inferred);
return Err(source_ref.with_error(
return Err(declared_type.source.with_error(
format!(
"Could not derive a concrete type for symbol {name}.\nInferred type scheme: {}\n",
format_type_scheme_around_name(
Expand All @@ -197,38 +258,39 @@ impl TypeChecker {
/// Fills self.declared_types and checks that declared builtins have the correct type.
fn setup_declared_types(
&mut self,
definitions: &mut HashMap<String, (Option<TypeScheme>, Option<&mut Expression>)>,
definitions: &mut HashMap<String, (Option<DeclaredType>, Option<&mut Expression>)>,
) {
// Add types from declarations. Type schemes are added without instantiating.
self.declared_types = definitions
.iter()
.map(|(name, (type_scheme, value))| {
let source_ref = value
.as_ref()
.map(|v| v.source_reference())
.cloned()
.unwrap_or_default();
// Check if it is a builtin symbol.
let ty = match (builtin_schemes().get(name), type_scheme) {
.map(|(name, (declared_type, _))| {
let declared_type = match (builtin_schemes().get(name), declared_type) {
(Some(builtin), declared) => {
if let Some(declared) = declared {
if let Some(declared_inner) = declared {
let declared_scheme = declared_inner.scheme();
assert_eq!(
builtin,
declared,
*builtin,
declared_scheme,
"Invalid type for built-in scheme. Got {} but expected {}",
format_type_scheme_around_name(name, &Some(declared.clone())),
format_type_scheme_around_name(
name,
&Some(declared_scheme.clone())
),
format_type_scheme_around_name(name, &Some(builtin.clone()))
);
};
builtin.clone()
builtin.clone().into()
}
// Store an (uninstantiated) type scheme for symbols with a declared polymorphic type.
(None, Some(type_scheme)) => type_scheme.clone(),
(None, Some(declared_type)) => declared_type.clone(),
// Store a new (unquantified) type variable for symbols without declared type.
// This forces a single concrete type for them.
(None, None) => self.unifier.new_type_var().into(),
(None, None) => {
let scheme: TypeScheme = self.unifier.new_type_var().into();
scheme.into()
}
};
(name.clone(), (source_ref, ty))
(name.clone(), declared_type)
})
.collect();

Expand All @@ -237,7 +299,7 @@ impl TypeChecker {
for (name, scheme) in builtin_schemes() {
self.declared_types
.entry(name.clone())
.or_insert_with(|| (SourceRef::unknown(), scheme.clone()));
.or_insert_with(|| Into::<DeclaredType>::into(scheme.clone()));
gzanitti marked this conversation as resolved.
Show resolved Hide resolved
definitions.remove(name);
}
}
Expand Down Expand Up @@ -327,7 +389,7 @@ impl TypeChecker {
/// the type variable names used by the type checker to those from the declaration.
fn update_type_args(
&mut self,
definitions: &mut HashMap<String, (Option<TypeScheme>, Option<&mut Expression>)>,
definitions: &mut HashMap<String, (Option<DeclaredType>, Option<&mut Expression>)>,
expressions: &mut [(&mut Expression, ExpectedType)],
type_var_mapping: &HashMap<String, HashMap<String, Type>>,
) -> Result<(), Vec<Error>> {
Expand Down Expand Up @@ -521,9 +583,7 @@ impl TypeChecker {
source_ref,
Reference::Poly(PolynomialReference { name, type_args }),
) => {
let (ty, args) = self
.unifier
.instantiate_scheme(self.declared_types[name].1.clone());
let (ty, args) = self.instantiate_scheme_by_declared_name(name);
if let Some(requested_type_args) = type_args {
if requested_type_args.len() != args.len() {
return Err(source_ref.with_error(format!(
Expand Down Expand Up @@ -856,9 +916,8 @@ impl TypeChecker {
Pattern::Enum(source_ref, name, data) => {
// We just ignore the generic args here, storing them in the pattern
// is not helpful because the type is obvious from the value.
let (ty, _generic_args) = self
.unifier
.instantiate_scheme(self.declared_types[&name.to_string()].1.clone());
let (ty, _generic_args) =
self.instantiate_scheme_by_declared_name(&name.to_string());
let ty = type_for_reference(&ty);

match data {
Expand Down Expand Up @@ -913,18 +972,19 @@ impl TypeChecker {
inferred_types: HashMap<String, Type>,
) -> Result<HashMap<String, HashMap<String, Type>>, Error> {
inferred_types.into_iter().map(|(name, inferred_type)| {
let (source_ref, declared_type) = self.declared_types[&name].clone();
let declared_type = self.declared_types[&name].clone();
gzanitti marked this conversation as resolved.
Show resolved Hide resolved
let inferred_type = self.type_into_substituted(inferred_type.clone());
let inferred = self.to_type_scheme(inferred_type.clone());
let declared = declared_type.clone().simplify_type_vars();
let declared = declared_type.scheme().clone().simplify_type_vars();
if inferred != declared {
return Err(source_ref.with_error(format!(
return Err(declared_type.source.with_error(format!(
"Inferred type scheme for symbol {name} does not match the declared type.\nInferred: let{}\nDeclared: let{}",
format_type_scheme_around_name(&name, &Some(inferred)),
format_type_scheme_around_name(&name, &Some(declared_type),
format_type_scheme_around_name(&name, &Some(declared),
))));
}
let declared_type_vars = declared_type.ty.contained_type_vars();
let declared_ty = declared_type.scheme().ty;
let declared_type_vars = declared_ty.contained_type_vars();
let inferred_type_vars = inferred_type.contained_type_vars();
Ok((name.clone(),
inferred_type_vars
Expand Down Expand Up @@ -977,6 +1037,11 @@ impl TypeChecker {
pub fn local_var_type(&self, id: u64) -> Type {
self.local_var_types[id as usize].clone()
}

fn instantiate_scheme_by_declared_name(&mut self, name: &str) -> (Type, Vec<Type>) {
gzanitti marked this conversation as resolved.
Show resolved Hide resolved
self.unifier
.instantiate_scheme(self.declared_types[&name.to_string()].scheme().clone())
gzanitti marked this conversation as resolved.
Show resolved Hide resolved
}
}

fn update_type_if_literal(
Expand Down
Loading