Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ast refactor #140

Merged
merged 7 commits into from
Jun 7, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions crates/concrete_ast/src/expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -99,16 +99,16 @@ pub enum BitwiseOp {

#[derive(Clone, Debug, Eq, PartialEq)]
pub struct MatchExpr {
pub value: Box<Expression>,
pub match_expr: Box<Expression>,
kenarab marked this conversation as resolved.
Show resolved Hide resolved
pub variants: Vec<MatchVariant>,
pub span: Span,
}

#[derive(Clone, Debug, Eq, PartialEq)]
pub struct IfExpr {
pub value: Box<Expression>,
pub contents: Vec<Statement>,
pub r#else: Option<Vec<Statement>>,
pub cond: Box<Expression>,
pub block_stmts: Vec<Statement>,
pub else_stmts: Option<Vec<Statement>>,
pub span: Span,
}

Expand Down
20 changes: 10 additions & 10 deletions crates/concrete_ast/src/statements.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,15 @@ pub enum Statement {

#[derive(Clone, Debug, Eq, PartialEq)]
pub enum LetStmtTarget {
Simple { name: Ident, r#type: TypeSpec },
Simple { id: Ident, r#type: TypeSpec },
Destructure(Vec<Binding>),
}

#[derive(Clone, Debug, Eq, PartialEq)]
pub struct LetStmt {
pub is_mutable: bool,
pub target: LetStmtTarget,
kenarab marked this conversation as resolved.
Show resolved Hide resolved
pub value: Expression,
pub lvalue: LetStmtTarget,
pub rvalue: Expression,
pub span: Span,
}

Expand All @@ -38,30 +38,30 @@ pub struct ReturnStmt {

#[derive(Clone, Debug, Eq, PartialEq)]
pub struct AssignStmt {
pub target: PathOp,
pub lvalue: PathOp,
pub derefs: usize,
pub value: Expression,
pub rvalue: Expression,
pub span: Span,
}

#[derive(Clone, Debug, Eq, Hash, PartialEq)]
pub struct Binding {
pub name: Ident,
pub id: Ident,
pub rename: Option<Ident>,
pub r#type: TypeSpec,
}

#[derive(Clone, Debug, Eq, PartialEq)]
pub struct ForStmt {
pub init: Option<LetStmt>,
pub condition: Option<Expression>,
kenarab marked this conversation as resolved.
Show resolved Hide resolved
pub cond: Option<Expression>,
pub post: Option<AssignStmt>,
pub contents: Vec<Statement>,
pub block_stmts: Vec<Statement>,
pub span: Span,
}

#[derive(Clone, Debug, Eq, PartialEq)]
pub struct WhileStmt {
pub value: Expression,
pub contents: Vec<Statement>,
pub cond: Expression,
pub block_stmts: Vec<Statement>,
}
56 changes: 29 additions & 27 deletions crates/concrete_check/src/linearity_check.rs
Original file line number Diff line number Diff line change
Expand Up @@ -324,14 +324,14 @@ impl LinearityChecker {
match statement {
Statement::Let(binding) => {
// Handle let bindings, possibly involving pattern matching
self.count_in_expression(name, &binding.value)
self.count_in_expression(name, &binding.rvalue)
}
Statement::If(if_stmt) => {
// Process all components of an if expression
let cond_apps = self.count_in_expression(name, &if_stmt.value);
let then_apps = self.count_in_statements(name, &if_stmt.contents);
let cond_apps = self.count_in_expression(name, &if_stmt.cond);
let then_apps = self.count_in_statements(name, &if_stmt.block_stmts);
let else_apps;
let else_statements = &if_stmt.r#else;
let else_statements = &if_stmt.else_stmts;
if let Some(else_statements) = else_statements {
else_apps = self.count_in_statements(name, else_statements);
} else {
Expand All @@ -340,8 +340,8 @@ impl LinearityChecker {
cond_apps.merge(&then_apps).merge(&else_apps)
}
Statement::While(while_expr) => {
let cond = &while_expr.value;
let block = &while_expr.contents;
let cond = &while_expr.cond;
let block = &while_expr.block_stmts;
// Handle while loops
self.count_in_expression(name, cond)
.merge(&self.count_in_statements(name, block))
Expand All @@ -350,9 +350,9 @@ impl LinearityChecker {
// Handle for loops
//init, cond, post, block
let init = &for_expr.init;
let cond = &for_expr.condition;
let cond = &for_expr.cond;
let post = &for_expr.post;
let block = &for_expr.contents;
let block = &for_expr.block_stmts;
let mut apps = Appearances::zero();
if let Some(init) = init {
if let Some(cond) = cond {
Expand Down Expand Up @@ -396,9 +396,9 @@ impl LinearityChecker {

fn count_in_assign_statement(&self, name: &str, assign_stmt: &AssignStmt) -> Appearances {
let AssignStmt {
target,
lvalue: target,
derefs,
value,
rvalue: value,
span,
} = assign_stmt;
// Handle assignments
Expand All @@ -422,8 +422,8 @@ impl LinearityChecker {
fn count_in_let_statements(&self, name: &str, let_stmt: &LetStmt) -> Appearances {
let LetStmt {
is_mutable,
target,
value,
lvalue: target,
rvalue: value,
span,
} = let_stmt;
self.count_in_expression(name, value)
Expand Down Expand Up @@ -466,10 +466,10 @@ impl LinearityChecker {
Expression::If(if_expr) => {
// Process all components of an if expression
// TODO review this code. If expressions should be processed counting both branches and comparing them
let cond_apps = self.count_in_expression(name, &if_expr.value);
let then_apps = self.count_in_statements(name, &if_expr.contents);
let cond_apps = self.count_in_expression(name, &if_expr.cond);
let then_apps = self.count_in_statements(name, &if_expr.block_stmts);
cond_apps.merge(&then_apps);
if let Some(else_block) = &if_expr.r#else {
if let Some(else_block) = &if_expr.else_stmts {
let else_apps = self.count_in_statements(name, else_block);
cond_apps.merge(&then_apps).merge(&else_apps);
}
Expand Down Expand Up @@ -524,12 +524,12 @@ impl LinearityChecker {
// Handle let bindings, possibly involving pattern matching
let LetStmt {
is_mutable,
target,
value,
lvalue: target,
rvalue: value,
span,
} = binding;
match target {
LetStmtTarget::Simple { name, r#type } => {
LetStmtTarget::Simple { id: name, r#type } => {
match r#type {
TypeSpec::Simple {
name: variable_type,
Expand Down Expand Up @@ -702,19 +702,20 @@ impl LinearityChecker {
//Statement::If(cond, then_block, else_block) => {
Statement::If(if_stmt) => {
// Handle conditional statements
state_tbl = self.check_expr(state_tbl, depth, &if_stmt.value, context)?;
state_tbl = self.check_stmts(state_tbl, depth + 1, &if_stmt.contents, context)?;
if let Some(else_block) = &if_stmt.r#else {
state_tbl = self.check_expr(state_tbl, depth, &if_stmt.cond, context)?;
state_tbl =
self.check_stmts(state_tbl, depth + 1, &if_stmt.block_stmts, context)?;
if let Some(else_block) = &if_stmt.else_stmts {
state_tbl = self.check_stmts(state_tbl, depth + 1, else_block, context)?;
}
Ok(state_tbl)
}
//Statement::While(cond, block) => {
Statement::While(while_stmt) => {
// Handle while loops
state_tbl = self.check_expr(state_tbl, depth, &while_stmt.value, context)?;
state_tbl = self.check_expr(state_tbl, depth, &while_stmt.cond, context)?;
state_tbl =
self.check_stmts(state_tbl, depth + 1, &while_stmt.contents, context)?;
self.check_stmts(state_tbl, depth + 1, &while_stmt.block_stmts, context)?;
Ok(state_tbl)
}
//Statement::For(init, cond, post, block) => {
Expand All @@ -723,22 +724,23 @@ impl LinearityChecker {
if let Some(init) = &for_stmt.init {
state_tbl = self.check_stmt_let(state_tbl, depth, init, context)?;
}
if let Some(condition) = &for_stmt.condition {
if let Some(condition) = &for_stmt.cond {
state_tbl = self.check_expr(state_tbl, depth, condition, context)?;
}
if let Some(post) = &for_stmt.post {
//TODO check assign statement
//self.check_stmt_assign(depth, post)?;
}
state_tbl = self.check_stmts(state_tbl, depth + 1, &for_stmt.contents, context)?;
state_tbl =
self.check_stmts(state_tbl, depth + 1, &for_stmt.block_stmts, context)?;
Ok(state_tbl)
}
Statement::Assign(assign_stmt) => {
// Handle assignments
let AssignStmt {
target,
lvalue: target,
derefs,
value,
rvalue: value,
span,
} = assign_stmt;
tracing::debug!("Checking assignment: {:?}", assign_stmt);
Expand Down
38 changes: 19 additions & 19 deletions crates/concrete_ir/src/lowering.rs
Original file line number Diff line number Diff line change
Expand Up @@ -287,8 +287,8 @@ fn lower_func(
// Get all locals
for stmt in &func.body {
if let statements::Statement::Let(info) = stmt {
match &info.target {
LetStmtTarget::Simple { name, r#type } => {
match &info.lvalue {
LetStmtTarget::Simple { id: name, r#type } => {
let ty = lower_type(&builder.ctx, r#type, builder.local_module)?;
builder
.name_to_local
Expand All @@ -305,8 +305,8 @@ fn lower_func(
}
} else if let statements::Statement::For(info) = stmt {
if let Some(info) = &info.init {
match &info.target {
LetStmtTarget::Simple { name, r#type } => {
match &info.lvalue {
LetStmtTarget::Simple { id: name, r#type } => {
let ty = lower_type(&builder.ctx, r#type, builder.local_module)?;
builder
.name_to_local
Expand Down Expand Up @@ -435,7 +435,7 @@ fn lower_while(builder: &mut FnBodyBuilder, info: &WhileStmt) -> Result<(), Lowe
});

let (discriminator, discriminator_type, _disc_span) =
lower_expression(builder, &info.value, None)?;
lower_expression(builder, &info.cond, None)?;

let local = builder.add_temp_local(TyKind::Bool);
let place = Place {
Expand Down Expand Up @@ -463,7 +463,7 @@ fn lower_while(builder: &mut FnBodyBuilder, info: &WhileStmt) -> Result<(), Lowe
// keep idx for switch targets
let first_then_block_idx = builder.body.basic_blocks.len();

for stmt in &info.contents {
for stmt in &info.block_stmts {
lower_statement(
builder,
stmt,
Expand Down Expand Up @@ -515,7 +515,7 @@ fn lower_for(builder: &mut FnBodyBuilder, info: &ForStmt) -> Result<(), Lowering
}),
});

let (discriminator, discriminator_type, _disc_span) = if let Some(condition) = &info.condition {
let (discriminator, discriminator_type, _disc_span) = if let Some(condition) = &info.cond {
let (discriminator, discriminator_type, span) = lower_expression(builder, condition, None)?;

(discriminator, discriminator_type, Some(span))
Expand Down Expand Up @@ -560,7 +560,7 @@ fn lower_for(builder: &mut FnBodyBuilder, info: &ForStmt) -> Result<(), Lowering
// keep idx for switch targets
let first_then_block_idx = builder.body.basic_blocks.len();

for stmt in &info.contents {
for stmt in &info.block_stmts {
lower_statement(
builder,
stmt,
Expand Down Expand Up @@ -602,7 +602,7 @@ fn lower_for(builder: &mut FnBodyBuilder, info: &ForStmt) -> Result<(), Lowering

fn lower_if_statement(builder: &mut FnBodyBuilder, info: &IfExpr) -> Result<(), LoweringError> {
let (discriminator, discriminator_type, _disc_span) =
lower_expression(builder, &info.value, None)?;
lower_expression(builder, &info.cond, None)?;

let local = builder.add_temp_local(TyKind::Bool);
let place = Place {
Expand Down Expand Up @@ -630,7 +630,7 @@ fn lower_if_statement(builder: &mut FnBodyBuilder, info: &IfExpr) -> Result<(),
// keep idx for switch targets
let first_then_block_idx = builder.body.basic_blocks.len();

for stmt in &info.contents {
for stmt in &info.block_stmts {
lower_statement(
builder,
stmt,
Expand All @@ -654,7 +654,7 @@ fn lower_if_statement(builder: &mut FnBodyBuilder, info: &IfExpr) -> Result<(),

let first_else_block_idx = builder.body.basic_blocks.len();

if let Some(contents) = &info.r#else {
if let Some(contents) = &info.else_stmts {
for stmt in contents {
lower_statement(
builder,
Expand Down Expand Up @@ -697,11 +697,11 @@ fn lower_if_statement(builder: &mut FnBodyBuilder, info: &IfExpr) -> Result<(),
}

fn lower_let(builder: &mut FnBodyBuilder, info: &LetStmt) -> Result<(), LoweringError> {
match &info.target {
LetStmtTarget::Simple { name, r#type } => {
match &info.lvalue {
LetStmtTarget::Simple { id: name, r#type } => {
let ty = lower_type(&builder.ctx, r#type, builder.local_module)?;
let (rvalue, rvalue_ty, rvalue_span) =
lower_expression(builder, &info.value, Some(ty.clone()))?;
lower_expression(builder, &info.rvalue, Some(ty.clone()))?;

if ty.kind != rvalue_ty.kind {
return Err(LoweringError::UnexpectedType {
Expand Down Expand Up @@ -734,7 +734,7 @@ fn lower_let(builder: &mut FnBodyBuilder, info: &LetStmt) -> Result<(), Lowering
}

fn lower_assign(builder: &mut FnBodyBuilder, info: &AssignStmt) -> Result<(), LoweringError> {
let (mut place, mut ty, _path_span) = lower_path(builder, &info.target)?;
let (mut place, mut ty, _path_span) = lower_path(builder, &info.lvalue)?;

if !builder.body.locals[place.local].is_mutable() {
return Err(LoweringError::NotMutable {
Expand All @@ -749,8 +749,8 @@ fn lower_assign(builder: &mut FnBodyBuilder, info: &AssignStmt) -> Result<(), Lo
TyKind::Ref(inner, is_mut) | TyKind::Ptr(inner, is_mut) => {
if matches!(is_mut, Mutability::Not) {
Err(LoweringError::BorrowNotMutable {
span: info.target.first.span,
name: info.target.first.name.clone(),
span: info.lvalue.first.span,
name: info.lvalue.first.name.clone(),
type_span: ty.span,
program_id: builder.local_module.program_id,
})?;
Expand All @@ -763,7 +763,7 @@ fn lower_assign(builder: &mut FnBodyBuilder, info: &AssignStmt) -> Result<(), Lo
}

let (rvalue, rvalue_ty, rvalue_span) =
lower_expression(builder, &info.value, Some(ty.clone()))?;
lower_expression(builder, &info.rvalue, Some(ty.clone()))?;

if ty.kind != rvalue_ty.kind {
return Err(LoweringError::UnexpectedType {
Expand All @@ -775,7 +775,7 @@ fn lower_assign(builder: &mut FnBodyBuilder, info: &AssignStmt) -> Result<(), Lo
}

builder.statements.push(Statement {
span: Some(info.target.first.span),
span: Some(info.lvalue.first.span),
kind: StatementKind::Assign(place, rvalue),
});

Expand Down
Loading
Loading