From 4e13bb2ffc57fdf73add4f6dfc604a4de136252e Mon Sep 17 00:00:00 2001 From: Logan Date: Thu, 24 Oct 2024 14:40:41 -0500 Subject: [PATCH] finished var and func numbering --- demo.hal | 3 + demo.lang | 2 - src/frontend.rs | 8 +- src/ir.rs | 58 +++++- src/main.rs | 30 ++- src/parse/expression.rs | 34 ++-- src/parse/statement.rs | 26 ++- src/semantic/analyzer.rs | 356 ++++++++++++++++++++++++++++++++++ src/semantic/mod.rs | 159 +++++++++++++--- src/semantic/types.rs | 400 ++------------------------------------- 10 files changed, 632 insertions(+), 444 deletions(-) create mode 100644 demo.hal delete mode 100644 demo.lang create mode 100644 src/semantic/analyzer.rs diff --git a/demo.hal b/demo.hal new file mode 100644 index 0000000..d51f4dd --- /dev/null +++ b/demo.hal @@ -0,0 +1,3 @@ +bar :: (a: integer) { + local := a; +} diff --git a/demo.lang b/demo.lang deleted file mode 100644 index 41054f6..0000000 --- a/demo.lang +++ /dev/null @@ -1,2 +0,0 @@ -a : i32 = 10; -b := 20 + a; diff --git a/src/frontend.rs b/src/frontend.rs index f556e24..7e3bd0c 100644 --- a/src/frontend.rs +++ b/src/frontend.rs @@ -1,6 +1,10 @@ use std::path::Path; -use crate::{err::*, semantic::typecheck, Parser, Statement, Tokenizer}; +use crate::{ + err::*, + semantic::{self}, + Parser, Statement, Tokenizer, +}; #[derive(Debug, Clone)] pub struct Module { @@ -26,7 +30,7 @@ impl Module { pub fn from_string(file_name: String, source: String) -> Self { let tokens = Tokenizer::new(source.chars()).filter(|t| t.0.is_meaningful()); let statements = Parser::new(tokens); - let program = typecheck(statements.collect()); + let program = semantic::Analyzer::typecheck(statements.collect()); let mut errors = vec![]; Self { file_name, diff --git a/src/ir.rs b/src/ir.rs index a1f212c..aed22fe 100644 --- a/src/ir.rs +++ b/src/ir.rs @@ -1,17 +1,57 @@ -use crate::{semantic::Type, BinaryOp, Immediate, UnaryOp}; +use crate::{ + BinaryOp, Expression, ExpressionKind, Immediate, UnaryOp, + semantic::{Type, VarKind, uid}, +}; #[derive(Debug, Clone)] pub enum IR { BinOp { op: BinaryOp, type_: Type }, UnOp { op: UnaryOp, type_: Type }, Imm(Immediate), - NewLocal { uid: usize, type_: Type }, - AssignLocal { uid: usize }, - AccessLocal { uid: usize }, - NewGlobal { uid: usize, type_: Type }, - AssignGlobal { uid: usize }, - AccessGlobal { uid: usize }, - StartFunc { uid: usize }, - NewParam { uid: usize, type_: Type }, + NewLocal { uid: uid, type_: Type }, + AssignLocal { uid: uid }, + GetLocal { uid: uid }, + NewGlobal { uid: uid, type_: Type }, + AssignGlobal { uid: uid }, + GetGlobal { uid: uid }, + StartFunc { uid: uid }, + NewParam { uid: uid, type_: Type }, EndFunc, } + +pub struct Compiler { + ir: Vec, +} + +impl Compiler { + fn expression(&mut self, expression: Expression) { + use ExpressionKind::*; + match expression.kind { + Immediate(immediate) => { + self.ir.push(IR::Imm(immediate)); + }, + Identifier(name, var_kind) => match var_kind { + VarKind::Global(uid) => self.ir.push(IR::GetGlobal { uid }), + VarKind::Local(uid) | VarKind::Param(uid) => { + self.ir.push(IR::GetLocal { uid }) + }, + VarKind::Function(_) => todo!(), + VarKind::Undefined => todo!(), + }, + Binary { op, left, right } => todo!(), + Unary { op, child } => todo!(), + Parenthesis(expression) => todo!(), + FunctionDef { + params, + returns_str, + returns_actual, + body, + id, + } => todo!(), + FunctionCall { callee, args } => todo!(), + StructDef(vec) => todo!(), + StructLiteral { name, args } => todo!(), + Field { namespace, field } => todo!(), + } + } +} diff --git a/src/main.rs b/src/main.rs index 7c0c231..ad802b0 100644 --- a/src/main.rs +++ b/src/main.rs @@ -62,10 +62,36 @@ fn test_expression(expr: &str) { println!("{:?}", parser.expression(0).unwrap()); } +fn prints(st: &Statement) { + use StatementKind as s; + println!("{st:?}"); + match &st.kind { + s::Declaration { + value: + Expression { + kind: ExpressionKind::FunctionDef { body: block, .. }, + .. + }, + .. + } + | s::If { block, .. } + | s::While { block, .. } + | s::Block(block) => { + for s in block { + prints(s); + } + }, + _ => {}, + }; +} + fn main() -> Result<()> { - let module = frontend::Module::from_file("./demo.lang")?; + test_expression("asdf.asdf()"); + /* + let module = frontend::Module::from_file("./demo.hal")?; for s in &module.program { - println!("{s:?}"); + prints(s); } + */ Ok(()) } diff --git a/src/parse/expression.rs b/src/parse/expression.rs index 39ab44c..a5b07d4 100644 --- a/src/parse/expression.rs +++ b/src/parse/expression.rs @@ -26,7 +26,7 @@ pub enum Immediate { #[derive(Clone)] pub enum ExpressionKind { Immediate(Immediate), - Identifier(String), + Identifier(String, VarKind), Binary { op: BinaryOp, left: Box, @@ -37,21 +37,22 @@ pub enum ExpressionKind { child: Box, }, Parenthesis(Box), - Function { + FunctionDef { params: Vec, returns_str: Option, returns_actual: Type, body: Vec, + id: uid, }, - Struct(Vec), + FunctionCall { + callee: Box, + args: Vec, + }, + StructDef(Vec), StructLiteral { name: String, args: Vec<(String, Expression)>, }, - Call { - callee: Box, - args: Vec, - }, Field { namespace: Box, field: Box, @@ -97,19 +98,21 @@ impl std::fmt::Debug for ExpressionKind { e::Unary { op: token, child } => { write!(f, "({token:?} {child:?})") }, - e::Identifier(i) => write!(f, "{i}"), - e::Call { callee, args } => write!(f, "({callee:?} call {args:?})"), + e::Identifier(i, _) => write!(f, "{i}"), + e::FunctionCall { callee, args } => { + write!(f, "({callee:?} call {args:?})") + }, e::Field { namespace, field } => { write!(f, "({namespace:?} . {field:?})") }, - e::Function { + e::FunctionDef { params, returns_actual, .. } => { write!(f, "(fn({params:?}) -> {returns_actual:?})") }, - e::Struct(params) => write!(f, "struct {{ {params:?} }}"), + e::StructDef(params) => write!(f, "struct {{ {params:?} }}"), e::StructLiteral { name, args } => write!(f, "{name} {{ {args:?} }}"), } } @@ -217,7 +220,7 @@ impl> Parser { } let Token(_, span2) = self.eat(t::RightParen).span(&span)?; current = Expression::new( - e::Call { + e::FunctionCall { callee: current.into(), args, }, @@ -326,11 +329,12 @@ impl> Parser { .block() .trace_span(span, "while parsing function body")?; span = span + span2; - e::Function { + e::FunctionDef { params, returns_str, returns_actual: Type::Ambiguous, body, + id: 0, } }, // Struct definition @@ -361,7 +365,7 @@ impl> Parser { } } self.eat(t::RightBrace)?; - e::Struct(params) + e::StructDef(params) }, // Struct literal t::Identifier(name) if self.look(1, t::LeftBrace).is_ok() => { @@ -390,7 +394,7 @@ impl> Parser { }, t::Identifier(i) => { self.skip(1); - e::Identifier(i) + e::Identifier(i, VarKind::Undefined) }, // Parenthetical t::LeftParen => { diff --git a/src/parse/statement.rs b/src/parse/statement.rs index fde53ab..7d1d060 100644 --- a/src/parse/statement.rs +++ b/src/parse/statement.rs @@ -1,4 +1,4 @@ -use crate::semantic::Type; +use crate::semantic::{Type, VarKind}; use super::*; @@ -10,10 +10,12 @@ pub enum StatementKind { type_actual: Type, value: Expression, mutable: bool, + varkind: VarKind, }, Assignment { name: String, value: Expression, + varkind: VarKind, }, If { predicate: Expression, @@ -68,13 +70,14 @@ impl> Parser { .expression(0) .trace_span(span, "while parsing declaration")?; span = span + value.span; - let no_semicolon = if let ExpressionKind::Function { .. } = value.kind { - true - } else if let ExpressionKind::Struct(_) = value.kind { - true - } else { - false - }; + let no_semicolon = + if let ExpressionKind::FunctionDef { .. } = value.kind { + true + } else if let ExpressionKind::StructDef(_) = value.kind { + true + } else { + false + }; let s = Statement { kind: s::Declaration { name, @@ -82,6 +85,7 @@ impl> Parser { type_actual: Type::Ambiguous, value, mutable, + varkind: VarKind::Undefined, }, span, }; @@ -99,7 +103,11 @@ impl> Parser { .trace_span(span, "while parsing assignment")?; Statement { span, - kind: s::Assignment { name, value }, + kind: s::Assignment { + name, + value, + varkind: VarKind::Undefined, + }, } }, // If diff --git a/src/semantic/analyzer.rs b/src/semantic/analyzer.rs new file mode 100644 index 0000000..288c9b1 --- /dev/null +++ b/src/semantic/analyzer.rs @@ -0,0 +1,356 @@ +use crate::{ + BinaryOp, Expression, ExpressionKind, Immediate, Parameter, Statement, + StatementKind, UnaryOp, + semantic::{Symbol, SymbolTable, Type, VarKind}, +}; + +use super::primitives::*; +use crate::err::*; + +pub struct Analyzer { + table: SymbolTable, +} + +impl Analyzer { + pub fn typecheck(statements: Vec) -> Vec { + let mut this = Self { + table: SymbolTable::new(), + }; + this.block(statements) + } + + fn block(&mut self, block: Vec) -> Vec { + let mut new_block = vec![]; + for st in block { + let span = st.span; + let st = match self.statement(st.into()) { + Ok(st) => *st, + Err(e) => Statement { + kind: StatementKind::Error(e), + span, + }, + }; + new_block.push(st); + } + new_block + } + + fn statement(&mut self, mut stmt: Box) -> Result> { + use Primitive as p; + use StatementKind as s; + match stmt.kind { + // Variable declaration + s::Declaration { + name, + type_str, + value, + mutable, + .. + } => { + let type_lhs = if let Some(ref s) = type_str { + self.table.get_type(s).span(&stmt.span)? + } else { + Type::Ambiguous + }; + let mut value = self.expression(value.into())?; + let type_actual = Type::coerce(&type_lhs, &value.type_) + .reason(format!( + "Expected type '{:?}', found type '{:?}'", + type_lhs, value.type_ + )) + .span(&stmt.span)?; + value.type_ = type_actual.clone(); + let varkind = self + .table + .define_symbol(name.clone(), type_actual.clone(), mutable) + .span(&stmt.span)?; + stmt.kind = s::Declaration { + name, + type_str, + type_actual, + value: *value, + mutable, + varkind, + }; + }, + // Variable assignment + s::Assignment { name, value, .. } => { + let symbol = self.table.find_symbol(&name).span(&stmt.span)?; + // Check that it is mutable + if !symbol.mutable { + return error() + .reason(format!("Cannot assign to immutable '{}'", name)) + .span(&stmt.span); + } + let mut value = *self.expression(value.into())?; + let type_actual = + Type::coerce(&symbol.type_, &value.type_).span(&stmt.span)?; + value.type_ = type_actual; + stmt.kind = s::Assignment { + name, + value, + varkind: symbol.kind, + }; + }, + s::If { + predicate, + block, + else_, + } => { + self.table.start_block(); + let predicate = *self.expression(predicate.into())?; + Type::coerce(&Type::Prim(p::boolean), &predicate.type_) + .span(&predicate.span)?; + let block = self.block(block); + let else_ = if let Some(else_) = else_ { + Some(self.statement(else_)?) + } else { + None + }; + stmt.kind = s::If { + predicate, + block, + else_, + }; + self.table.end_block(); + }, + s::While { predicate, block } => { + self.table.start_block(); + let predicate = *self.expression(predicate.into())?; + Type::coerce(&Type::Prim(p::boolean), &predicate.type_) + .span(&predicate.span)?; + let block = self.block(block); + stmt.kind = s::While { predicate, block }; + self.table.end_block(); + }, + s::Print(e) => { + stmt.kind = s::Print(*self.expression(e.into())?); + }, + s::Expression(e) => { + stmt.kind = s::Expression(*self.expression(e.into())?); + }, + s::Block(block) => { + self.table.start_block(); + let block = self.block(block); + stmt.kind = s::Block(block); + self.table.end_block(); + }, + s::Error(e) => return Err(e), + } + Ok(stmt) + } + + fn expression( + &mut self, + mut expr: Box, + ) -> Result> { + use ExpressionKind as e; + use Immediate as i; + use Primitive as p; + let type_ = match expr.kind { + e::Immediate(ref i) => Type::Prim(match i { + i::Integer(_) => p::integer_ambiguous, + i::Real(_) => p::real_ambiguous, + i::String(_) => p::string, + i::Boolean(_) => p::boolean, + }), + e::Identifier(ref i, ref mut kind) => { + let symbol = self.table.find_symbol(i)?; + *kind = symbol.kind; + self.table.find_symbol(i)?.type_ + }, + e::Binary { op, left, right } => { + let left = self.expression(left)?; + let right = self.expression(right)?; + let type_ = Type::binary_op(&left.type_, op, &right.type_)?; + expr.kind = e::Binary { left, right, op }; + type_ + }, + e::Unary { op, child } => { + let child = self.expression(child)?; + let type_ = Type::unary_op(op, &child.type_)?; + expr.kind = e::Unary { child, op }; + type_ + }, + e::Parenthesis(inner) => { + let inner = self.expression(inner)?; + let type_ = inner.type_.clone(); + expr.kind = e::Parenthesis(inner); + type_ + }, + e::FunctionDef { + mut params, + returns_str, + mut returns_actual, + body, + id, + } => { + self.table.start_func(); + for p in &mut params { + p.type_actual = self.table.get_type(&p.type_str).span(&expr.span)?; + self + .table + .define_param(p.name.clone(), p.type_actual.clone())?; + } + returns_actual = match &returns_str { + Some(s) => self.table.get_type(s).span(&expr.span)?, + None => Type::Nothing, + }; + let body = self.block(body); + self.table.end_func(); + expr.kind = e::FunctionDef { + params: params.clone(), + returns_str, + returns_actual: returns_actual.clone(), + body, + id, + }; + Type::Function { + params: params.into_iter().map(|p| p.type_actual).collect(), + returns: returns_actual.into(), + } + }, + e::FunctionCall { callee, mut args } => { + let callee = self.expression(callee)?; + // Check that this is actually a function + let Type::Function { + ref params, + ref returns, + } = callee.type_ + else { + return error() + .reason(format!("Cannot call type {:?}", callee.type_)) + .span(&callee.span); + }; + // Check for correct number of args + if params.len() != args.len() { + return error() + .reason(format!( + "Wrong number of arguments, function expects {}, found {}", + params.len(), + args.len() + )) + .span(&callee.span); + } + // Check for correct arg types + for (expect, actual) in params.iter().zip(args.iter_mut()) { + *actual = *self.expression(actual.clone().into())?; + let coerced_type = Type::coerce(expect, &actual.type_); + if let Ok(t) = coerced_type { + actual.type_ = t; + } else { + return error() + .reason(format!( + "Expected type {expect:?}, found {:?}", + actual.type_ + )) + .span(&actual.span); + } + } + let returns = *returns.clone(); + expr.kind = e::FunctionCall { callee, args }; + returns + }, + e::StructDef(mut params) => { + for p in &mut params { + p.type_actual = self.table.get_type(&p.type_str).span(&expr.span)?; + } + expr.kind = e::StructDef(params.clone()); + Type::StructDef(params) + }, + e::StructLiteral { name, args } => { + let type_ = self.table.get_type(&name).span(&expr.span)?; + let Type::Struct(params) = type_ else { + return error().reason(format!( + "Cannot construct type {:?} as struct literal", + type_ + )); + }; + if args.len() != params.len() { + return error().reason(format!( + "Incorrect number of parameters for struct '{}'; expected {}, \ + found {}", + name, + params.len(), + args.len() + )); + } + // TODO out of order params + let mut new_args = vec![]; + for ( + (argname, argexpr), + Parameter { + name: pname, + type_actual: ptype, + .. + }, + ) in args.iter().zip(params.iter()) + { + if argname != pname { + return error() + .reason(format!( + "In struct literal, expected parameter '{pname}', found \ + '{argname}'" + )) + .span(&argexpr.span); + } + let argspan = argexpr.span; + let mut arg = *self + .expression(argexpr.clone().into()) + .trace_span(expr.span, "while parsing struct literal")?; + let coerced_type = Type::coerce(ptype, &arg.type_); + if let Ok(t) = coerced_type { + arg.type_ = t; + } else { + return error() + .reason(format!( + "In struct literal, expected type '{ptype:?}', found '{:?}", + arg.type_, + )) + .span(&argspan); + } + new_args.push((argname.clone(), arg)); + } + expr.kind = e::StructLiteral { + name, + args: new_args, + }; + Type::Struct(params) + }, + e::Field { namespace, field } => { + let namespace = self.expression(namespace)?; + // Check that namespace is struct + // TODO: fields in other types + let Type::Struct(ref params) = namespace.type_ else { + return error() + .reason(format!("Type {:?} does not have fields", namespace.type_)) + .span(&namespace.span); + }; + // Check that field is identifier + // TODO: tuple fields? + let e::Identifier(ref name, _) = field.kind else { + return error() + .reason("Field must be an identifier") + .span(&field.span); + }; + let mut type_ = None; + for p in params { + if &p.name == name { + type_ = Some(p.type_actual.clone()); + break; + } + } + let type_ = type_ + .reason(format!( + "Type {:?} does not contain field {}", + namespace, name + )) + .span(&field.span)?; + expr.kind = e::Field { namespace, field }; + type_ + }, + }; + expr.type_ = type_; + Ok(expr) + } +} diff --git a/src/semantic/mod.rs b/src/semantic/mod.rs index fd64fc1..d6ae31b 100644 --- a/src/semantic/mod.rs +++ b/src/semantic/mod.rs @@ -1,39 +1,138 @@ +mod analyzer; mod primitives; mod types; -use crate::err::*; +use crate::{Parameter, err::*}; +pub use analyzer::*; pub use primitives::*; pub use types::*; +#[allow(non_camel_case_types)] +pub type uid = u32; + +// Variable and function numbering +#[derive(Clone, Copy, Debug)] +pub enum VarKind { + Global(uid), + Local(uid), + Param(uid), + Function(uid), + Undefined, +} + +impl VarKind { + pub fn unwrap(self) -> uid { + match self { + VarKind::Global(i) + | VarKind::Local(i) + | VarKind::Param(i) + | VarKind::Function(i) => i, + VarKind::Undefined => unreachable!("Failed unwrapping uid"), + } + } +} + +#[derive(Clone, Debug)] +pub struct Symbol { + pub name: String, + pub type_: Type, + pub mutable: bool, + pub kind: VarKind, +} #[derive(Debug, Clone)] -pub enum Symbol { - Var(String, Type, bool), - Type(String, Type), +pub enum Definition { + Symbol(Symbol), BlockStart, FuncStart, } +fn next(array: &mut [uid]) -> uid { + let current = array.last_mut().unwrap(); + let ret = *current; + *current += 1; + ret +} + #[derive(Debug, Clone)] pub struct SymbolTable { - syms: Vec, + syms: Vec, + nesting: usize, + local_varno: Vec, + global_varno: Vec, + funcno: Vec, } impl SymbolTable { - fn define_var(&mut self, name: String, type_: Type, mutable: bool) { - self.syms.push(Symbol::Var(name, type_, mutable)); + pub fn new() -> Self { + Self { + syms: vec![], + nesting: 0, + global_varno: vec![0], + local_varno: vec![0], + funcno: vec![0], + } } - fn define_type(&mut self, name: String, type_: Type) { - self.syms.push(Symbol::Type(name, type_)); + fn define_symbol( + &mut self, + name: String, + type_: Type, + mutable: bool, + ) -> Result { + let kind = match type_ { + Type::Prim(_) | Type::Struct(_) => { + if self.nesting == 0 { + VarKind::Global(next(&mut self.global_varno)) + } else { + VarKind::Local(next(&mut self.local_varno)) + } + }, + Type::StructDef(_) => { + if mutable { + return error().reason("Struct definition must be immutable"); + } + VarKind::Undefined + }, + Type::Function { .. } => { + if !mutable { + VarKind::Function(next(&mut self.funcno)) + } else { + return error().reason("Function declaration must be immutable"); + } + }, + _ => VarKind::Undefined, + }; + self.syms.push(Definition::Symbol(Symbol { + name, + type_, + mutable, + kind, + })); + Ok(kind) + } + + fn define_param(&mut self, name: String, type_: Type) -> Result { + let kind = VarKind::Param(next(&mut self.local_varno)); + self.syms.push(Definition::Symbol(Symbol { + name, + type_, + mutable: false, + kind, + })); + Ok(kind) } fn start_func(&mut self) { - self.syms.push(Symbol::FuncStart); + self.nesting += 1; + self.local_varno.push(0); + self.syms.push(Definition::FuncStart); } fn end_func(&mut self) { + self.nesting -= 1; + self.local_varno.pop(); while !self.syms.is_empty() { - if let Some(Symbol::FuncStart) = self.syms.pop() { + if let Some(Definition::FuncStart) = self.syms.pop() { return; } } @@ -41,34 +140,50 @@ impl SymbolTable { } fn start_block(&mut self) { - self.syms.push(Symbol::BlockStart); + self.syms.push(Definition::BlockStart); } fn end_block(&mut self) { while !self.syms.is_empty() { - if let Some(Symbol::BlockStart) = self.syms.pop() { + if let Some(Definition::BlockStart) = self.syms.pop() { return; } } unreachable!("Tried to exit global scope in symbol table") } - fn get_var(&self, name: &str) -> Result<(Type, bool)> { + fn find_symbol(&self, find_name: &str) -> Result { + let mut nesting = self.nesting; + println!("Looking for {find_name}, scope = {nesting}"); for s in self.syms.iter().rev() { - if let Symbol::Var(name2, type_, mutable) = s { - if name == name2 { - return Ok((type_.clone(), *mutable)); - } - } + match s { + Definition::Symbol(sym) + // Only search function local and global scope + if nesting == self.nesting || nesting == 0 => { + println!("{}, {:?}, {nesting}", sym.name, sym.type_); + if find_name == sym.name { + return Ok(sym.clone()); + } + }, + Definition::FuncStart => { + nesting -= 1; + }, + _ => {}, + }; } - error().reason(format!("Identifier {name} is not defined")) + error().reason(format!("Symbol '{find_name}' is not defined")) } fn get_type(&self, name: &str) -> Result { for s in self.syms.iter().rev() { - if let Symbol::Type(name2, t) = s { + if let Definition::Symbol(Symbol { + name: name2, + type_: Type::StructDef(params), + .. + }) = s + { if name == name2 { - return Ok(t.clone()); + return Ok(Type::Struct(params.clone())); } } } diff --git a/src/semantic/types.rs b/src/semantic/types.rs index 1cddc98..6308b91 100644 --- a/src/semantic/types.rs +++ b/src/semantic/types.rs @@ -1,5 +1,6 @@ use crate::{ - semantic::SymbolTable, BinaryOp, Expression, ExpressionKind, Immediate, Parameter, Statement, + semantic::{Symbol, SymbolTable}, + BinaryOp, Expression, ExpressionKind, Immediate, Parameter, Statement, StatementKind, UnaryOp, }; @@ -12,6 +13,7 @@ pub enum Type { Nothing, Prim(Primitive), Struct(Vec), + StructDef(Vec), Function { params: Vec, returns: Box, @@ -54,7 +56,7 @@ impl Type { (t::Prim(a), t::Prim(b)) => { let p = Primitive::binary_op(*a, op, *b)?; Ok(t::Prim(p)) - } + }, _ => e, } } @@ -69,392 +71,24 @@ impl Type { } } - fn coerce(expect: &Type, actual: &Type) -> Option { + pub fn coerce(expect: &Type, actual: &Type) -> Result { use Primitive as p; use Type::*; + let e = || { + error().reason(format!( + "Could not coerce type '{actual:?}' into '{expect:?}'" + )) + }; match (expect, actual) { - (Ambiguous, Ambiguous) => None, - (Ambiguous, Prim(p::integer_ambiguous)) => Some(Prim(p::integer)), - (Ambiguous, Prim(p::real_ambiguous)) => Some(Prim(p::real)), - (Ambiguous, t) => Some(t.clone()), + (Ambiguous, Ambiguous) => e(), + (Ambiguous, Prim(p::integer_ambiguous)) => Ok(Prim(p::integer)), + (Ambiguous, Prim(p::real_ambiguous)) => Ok(Prim(p::real)), + (Ambiguous, t) => Ok(t.clone()), (Prim(p1), Prim(p2)) => { let (p1, p2) = Primitive::coerce_ambiguous(*p1, *p2); - if p1 != p2 { - None - } else { - Some(Type::Prim(p1)) - } - } - _ => None, - } - } -} - -pub fn typecheck(program: Vec) -> Vec { - use StatementKind as s; - let mut table = SymbolTable { syms: vec![] }; - let mut ret = vec![]; - for s in program { - let span = s.span; - ret.push(match statement(s.into(), &mut table) { - Ok(s) => *s, - Err(e) => Statement { - kind: s::Error(e), - span, + if p1 != p2 { e() } else { Ok(Type::Prim(p1)) } }, - }) + _ => e(), + } } - ret -} - -fn statement(mut stmt: Box, table: &mut SymbolTable) -> Result> { - use Primitive as p; - use StatementKind as s; - match stmt.kind { - s::Declaration { - name, - type_str, - value, - mutable, - .. - } => { - let type_expect = match type_str { - Some(ref s) => table.get_type(s).span(&stmt.span)?, - None => Type::Ambiguous, - }; - let value = expression(value.into(), table)?; - let type_actual = Type::coerce(&type_expect, &value.type_).reason(format!( - "Expected type '{:?}', found type '{:?}'", - type_expect, value.type_ - ))?; - // Check that structs are const - if let Type::Struct(_) = type_actual { - if mutable { - return error() - .reason("Struct declarations must be immutable") - .span(&stmt.span); - } - table.define_type(name.clone(), type_actual.clone()); - } - // Check that functions are const - else if let Type::Function { .. } = type_actual { - if mutable { - return error() - .reason("Function declarations must be immutable") - .span(&stmt.span); - } - table.define_var(name.clone(), type_actual.clone(), false); - } else { - table.define_var(name.clone(), type_actual.clone(), mutable); - } - stmt.kind = s::Declaration { - name, - type_str, - type_actual, - value: *value, - mutable, - }; - } - s::Assignment { name, value } => { - let (type_, mutable) = table.get_var(&name).span(&stmt.span)?; - // Check that it is mutable - if !mutable { - return error() - .reason(format!("Cannot assign to immutable '{}'", name)) - .span(&stmt.span); - } - let value = *expression(value.into(), table)?; - // Check for correct type - if type_ != value.type_ { - return error().reason(format!( - "Attempted to assign '{:?}' to '{type_:?}'", - value.type_ - )); - } - stmt.kind = s::Assignment { name, value }; - } - s::If { - predicate, - block, - else_, - } => { - table.start_block(); - let predicate = *expression(predicate.into(), table)?; - if predicate.type_ != Type::Prim(p::boolean) { - return error() - .reason(format!( - "Predicate of if statement must be a boolean, found {:?}", - predicate.type_ - )) - .span(&predicate.span); - } - let mut new_block = vec![]; - for stmt in block { - new_block.push(*statement(stmt.into(), table)?); - } - let else_ = if let Some(else_) = else_ { - Some(statement(else_, table)?) - } else { - None - }; - stmt.kind = s::If { - predicate, - block: new_block, - else_, - }; - table.end_block(); - } - s::While { predicate, block } => { - table.start_block(); - let predicate = *expression(predicate.into(), table)?; - if predicate.type_ != Type::Prim(p::boolean) { - return error() - .reason(format!( - "Predicate of while statement must be a boolean, found {:?}", - predicate.type_ - )) - .span(&predicate.span); - } - let mut new_block = vec![]; - for stmt in block { - new_block.push(*statement(stmt.into(), table)?); - } - stmt.kind = s::While { - predicate, - block: new_block, - }; - table.end_block(); - } - s::Print(e) => { - stmt.kind = s::Print(*expression(e.into(), table)?); - } - s::Expression(mut e) => { - use ExpressionKind as e; - let is_func = if let e::Function { params, .. } = &mut e.kind { - // TODO: start/end function instead - table.start_block(); - for p in params { - p.type_actual = table.get_type(&p.type_str)?; - table.define_var(p.name.clone(), p.type_actual.clone(), false); - } - true - } else { - false - }; - stmt.kind = s::Expression(*expression(e.into(), table)?); - if is_func { - table.end_block(); - } - } - s::Block(block) => { - table.start_block(); - let mut new_block = vec![]; - for stmt in block { - new_block.push(*statement(stmt.into(), table)?); - } - stmt.kind = s::Block(new_block); - table.end_block(); - } - s::Error(e) => return Err(e), - } - Ok(stmt) -} - -fn expression(mut expr: Box, table: &SymbolTable) -> Result> { - use ExpressionKind as e; - use Immediate as i; - use Primitive as p; - let type_ = match expr.kind { - e::Immediate(ref i) => Type::Prim(match i { - i::Integer(_) => p::integer_ambiguous, - i::Real(_) => p::real_ambiguous, - i::String(_) => p::string, - i::Boolean(_) => p::boolean, - }), - e::Identifier(ref i) => table.get_var(i)?.0, - e::Binary { op, left, right } => { - let left = expression(left, table)?; - let right = expression(right, table)?; - let type_ = Type::binary_op(&left.type_, op, &right.type_)?; - expr.kind = e::Binary { left, right, op }; - type_ - } - e::Unary { op, child } => { - let child = expression(child, table)?; - let type_ = Type::unary_op(op, &child.type_)?; - expr.kind = e::Unary { child, op }; - type_ - } - e::Parenthesis(inner) => { - let inner = expression(inner, table)?; - let type_ = inner.type_.clone(); - expr.kind = e::Parenthesis(inner); - type_ - } - e::Function { - mut params, - returns_str, - mut returns_actual, - body, - } => { - for p in &mut params { - p.type_actual = table.get_type(&p.type_str).span(&expr.span)?; - } - returns_actual = match &returns_str { - Some(s) => table.get_type(s).span(&expr.span)?, - None => Type::Nothing, - }; - expr.kind = e::Function { - params: params.clone(), - returns_str, - returns_actual: returns_actual.clone(), - body, - }; - Type::Function { - params: params.into_iter().map(|p| p.type_actual).collect(), - returns: returns_actual.into(), - } - } - e::Call { callee, mut args } => { - let callee = expression(callee, table)?; - // Check that this is actually a function - let Type::Function { - ref params, - ref returns, - } = callee.type_ - else { - return error() - .reason(format!("Cannot call type {:?}", callee.type_)) - .span(&callee.span); - }; - // Check for correct number of args - if params.len() != args.len() { - return error() - .reason(format!( - "Wrong number of arguments, function expects {}, found {}", - params.len(), - args.len() - )) - .span(&callee.span); - } - // Check for correct arg types - for (expect, actual) in params.iter().zip(args.iter_mut()) { - *actual = *expression(actual.clone().into(), table)?; - let coerced_type = Type::coerce(expect, &actual.type_); - println!("{:?}, {:?}, {coerced_type:?}", expect, actual.type_); - if let Some(t) = coerced_type { - actual.type_ = t; - } else { - return error() - .reason(format!( - "Expected type {expect:?}, found {:?}", - actual.type_ - )) - .span(&actual.span); - } - } - let returns = *returns.clone(); - expr.kind = e::Call { callee, args }; - returns - } - e::Struct(mut params) => { - for p in &mut params { - p.type_actual = table.get_type(&p.type_str).span(&expr.span)?; - } - expr.kind = e::Struct(params.clone()); - Type::Struct(params) - } - e::StructLiteral { name, args } => { - let type_ = table.get_type(&name).span(&expr.span)?; - let Type::Struct(params) = type_ else { - return error().reason(format!( - "Cannot construct type {:?} as struct literal", - type_ - )); - }; - if args.len() != params.len() { - return error().reason(format!( - "Incorrect number of parameters for struct '{}'; expected {}, \ - found {}", - name, - params.len(), - args.len() - )); - } - // TODO out of order params - let mut new_args = vec![]; - for ( - (argname, argexpr), - Parameter { - name: pname, - type_actual: ptype, - .. - }, - ) in args.iter().zip(params.iter()) - { - if argname != pname { - return error() - .reason(format!( - "In struct literal, expected parameter '{pname}', found \ - '{argname}'" - )) - .span(&argexpr.span); - } - let argspan = argexpr.span; - let mut arg = *expression(argexpr.clone().into(), table) - .trace_span(expr.span, "while parsing struct literal")?; - let coerced_type = Type::coerce(ptype, &arg.type_); - if let Some(t) = coerced_type { - arg.type_ = t; - } else { - return error() - .reason(format!( - "In struct literal, expected type '{ptype:?}', found '{:?}", - arg.type_, - )) - .span(&argspan); - } - new_args.push((argname.clone(), arg)); - } - expr.kind = e::StructLiteral { - name, - args: new_args, - }; - Type::Struct(params) - } - e::Field { namespace, field } => { - let namespace = expression(namespace, table)?; - // Check that namespace is struct - // TODO: fields in other types - let Type::Struct(ref params) = namespace.type_ else { - return error() - .reason(format!("Type {:?} does not have fields", namespace.type_)) - .span(&namespace.span); - }; - // Check that field is identifier - // TODO: tuple fields? - let e::Identifier(ref name) = field.kind else { - return error() - .reason("Field must be an identifier") - .span(&field.span); - }; - let mut type_ = None; - for p in params { - if &p.name == name { - type_ = Some(p.type_actual.clone()); - break; - } - } - let type_ = type_ - .reason(format!( - "Type {:?} does not contain field {}", - namespace, name - )) - .span(&field.span)?; - expr.kind = e::Field { namespace, field }; - type_ - } - }; - expr.type_ = type_; - Ok(expr) }