diff --git a/Cargo.lock b/Cargo.lock index c39cb7752b..463dd7e585 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -82,6 +82,12 @@ dependencies = [ "alloc-no-stdlib", ] +[[package]] +name = "allocator-api2" +version = "0.2.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "683d7910e743518b0e34f1186f92494becacb047c7b6bf616c96772180fef923" + [[package]] name = "android-activity" version = "0.5.2" @@ -737,6 +743,19 @@ dependencies = [ "windows-link", ] +[[package]] +name = "chumsky" +version = "0.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "14377e276b2c8300513dff55ba4cc4142b44e5d6de6d00eb5b2307d650bb4ec1" +dependencies = [ + "hashbrown 0.15.2", + "regex-automata 0.3.9", + "serde", + "unicode-ident", + "unicode-segmentation", +] + [[package]] name = "ciborium" version = "0.2.2" @@ -2418,6 +2437,8 @@ version = "0.15.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bf151400ff0baff5465007dd2f3e717f3fe502074ca563069ce3a6629d07b289" dependencies = [ + "allocator-api2", + "equivalent", "foldhash", ] @@ -3225,7 +3246,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fc2f4eb4bc735547cfed7c0a4922cbd04a4655978c09b54f1f7b228750664c34" dependencies = [ "cfg-if", - "windows-targets 0.48.5", + "windows-targets 0.52.6", ] [[package]] @@ -3339,11 +3360,10 @@ checksum = "47e1ffaa40ddd1f3ed91f717a33c8c0ee23fff369e3aa8772b9605cc1d22f4c3" name = "math-parser" version = "0.0.0" dependencies = [ + "chumsky", "criterion", "lazy_static", "num-complex", - "pest", - "pest_derive", "thiserror 2.0.12", ] @@ -4134,51 +4154,6 @@ version = "2.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e" -[[package]] -name = "pest" -version = "2.7.15" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8b7cafe60d6cf8e62e1b9b2ea516a089c008945bb5a275416789e7db0bc199dc" -dependencies = [ - "memchr", - "thiserror 2.0.12", - "ucd-trie", -] - -[[package]] -name = "pest_derive" -version = "2.7.15" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "816518421cfc6887a0d62bf441b6ffb4536fcc926395a69e1a85852d4363f57e" -dependencies = [ - "pest", - "pest_generator", -] - -[[package]] -name = "pest_generator" -version = "2.7.15" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7d1396fd3a870fc7838768d171b4616d5c91f6cc25e377b673d714567d99377b" -dependencies = [ - "pest", - "pest_meta", - "proc-macro2", - "quote", - "syn 2.0.99", -] - -[[package]] -name = "pest_meta" -version = "2.7.15" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e1e58089ea25d717bfd31fb534e4f3afcc2cc569c70de3e239778991ea3b7dea" -dependencies = [ - "once_cell", - "pest", - "sha2", -] - [[package]] name = "petgraph" version = "0.6.5" @@ -4942,8 +4917,19 @@ checksum = "b544ef1b4eac5dc2db33ea63606ae9ffcfac26c1416a2806ae0bf5f56b201191" dependencies = [ "aho-corasick", "memchr", - "regex-automata", - "regex-syntax", + "regex-automata 0.4.9", + "regex-syntax 0.8.5", +] + +[[package]] +name = "regex-automata" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "59b23e92ee4318893fa3fe3e6fb365258efbfe6ac6ab30f090cdcbb7aa37efa9" +dependencies = [ + "aho-corasick", + "memchr", + "regex-syntax 0.7.5", ] [[package]] @@ -4954,9 +4940,15 @@ checksum = "809e8dc61f6de73b46c85f4c96486310fe304c434cfa43669d7b40f711150908" dependencies = [ "aho-corasick", "memchr", - "regex-syntax", + "regex-syntax 0.8.5", ] +[[package]] +name = "regex-syntax" +version = "0.7.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dbb5fb1acd8a1a18b3dd5be62d25485eb770e05afb408a9627d14d451bae12da" + [[package]] name = "regex-syntax" version = "0.8.5" @@ -6670,12 +6662,6 @@ version = "1.18.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1dccffe3ce07af9386bfd29e80c0ab1a8205a2fc34e4bcd40364df902cfa8f3f" -[[package]] -name = "ucd-trie" -version = "0.1.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2896d95c02a80c6d6a5d6e953d479f5ddf2dfdb6a244441010e373ac0fb88971" - [[package]] name = "unic-char-property" version = "0.9.0" diff --git a/frontend/wasm/src/editor_api.rs b/frontend/wasm/src/editor_api.rs index 7284cc3fce..8834cb5ab6 100644 --- a/frontend/wasm/src/editor_api.rs +++ b/frontend/wasm/src/editor_api.rs @@ -964,7 +964,6 @@ pub fn evaluate_math_expression(expression: &str) -> Option { let value = math_parser::evaluate(expression) .inspect_err(|err| error!("Math parser error on \"{expression}\": {err}")) .ok()? - .0 .inspect_err(|err| error!("Math evaluate error on \"{expression}\": {err} ")) .ok()?; let Some(real) = value.as_real() else { diff --git a/libraries/math-parser/Cargo.toml b/libraries/math-parser/Cargo.toml index b84da885f6..efedc5ede0 100644 --- a/libraries/math-parser/Cargo.toml +++ b/libraries/math-parser/Cargo.toml @@ -8,11 +8,10 @@ description = "Parser for Graphite style mathematics expressions" license = "MIT OR Apache-2.0" [dependencies] -pest = "2.7" -pest_derive = "2.7.11" thiserror = "2.0" lazy_static = "1.5" num-complex = "0.4" +chumsky = { version = "0.10", default-features = false, features = ["std"] } [dev-dependencies] criterion = "0.5" diff --git a/libraries/math-parser/benches/bench.rs b/libraries/math-parser/benches/bench.rs index fd1824c9a0..fc27144491 100644 --- a/libraries/math-parser/benches/bench.rs +++ b/libraries/math-parser/benches/bench.rs @@ -16,7 +16,7 @@ macro_rules! generate_benchmarks { fn evaluation_bench(c: &mut Criterion) { $( - let expr = ast::Node::try_parse_from_str($input).unwrap().0; + let expr = ast::Node::try_parse_from_str($input).unwrap(); let context = EvalContext::default(); c.bench_function(concat!("eval ", $input), |b| { diff --git a/libraries/math-parser/src/ast.rs b/libraries/math-parser/src/ast.rs index 4c42fc4b1d..051db8466a 100644 --- a/libraries/math-parser/src/ast.rs +++ b/libraries/math-parser/src/ast.rs @@ -37,7 +37,7 @@ impl Unit { } } -#[derive(Debug, PartialEq)] +#[derive(Debug, Clone, PartialEq)] pub enum Literal { Float(f64), Complex(Complex), @@ -56,6 +56,11 @@ pub enum BinaryOp { Mul, Div, Pow, + Leq, + Lt, + Geq, + Gt, + Eq, } #[derive(Debug, PartialEq, Clone, Copy)] @@ -72,4 +77,5 @@ pub enum Node { FnCall { name: String, expr: Vec }, BinOp { lhs: Box, op: BinaryOp, rhs: Box }, UnaryOp { expr: Box, op: UnaryOp }, + Conditional { condition: Box, if_block: Box, else_block: Box }, } diff --git a/libraries/math-parser/src/constants.rs b/libraries/math-parser/src/constants.rs index c010d13253..cdf3b2439e 100644 --- a/libraries/math-parser/src/constants.rs +++ b/libraries/math-parser/src/constants.rs @@ -2,13 +2,21 @@ use crate::value::{Number, Value}; use lazy_static::lazy_static; use num_complex::{Complex, ComplexFloat}; use std::collections::HashMap; -use std::f64::consts::PI; +use std::f64::consts::{LN_2, PI}; type FunctionImplementation = Box Option + Send + Sync>; lazy_static! { pub static ref DEFAULT_FUNCTIONS: HashMap<&'static str, FunctionImplementation> = { let mut map: HashMap<&'static str, FunctionImplementation> = HashMap::new(); + map.insert( + "sqrt", + Box::new(|values| match values{ + [Value::Number(Number::Real(real))] => Some(Value::Number(Number::Real(real.sqrt()))), + [Value::Number(Number::Complex(complex))] => Some(Value::Number(Number::Complex(complex.sqrt()))), + _ => None, + }) + ); map.insert( "sin", Box::new(|values| match values { @@ -116,6 +124,227 @@ lazy_static! { _ => None, }), ); + // Hyperbolic Functions + map.insert( + "sinh", + Box::new(|values| match values { + [Value::Number(Number::Real(real))] => Some(Value::Number(Number::Real(real.sinh()))), + [Value::Number(Number::Complex(complex))] => Some(Value::Number(Number::Complex(complex.sinh()))), + _ => None, + }), + ); + + map.insert( + "cosh", + Box::new(|values| match values { + [Value::Number(Number::Real(real))] => Some(Value::Number(Number::Real(real.cosh()))), + [Value::Number(Number::Complex(complex))] => Some(Value::Number(Number::Complex(complex.cosh()))), + _ => None, + }), + ); + + map.insert( + "tanh", + Box::new(|values| match values { + [Value::Number(Number::Real(real))] => Some(Value::Number(Number::Real(real.tanh()))), + [Value::Number(Number::Complex(complex))] => Some(Value::Number(Number::Complex(complex.tanh()))), + _ => None, + }), + ); + + // Inverse Hyperbolic Functions + map.insert( + "asinh", + Box::new(|values| match values { + [Value::Number(Number::Real(real))] => Some(Value::Number(Number::Real(real.asinh()))), + [Value::Number(Number::Complex(complex))] => Some(Value::Number(Number::Complex(complex.asinh()))), + _ => None, + }), + ); + + map.insert( + "acosh", + Box::new(|values| match values { + [Value::Number(Number::Real(real))] => Some(Value::Number(Number::Real(real.acosh()))), + [Value::Number(Number::Complex(complex))] => Some(Value::Number(Number::Complex(complex.acosh()))), + _ => None, + }), + ); + + map.insert( + "atanh", + Box::new(|values| match values { + [Value::Number(Number::Real(real))] => Some(Value::Number(Number::Real(real.atanh()))), + [Value::Number(Number::Complex(complex))] => Some(Value::Number(Number::Complex(complex.atanh()))), + _ => None, + }), + ); + + // Logarithm Functions + map.insert( + "ln", + Box::new(|values| match values { + [Value::Number(Number::Real(real))] => Some(Value::Number(Number::Real(real.ln()))), + [Value::Number(Number::Complex(complex))] => Some(Value::Number(Number::Complex(complex.ln()))), + _ => None, + }), + ); + + map.insert( + "log", + Box::new(|values| match values { + [Value::Number(Number::Real(real))] => Some(Value::Number(Number::Real(real.log10()))), + [Value::Number(Number::Complex(complex))] => Some(Value::Number(Number::Complex(complex.log10()))), + [Value::Number(n), Value::Number(base)] => { + // Custom base logarithm using change of base formula + let compute_log = |x: f64, b: f64| -> f64 { x.ln() / b.ln() }; + match (n, base) { + (Number::Real(x), Number::Real(b)) => Some(Value::Number(Number::Real(compute_log(*x, *b)))), + _ => None, + } + } + _ => None, + }), + ); + + map.insert( + "log2", + Box::new(|values| match values { + [Value::Number(Number::Real(real))] => Some(Value::Number(Number::Real(real.log2()))), + [Value::Number(Number::Complex(complex))] => Some(Value::Number(Number::Complex(complex / LN_2))), + _ => None, + }), + ); + + // Root Functions + map.insert( + "sqrt", + Box::new(|values| match values { + [Value::Number(Number::Real(real))] => Some(Value::Number(Number::Real(real.sqrt()))), + [Value::Number(Number::Complex(complex))] => Some(Value::Number(Number::Complex(complex.sqrt()))), + _ => None, + }), + ); + + map.insert( + "cbrt", + Box::new(|values| match values { + [Value::Number(Number::Real(real))] => Some(Value::Number(Number::Real(real.cbrt()))), + [Value::Number(Number::Complex(complex))] => Some(Value::Number(Number::Complex(complex.powf(1.0/3.0)))), + _ => None, + }), + ); + + // Geometry Functions + map.insert( + "hypot", + Box::new(|values| match values { + [Value::Number(Number::Real(a)), Value::Number(Number::Real(b))] => { + Some(Value::Number(Number::Real(a.hypot(*b)))) + }, + _ => None, + }), + ); + + // Mapping Functions + map.insert( + "abs", + Box::new(|values| match values { + [Value::Number(Number::Real(real))] => Some(Value::Number(Number::Real(real.abs()))), + [Value::Number(Number::Complex(complex))] => Some(Value::Number(Number::Real(complex.abs()))), + _ => None, + }), + ); + + map.insert( + "floor", + Box::new(|values| match values { + [Value::Number(Number::Real(real))] => Some(Value::Number(Number::Real(real.floor()))), + _ => None, + }), + ); + + map.insert( + "ceil", + Box::new(|values| match values { + [Value::Number(Number::Real(real))] => Some(Value::Number(Number::Real(real.ceil()))), + _ => None, + }), + ); + + map.insert( + "round", + Box::new(|values| match values { + [Value::Number(Number::Real(real))] => Some(Value::Number(Number::Real(real.round()))), + _ => None, + }), + ); + + map.insert( + "clamp", + Box::new(|values| match values { + [Value::Number(Number::Real(x)), Value::Number(Number::Real(min)), Value::Number(Number::Real(max))] => { + Some(Value::Number(Number::Real(x.clamp(*min, *max)))) + }, + _ => None, + }), + ); + + map.insert( + "lerp", + Box::new(|values| match values { + [Value::Number(Number::Real(a)), Value::Number(Number::Real(b)), Value::Number(Number::Real(t))] => { + Some(Value::Number(Number::Real(a + (b - a) * t))) + }, + _ => None, + }), + ); + + // Complex Number Functions + map.insert( + "real", + Box::new(|values| match values { + [Value::Number(Number::Complex(complex))] => Some(Value::Number(Number::Real(complex.re))), + [Value::Number(Number::Real(real))] => Some(Value::Number(Number::Real(*real))), + _ => None, + }), + ); + + map.insert( + "imag", + Box::new(|values| match values { + [Value::Number(Number::Complex(complex))] => Some(Value::Number(Number::Real(complex.im))), + [Value::Number(Number::Real(_))] => Some(Value::Number(Number::Real(0.0))), + _ => None, + }), + ); + + // Logical Functions + map.insert( + "isnan", + Box::new(|values| match values { + [Value::Number(Number::Real(real))] => Some(Value::Number(Number::Real(if real.is_nan() { 1.0 } else { 0.0 }))), + _ => None, + }), + ); + + map.insert( + "eq", + Box::new(|values| match values { + [Value::Number(a), Value::Number(b)] => Some(Value::Number(Number::Real(if a == b { 1.0 } else { 0.0 }))), + _ => None, + }), + ); + + map.insert( + "greater", + Box::new(|values| match values { + [Value::Number(Number::Real(a)), Value::Number(Number::Real(b))] => { + Some(Value::Number(Number::Real(if a > b { 1.0 } else { 0.0 }))) + }, + _ => None, + }), + ); map }; diff --git a/libraries/math-parser/src/executer.rs b/libraries/math-parser/src/executer.rs index 9d6180f1ab..85871d769c 100644 --- a/libraries/math-parser/src/executer.rs +++ b/libraries/math-parser/src/executer.rs @@ -2,6 +2,7 @@ use crate::ast::{Literal, Node}; use crate::constants::DEFAULT_FUNCTIONS; use crate::context::{EvalContext, FunctionProvider, ValueProvider}; use crate::value::{Number, Value}; +use num_complex::Complex; use thiserror::Error; #[derive(Debug, Error)] @@ -24,7 +25,7 @@ impl Node { }, Node::BinOp { lhs, op, rhs } => match (lhs.eval(context)?, rhs.eval(context)?) { - (Value::Number(lhs), Value::Number(rhs)) => Ok(Value::Number(lhs.binary_op(*op, rhs))), + (Value::Number(lhs), Value::Number(rhs)) => Ok(Value::Number(lhs.binary_op(*op, rhs).ok_or(EvalError::TypeError)?)), }, Node::UnaryOp { expr, op } => match expr.eval(context)? { Value::Number(num) => Ok(Value::Number(num.unary_op(*op))), @@ -40,6 +41,14 @@ impl Node { context.get_value(name).ok_or_else(|| EvalError::MissingFunction(name.to_string())) } } + Node::Conditional { condition, if_block, else_block } => { + let condition = match condition.eval(context)? { + Value::Number(Number::Real(number)) => number != 0.0, + Value::Number(Number::Complex(number)) => number != Complex::ZERO, + }; + + if condition { if_block.eval(context) } else { else_block.eval(context) } + } } } } diff --git a/libraries/math-parser/src/grammer.pest b/libraries/math-parser/src/grammer.pest deleted file mode 100644 index d7a61939df..0000000000 --- a/libraries/math-parser/src/grammer.pest +++ /dev/null @@ -1,60 +0,0 @@ -WHITESPACE = _{ " " | "\t" } - -// TODO: Proper indentation and formatting -program = _{ SOI ~ expr ~ EOI } - -expr = { atom ~ (infix ~ atom)* } -atom = _{ prefix? ~ primary ~ postfix? } -infix = _{ add | sub | mul | div | pow | paren } -add = { "+" } // Addition -sub = { "-" } // Subtraction -mul = { "*" } // Multiplication -div = { "/" } // Division -mod = { "%" } // Modulo -pow = { "^" } // Exponentiation -paren = { "" } // Implicit multiplication operator - -prefix = _{ neg | sqrt } -neg = { "-" } // Negation -sqrt = { "sqrt" } - -postfix = _{ fac } -fac = { "!" } // Factorial - -primary = _{ ("(" ~ expr ~ ")") | lit | constant | fn_call | ident } -fn_call = { ident ~ "(" ~ expr ~ ("," ~ expr)* ~ ")" } -ident = @{ (ASCII_ALPHA | "_") ~ (ASCII_ALPHANUMERIC | "_")* } -lit = { unit | ((float | int) ~ unit?) } - -float = @{ int ~ "." ~ int? ~ exp? | int ~ exp } -exp = _{ ^"e" ~ ("+" | "-")? ~ int } -int = @{ ASCII_DIGIT+ } - -unit = ${ (scale ~ base_unit) | base_unit ~ !ident} -base_unit = _{ meter | second | gram } -meter = { "m" } -second = { "s" } -gram = { "g" } - -scale = _{ nano | micro | milli | centi | deci | deca | hecto | kilo | mega | giga | tera } -nano = { "n" } -micro = { "µ" | "u" } -milli = { "m" } -centi = { "c" } -deci = { "d" } -deca = { "da" } -hecto = { "h" } -kilo = { "k" } -mega = { "M" } -giga = { "G" } -tera = { "T" } - -// Constants -constant = { infinity | imaginary_unit | pi | tau | euler_number | golden_ratio | gravity_acceleration } -infinity = { "inf" | "INF" | "infinity" | "INFINITY" | "∞" } -imaginary_unit = { "i" | "I" } -pi = { "pi" | "PI" | "π" } -tau = { "tau" | "TAU" | "τ" } -euler_number = { "e" } -golden_ratio = { "phi" | "PHI" | "φ" } -gravity_acceleration = { "G" } diff --git a/libraries/math-parser/src/lexer.rs b/libraries/math-parser/src/lexer.rs new file mode 100644 index 0000000000..9bce2a5656 --- /dev/null +++ b/libraries/math-parser/src/lexer.rs @@ -0,0 +1,253 @@ +// ── lexer.rs ─────────────────────────────────────────────────────────── +use crate::ast::Literal; +use chumsky::input::{Input, ValueInput}; +use chumsky::prelude::*; +use chumsky::span::SimpleSpan; +use chumsky::text::{ident, int}; +use core::f64; +use num_complex::Complex64; +use std::iter::Peekable; +use std::ops::Range; +use std::str::Chars; + +pub type Span = SimpleSpan; + +#[derive(Clone, Debug, PartialEq)] +pub enum Token<'src> { + // literals ---------------------------------------------------------------- + Const(Literal), // numeric or complex constants recognised at lex‑time + Ident(&'src str), + // punctuation ------------------------------------------------------------- + LParen, + RParen, + Comma, + Plus, + Minus, + Star, + Slash, + Caret, + // comparison -------------------------------------------------------------- + Lt, + Le, + Gt, + Ge, + EqEq, + // keywords ---------------------------------------------------------------- + If, +} + +fn const_lit(name: &str) -> Option { + use std::f64::consts::*; + + Some(match name { + "pi" | "π" => Literal::Float(PI), + "tau" | "τ" => Literal::Float(TAU), + "e" => Literal::Float(E), + "phi" | "φ" => Literal::Float(1.618_033_988_75), + "inf" | "∞" => Literal::Float(f64::INFINITY), + "i" => Literal::Complex(Complex64::new(0.0, 1.0)), + "G" => Literal::Float(9.80665), + _ => return None, + }) +} + +pub struct Lexer<'a> { + input: &'a str, + pos: usize, +} + +impl<'a> Lexer<'a> { + pub fn new(input: &'a str) -> Self { + Self { input, pos: 0 } + } + + fn peek(&self) -> Option { + self.input[self.pos..].chars().next() + } + + fn bump(&mut self) -> Option { + let c = self.peek()?; + self.pos += c.len_utf8(); + Some(c) + } + + fn consume_while(&mut self, cond: F) -> &'a str + where + F: Fn(char) -> bool, + { + let start = self.pos; + while self.peek().is_some_and(&cond) { + self.bump(); + } + &self.input[start..self.pos] + } + + fn lex_ident(&mut self) -> &'a str { + self.consume_while(|c| c.is_alphanumeric() || c == '_') + } + + fn lex_uint(&mut self) -> Option<(u64, usize)> { + let mut v = 0u64; + let mut digits = 0; + while let Some(d) = self.peek().and_then(|c| c.to_digit(10)) { + v = v * 10 + d as u64; + digits += 1; + self.bump(); + } + (digits > 0).then_some((v, digits)) + } + + fn lex_number(&mut self) -> Option { + let start_pos = self.pos; + let (int_val, int_digits) = self.lex_uint().unwrap_or((0, 0)); + let mut got_digit = int_digits > 0; + let mut num = int_val as f64; + + if self.peek() == Some('.') { + self.bump(); + if let Some((frac_val, frac_digits)) = self.lex_uint() { + num += (frac_val as f64) / 10f64.powi(frac_digits as i32); + got_digit = true; + } + } + + if matches!(self.peek(), Some('e' | 'E')) { + self.bump(); + let sign = match self.peek() { + Some('+') => { + self.bump(); + 1 + } + Some('-') => { + self.bump(); + -1 + } + _ => 1, + }; + if let Some((exp_val, _)) = self.lex_uint() { + num *= 10f64.powi(sign * exp_val as i32); + } else { + self.pos = start_pos; + return None; + } + } + + got_digit.then_some(num) + } + + fn skip_ws(&mut self) { + self.consume_while(char::is_whitespace); + } + + pub fn next_token(&mut self) -> Option> { + self.skip_ws(); + let start = self.pos; + let ch = self.bump()?; + + use Token::*; + let tok = match ch { + '(' => LParen, + ')' => RParen, + ',' => Comma, + '+' => Plus, + '-' => Minus, + '*' => Star, + '/' => Slash, + '^' => Caret, + + '<' => { + if self.peek() == Some('=') { + self.bump(); + Le + } else { + Lt + } + } + '>' => { + if self.peek() == Some('=') { + self.bump(); + Ge + } else { + Gt + } + } + '=' => { + if self.peek() == Some('=') { + self.bump(); + EqEq + } else { + return None; + } + } + + c if c.is_ascii_digit() || (c == '.' && self.peek().is_some_and(|c| c.is_ascii_digit())) => { + self.pos = start; + Const(Literal::Float(self.lex_number()?)) + } + + _ => { + self.consume_while(|c| c.is_alphanumeric() || c == '_'); + let ident = &self.input[start..self.pos]; + + if ident == "if" { + If + } else if let Some(lit) = const_lit(ident) { + Const(lit) + } else if ch.is_alphanumeric() { + Ident(ident) + } else { + return None; + } + } + }; + + Some(tok) + } +} + +impl<'src> Input<'src> for Lexer<'src> { + type Token = Token<'src>; + type Span = Span; + type Cursor = usize; // byte offset inside `input` + type MaybeToken = Token<'src>; + type Cache = Self; + + #[inline] + fn begin(self) -> (Self::Cursor, Self::Cache) { + (0, self) + } + + #[inline] + fn cursor_location(cursor: &Self::Cursor) -> usize { + *cursor + } + + #[inline] + unsafe fn next_maybe(this: &mut Self::Cache, cursor: &mut Self::Cursor) -> Option { + this.pos = *cursor; + if let Some(tok) = this.next_token() { + *cursor = this.pos; + Some(tok) + } else { + None + } + } + + #[inline] + unsafe fn span(_this: &mut Self::Cache, range: Range<&Self::Cursor>) -> Self::Span { + (*range.start..*range.end).into() + } +} + +impl<'src> ValueInput<'src> for Lexer<'src> { + #[inline] + unsafe fn next(this: &mut Self::Cache, cursor: &mut Self::Cursor) -> Option { + this.pos = *cursor; + if let Some(tok) = this.next_token() { + *cursor = this.pos; + Some(tok) + } else { + None + } + } +} diff --git a/libraries/math-parser/src/lib.rs b/libraries/math-parser/src/lib.rs index e596d78dff..e360ef0cce 100644 --- a/libraries/math-parser/src/lib.rs +++ b/libraries/math-parser/src/lib.rs @@ -4,6 +4,7 @@ pub mod ast; mod constants; pub mod context; pub mod executer; +pub mod lexer; pub mod parser; pub mod value; @@ -13,10 +14,10 @@ use executer::EvalError; use parser::ParseError; use value::Value; -pub fn evaluate(expression: &str) -> Result<(Result, Unit), ParseError> { +pub fn evaluate(expression: &str) -> Result, ParseError<'_>> { let expr = ast::Node::try_parse_from_str(expression); let context = EvalContext::default(); - expr.map(|(node, unit)| (node.eval(&context), unit)) + expr.map(|node| node.eval(&context)) } #[cfg(test)] @@ -28,21 +29,21 @@ mod tests { const EPSILON: f64 = 1e-10_f64; macro_rules! test_end_to_end{ - ($($name:ident: $input:expr_2021 => ($expected_value:expr_2021, $expected_unit:expr_2021)),* $(,)?) => { + ($($name:ident: $input:expr_2021 => $expected_value:expr_2021),* $(,)?) => { $( #[test] fn $name() { let expected_value = $expected_value; - let expected_unit = $expected_unit; let expr = ast::Node::try_parse_from_str($input); let context = EvalContext::default(); - let (actual_value, actual_unit) = expr.map(|(node, unit)| (node.eval(&context), unit)).unwrap(); + dbg!(&expr); + + let actual_value = expr.map(|node| node.eval(&context)).unwrap(); let actual_value = actual_value.unwrap(); - assert!(actual_unit == expected_unit, "Expected unit {:?} but found unit {:?}", expected_unit, actual_unit); let expected_value = expected_value.into(); @@ -85,65 +86,97 @@ mod tests { } test_end_to_end! { - // Basic arithmetic and units - infix_addition: "5 + 5" => (10., Unit::BASE_UNIT), - infix_subtraction_units: "5m - 3m" => (2., Unit::LENGTH), - infix_multiplication_units: "4s * 4s" => (16., Unit { length: 0, mass: 0, time: 2 }), - infix_division_units: "8m/2s" => (4., Unit::VELOCITY), + // Basic arithmetic + infix_addition: "5 + 5" => 10., + infix_subtraction_units: "5 - 3" => 2., + infix_multiplication_units: "4 * 4" => 16., + infix_division_units: "8/2" => 4., // Order of operations - order_of_operations_negative_prefix: "-10 + 5" => (-5., Unit::BASE_UNIT), - order_of_operations_add_multiply: "5+1*1+5" => (11., Unit::BASE_UNIT), - order_of_operations_add_negative_multiply: "5+(-1)*1+5" => (9., Unit::BASE_UNIT), - order_of_operations_sqrt: "sqrt25 + 11" => (16., Unit::BASE_UNIT), - order_of_operations_sqrt_expression: "sqrt(25+11)" => (6., Unit::BASE_UNIT), + order_of_operations_negative_prefix: "-10 + 5" => -5., + order_of_operations_add_multiply: "5+1*1+5" => 11., + order_of_operations_add_negative_multiply: "5+(-1)*1+5" => 9., + order_of_operations_sqrt: "sqrt(25) + 11" => 16., + order_of_operations_sqrt_expression: "sqrt(25+11)" => 6., // Parentheses and nested expressions - parentheses_nested_multiply: "(5 + 3) * (2 + 6)" => (64., Unit::BASE_UNIT), - parentheses_mixed_operations: "2 * (3 + 5 * (2 + 1))" => (36., Unit::BASE_UNIT), - parentheses_divide_add_multiply: "10 / (2 + 3) + (7 * 2)" => (16., Unit::BASE_UNIT), + parentheses_nested_multiply: "(5 + 3) * (2 + 6)" => 64., + parentheses_mixed_operations: "2 * (3 + 5 * (2 + 1))" => 36., + parentheses_divide_add_multiply: "10 / (2 + 3) + (7 * 2)" => 16., // Square root and nested square root - sqrt_chain_operations: "sqrt(16) + sqrt(9) * sqrt(4)" => (10., Unit::BASE_UNIT), - sqrt_nested: "sqrt(sqrt(81))" => (3., Unit::BASE_UNIT), - sqrt_divide_expression: "sqrt((25 + 11) / 9)" => (2., Unit::BASE_UNIT), + sqrt_chain_operations: "sqrt(16) + sqrt(9) * sqrt(4)" => 10., + sqrt_nested: "sqrt(sqrt(81))" => 3., + sqrt_divide_expression: "sqrt((25 + 11) / 9)" => 2., // Mixed square root and units - sqrt_multiply_units: "sqrt(16) * 2g + 5g" => (13., Unit::MASS), - sqrt_add_multiply: "sqrt(49) - 1 + 2 * 3" => (12., Unit::BASE_UNIT), - sqrt_addition_multiply: "(sqrt(36) + 2) * 2" => (16., Unit::BASE_UNIT), + sqrt_add_multiply: "sqrt(49) - 1 + 2 * 3" => 12., + sqrt_addition_multiply: "(sqrt(36) + 2) * 2" => 16., // Exponentiation - exponent_single: "2^3" => (8., Unit::BASE_UNIT), - exponent_mixed_operations: "2^3 + 4^2" => (24., Unit::BASE_UNIT), - exponent_nested: "2^(3+1)" => (16., Unit::BASE_UNIT), + exponent_single: "2^3" => 8., + exponent_mixed_operations: "2^3 + 4^2" => 24., + exponent_nested: "2^(3+1)" => 16., // Operations with negative values - negative_units_add_multiply: "-5s + (-3 * 2)s" => (-11., Unit::TIME), - negative_nested_parentheses: "-(5 + 3 * (2 - 1))" => (-8., Unit::BASE_UNIT), - negative_sqrt_addition: "-(sqrt(16) + sqrt(9))" => (-7., Unit::BASE_UNIT), - multiply_sqrt_subtract: "5 * 2 + sqrt(16) / 2 - 3" => (9., Unit::BASE_UNIT), - add_multiply_subtract_sqrt: "4 + 3 * (2 + 1) - sqrt(25)" => (8., Unit::BASE_UNIT), - add_sqrt_subtract_nested_multiply: "10 + sqrt(64) - (5 * (2 + 1))" => (3., Unit::BASE_UNIT), + negative_nested_parentheses: "-(5 + 3 * (2 - 1))" => -8., + negative_sqrt_addition: "-(sqrt(16) + sqrt(9))" => -7., + multiply_sqrt_subtract: "5 * 2 + sqrt(16) / 2 - 3" => 9., + add_multiply_subtract_sqrt: "4 + 3 * (2 + 1) - sqrt(25)" => 8., + add_sqrt_subtract_nested_multiply: "10 + sqrt(64) - (5 * (2 + 1))" => 3., // Mathematical constants - constant_pi: "pi" => (std::f64::consts::PI, Unit::BASE_UNIT), - constant_e: "e" => (std::f64::consts::E, Unit::BASE_UNIT), - constant_phi: "phi" => (1.61803398875, Unit::BASE_UNIT), - constant_tau: "tau" => (2.0 * std::f64::consts::PI, Unit::BASE_UNIT), - constant_infinity: "inf" => (f64::INFINITY, Unit::BASE_UNIT), - constant_infinity_symbol: "∞" => (f64::INFINITY, Unit::BASE_UNIT), - multiply_pi: "2 * pi" => (2.0 * std::f64::consts::PI, Unit::BASE_UNIT), - add_e_constant: "e + 1" => (std::f64::consts::E + 1.0, Unit::BASE_UNIT), - multiply_phi_constant: "phi * 2" => (1.61803398875 * 2.0, Unit::BASE_UNIT), - exponent_tau: "2^tau" => (2f64.powf(2.0 * std::f64::consts::PI), Unit::BASE_UNIT), - infinity_subtract_large_number: "inf - 1000" => (f64::INFINITY, Unit::BASE_UNIT), + constant_pi: "pi" => std::f64::consts::PI, + constant_e: "e" => std::f64::consts::E, + constant_phi: "phi" => 1.61803398875, + constant_tau: "tau" => 2.0 * std::f64::consts::PI, + constant_infinity: "inf" => f64::INFINITY, + constant_infinity_symbol: "∞" => f64::INFINITY, + multiply_pi: "2 * pi" => 2.0 * std::f64::consts::PI, + add_e_constant: "e + 1" => std::f64::consts::E + 1.0, + multiply_phi_constant: "phi * 2" => 1.61803398875 * 2.0, + exponent_tau: "2^tau" => 2f64.powf(2.0 * std::f64::consts::PI), + infinity_subtract_large_number: "inf - 1000" => f64::INFINITY, // Trigonometric functions - trig_sin_pi: "sin(pi)" => (0.0, Unit::BASE_UNIT), - trig_cos_zero: "cos(0)" => (1.0, Unit::BASE_UNIT), - trig_tan_pi_div_four: "tan(pi/4)" => (1.0, Unit::BASE_UNIT), - trig_sin_tau: "sin(tau)" => (0.0, Unit::BASE_UNIT), - trig_cos_tau_div_two: "cos(tau/2)" => (-1.0, Unit::BASE_UNIT), + trig_sin_pi: "sin(pi)" => 0.0, + trig_cos_zero: "cos(0)" => 1.0, + trig_tan_pi_div_four: "tan(pi/4)" => 1.0, + trig_sin_tau: "sin(tau)" => 0.0, + trig_cos_tau_div_two: "cos(tau/2)" => -1.0, + + // Basic if statements + if_true_condition: "if(1,5,3)" => 5., + if_false_condition: "if(0, 5, 3)" => 3., + + // Arithmetic conditions + if_arithmetic_true: "if(2+2-4, 1 , 0)" => 0., + if_arithmetic_false: "if(3*2-5, 1, 0)" => 1., + + // Nested arithmetic + if_complex_arithmetic: "if((5+3)*(2-1), 10, 20)" => 10., + if_with_division: "if(8/4-2 == 0,15, 25)" => 15., + + // Constants in conditions + if_with_pi: "if(pi > 3, 1, 0)" => 1., + if_with_e: "if(e < 3, 1, 0)" => 1., + + // Functions in conditions + if_with_sqrt: "if(sqrt(16) == 4, 1, 0)" => 1., + if_with_sin: "if(sin(pi) == 0.0, 1, 0)" => 0., + + // Nested if statements + nested_if: "if(1, if(0, 1, 2), 3)" => 2., + nested_if_complex: "if(2-2 == 0, if(1, 5, 6), if(1, 7, 8))" => 5., + + // Mixed operations in conditions and blocks + if_complex_condition: "if(sqrt(16) + sin(pi) < 5, 2*pi, 3*e)" => 2. * std::f64::consts::PI, + if_complex_blocks: "if(1, 2*sqrt(16) + sin(pi/2), 3*cos(0) + 4)" => 9., + + // Edge cases + if_zero: "if(0.0, 1, 2)" => 2., + + // Complex nested expressions + if_nested_expr: "if((sqrt(16) + 2) * (sin(pi) + 1), 3 + 4 * 2, 5 - 2 / 1)" => 11., } } diff --git a/libraries/math-parser/src/parser.rs b/libraries/math-parser/src/parser.rs index 101995d87b..225d0d7931 100644 --- a/libraries/math-parser/src/parser.rs +++ b/libraries/math-parser/src/parser.rs @@ -1,316 +1,118 @@ use crate::ast::{BinaryOp, Literal, Node, UnaryOp, Unit}; use crate::context::EvalContext; +use crate::lexer::{Lexer, Span, Token}; use crate::value::{Complex, Number, Value}; +use chumsky::container::Seq; +use chumsky::input::{BorrowInput, ValueInput}; +use chumsky::{Parser, prelude::*}; use lazy_static::lazy_static; use num_complex::ComplexFloat; -use pest::Parser; -use pest::iterators::{Pair, Pairs}; -use pest::pratt_parser::{Assoc, Op, PrattParser}; -use pest_derive::Parser; use std::num::{ParseFloatError, ParseIntError}; use thiserror::Error; -#[derive(Parser)] -#[grammar = "./grammer.pest"] // Point to the grammar file -struct ExprParser; - -lazy_static! { - static ref PRATT_PARSER: PrattParser = { - PrattParser::new() - .op(Op::infix(Rule::add, Assoc::Left) | Op::infix(Rule::sub, Assoc::Left)) - .op(Op::infix(Rule::mul, Assoc::Left) | Op::infix(Rule::div, Assoc::Left) | Op::infix(Rule::paren, Assoc::Left)) - .op(Op::infix(Rule::pow, Assoc::Right)) - .op(Op::postfix(Rule::fac) | Op::postfix(Rule::EOI)) - .op(Op::prefix(Rule::sqrt)) - .op(Op::prefix(Rule::neg)) - }; -} - -#[derive(Error, Debug)] -pub enum TypeError { - #[error("Invalid BinOp: {0:?} {1:?} {2:?}")] - InvalidBinaryOp(Unit, BinaryOp, Unit), - - #[error("Invalid UnaryOp: {0:?}")] - InvalidUnaryOp(Unit, UnaryOp), -} - #[derive(Error, Debug)] -pub enum ParseError { - #[error("ParseIntError: {0}")] - ParseInt(#[from] ParseIntError), - #[error("ParseFloatError: {0}")] - ParseFloat(#[from] ParseFloatError), - - #[error("TypeError: {0}")] - Type(#[from] TypeError), - - #[error("PestError: {0}")] - Pest(#[from] Box>), +pub enum ParseError<'src> { + #[error("syntax error(s): {0:#?}")] + Parse(Vec, Span>>), } impl Node { - pub fn try_parse_from_str(s: &str) -> Result<(Node, Unit), ParseError> { - let pairs = ExprParser::parse(Rule::program, s).map_err(Box::new)?; - let (node, metadata) = parse_expr(pairs)?; - Ok((node, metadata.unit)) - } -} - -struct NodeMetadata { - pub unit: Unit, -} - -impl NodeMetadata { - pub fn new(unit: Unit) -> Self { - Self { unit } - } -} - -fn parse_unit(pairs: Pairs) -> Result<(Unit, f64), ParseError> { - let mut scale = 1.0; - let mut length = 0; - let mut mass = 0; - let mut time = 0; - - for pair in pairs { - println!("found rule: {:?}", pair.as_rule()); - match pair.as_rule() { - Rule::nano => scale *= 1e-9, - Rule::micro => scale *= 1e-6, - Rule::milli => scale *= 1e-3, - Rule::centi => scale *= 1e-2, - Rule::deci => scale *= 1e-1, - Rule::deca => scale *= 1e1, - Rule::hecto => scale *= 1e2, - Rule::kilo => scale *= 1e3, - Rule::mega => scale *= 1e6, - Rule::giga => scale *= 1e9, - Rule::tera => scale *= 1e12, - - Rule::meter => length = 1, - Rule::gram => mass = 1, - Rule::second => time = 1, - - _ => unreachable!(), // All possible rules should be covered + /// Lex + parse the source and either return an AST `Node` + /// or a typed `ParseError`. + pub fn try_parse_from_str(src: &str) -> Result> { + let tokens = Lexer::new(src); + + match parser().parse(tokens).into_result() { + Ok(ast) => Ok(ast), + Err(errs) => Err(ParseError::Parse(errs)), } } - - Ok((Unit { length, mass, time }, scale)) -} - -fn parse_const(pair: Pair) -> Literal { - match pair.as_rule() { - Rule::infinity => Literal::Float(f64::INFINITY), - Rule::imaginary_unit => Literal::Complex(Complex::new(0.0, 1.0)), - Rule::pi => Literal::Float(std::f64::consts::PI), - Rule::tau => Literal::Float(2.0 * std::f64::consts::PI), - Rule::euler_number => Literal::Float(std::f64::consts::E), - Rule::golden_ratio => Literal::Float(1.61803398875), - _ => unreachable!("Unexpected constant: {:?}", pair), - } } -fn parse_lit(mut pairs: Pairs) -> Result<(Literal, Unit), ParseError> { - let literal = match pairs.next() { - Some(lit) => match lit.as_rule() { - Rule::int => { - let value = lit.as_str().parse::()? as f64; - Literal::Float(value) - } - Rule::float => { - let value = lit.as_str().parse::()?; - Literal::Float(value) - } - Rule::unit => { - let (unit, scale) = parse_unit(lit.into_inner())?; - return Ok((Literal::Float(scale), unit)); - } - rule => unreachable!("unexpected rule: {:?}", rule), - }, - None => unreachable!("expected rule"), // No literal found - }; - - if let Some(unit_pair) = pairs.next() { - let unit_pairs = unit_pair.into_inner(); // Get the inner pairs for the unit - let (unit, scale) = parse_unit(unit_pairs)?; - - println!("found unit: {:?}", unit); - - Ok(( - match literal { - Literal::Float(num) => Literal::Float(num * scale), - Literal::Complex(num) => Literal::Complex(num * scale), - }, - unit, - )) - } else { - Ok((literal, Unit::BASE_UNIT)) - } -} - -fn parse_expr(pairs: Pairs) -> Result<(Node, NodeMetadata), ParseError> { - PRATT_PARSER - .map_primary(|primary| { - Ok(match primary.as_rule() { - Rule::lit => { - let (lit, unit) = parse_lit(primary.into_inner())?; - - (Node::Lit(lit), NodeMetadata { unit }) - } - Rule::fn_call => { - let mut pairs = primary.into_inner(); - let name = pairs.next().expect("fn_call always has 2 children").as_str().to_string(); - - ( - Node::FnCall { - name, - expr: pairs.map(|p| parse_expr(p.into_inner()).map(|expr| expr.0)).collect::, ParseError>>()?, - }, - NodeMetadata::new(Unit::BASE_UNIT), - ) - } - Rule::constant => { - let lit = parse_const(primary.into_inner().next().expect("constant should have atleast 1 child")); - - (Node::Lit(lit), NodeMetadata::new(Unit::BASE_UNIT)) - } - Rule::ident => { - let name = primary.as_str().to_string(); - - (Node::Var(name), NodeMetadata::new(Unit::BASE_UNIT)) - } - Rule::expr => parse_expr(primary.into_inner())?, - Rule::float => { - let value = primary.as_str().parse::()?; - (Node::Lit(Literal::Float(value)), NodeMetadata::new(Unit::BASE_UNIT)) - } - rule => unreachable!("unexpected rule: {:?}", rule), - }) - }) - .map_prefix(|op, rhs| { - let (rhs, rhs_metadata) = rhs?; - let op = match op.as_rule() { - Rule::neg => UnaryOp::Neg, - Rule::sqrt => UnaryOp::Sqrt, - - rule => unreachable!("unexpected rule: {:?}", rule), - }; - - let node = Node::UnaryOp { expr: Box::new(rhs), op }; - let unit = rhs_metadata.unit; - - let unit = if !unit.is_base() { - match op { - UnaryOp::Sqrt if unit.length % 2 == 0 && unit.mass % 2 == 0 && unit.time % 2 == 0 => Unit { - length: unit.length / 2, - mass: unit.mass / 2, - time: unit.time / 2, - }, - UnaryOp::Neg => unit, - op => return Err(ParseError::Type(TypeError::InvalidUnaryOp(unit, op))), - } - } else { - Unit::BASE_UNIT - }; - - Ok((node, NodeMetadata::new(unit))) - }) - .map_postfix(|lhs, op| { - let (lhs_node, lhs_metadata) = lhs?; - - let op = match op.as_rule() { - Rule::EOI => return Ok((lhs_node, lhs_metadata)), - Rule::fac => UnaryOp::Fac, - rule => unreachable!("unexpected rule: {:?}", rule), - }; - - if !lhs_metadata.unit.is_base() { - return Err(ParseError::Type(TypeError::InvalidUnaryOp(lhs_metadata.unit, op))); - } - - Ok((Node::UnaryOp { expr: Box::new(lhs_node), op }, lhs_metadata)) - }) - .map_infix(|lhs, op, rhs| { - let (lhs, lhs_metadata) = lhs?; - let (rhs, rhs_metadata) = rhs?; - - let op = match op.as_rule() { - Rule::add => BinaryOp::Add, - Rule::sub => BinaryOp::Sub, - Rule::mul => BinaryOp::Mul, - Rule::div => BinaryOp::Div, - Rule::pow => BinaryOp::Pow, - Rule::paren => BinaryOp::Mul, - rule => unreachable!("unexpected rule: {:?}", rule), - }; - - let (lhs_unit, rhs_unit) = (lhs_metadata.unit, rhs_metadata.unit); - - let unit = match (!lhs_unit.is_base(), !rhs_unit.is_base()) { - (true, true) => match op { - BinaryOp::Mul => Unit { - length: lhs_unit.length + rhs_unit.length, - mass: lhs_unit.mass + rhs_unit.mass, - time: lhs_unit.time + rhs_unit.time, - }, - BinaryOp::Div => Unit { - length: lhs_unit.length - rhs_unit.length, - mass: lhs_unit.mass - rhs_unit.mass, - time: lhs_unit.time - rhs_unit.time, - }, - BinaryOp::Add | BinaryOp::Sub => { - if lhs_unit == rhs_unit { - lhs_unit - } else { - return Err(ParseError::Type(TypeError::InvalidBinaryOp(lhs_unit, op, rhs_unit))); - } - } - BinaryOp::Pow => { - return Err(ParseError::Type(TypeError::InvalidBinaryOp(lhs_unit, op, rhs_unit))); - } - }, - - (true, false) => match op { - BinaryOp::Add | BinaryOp::Sub => return Err(ParseError::Type(TypeError::InvalidBinaryOp(lhs_unit, op, Unit::BASE_UNIT))), - BinaryOp::Pow => { - //TODO: improve error type - //TODO: support 1 / int - if let Ok(Value::Number(Number::Real(val))) = rhs.eval(&EvalContext::default()) { - if (val - val as i32 as f64).abs() <= f64::EPSILON { - Unit { - length: lhs_unit.length * val as i32, - mass: lhs_unit.mass * val as i32, - time: lhs_unit.time * val as i32, - } - } else { - return Err(ParseError::Type(TypeError::InvalidBinaryOp(lhs_unit, op, Unit::BASE_UNIT))); - } - } else { - return Err(ParseError::Type(TypeError::InvalidBinaryOp(lhs_unit, op, Unit::BASE_UNIT))); - } - } - _ => lhs_unit, - }, - (false, true) => match op { - BinaryOp::Add | BinaryOp::Sub | BinaryOp::Pow => return Err(ParseError::Type(TypeError::InvalidBinaryOp(Unit::BASE_UNIT, op, rhs_unit))), - _ => rhs_unit, - }, - (false, false) => Unit::BASE_UNIT, - }; - - let node = Node::BinOp { +pub fn parser<'src, I>() -> impl Parser<'src, I, Node, extra::Err, Span>>> +where + I: ValueInput<'src, Token = Token<'src>, Span = Span>, +{ + recursive(|expr| { + let constant = select! {Token::Const(x) => Node::Lit(x)}; + + let args = expr.clone().separated_by(just(Token::Comma)).collect::>().delimited_by(just(Token::LParen), just(Token::RParen)); + + let if_expr = just(Token::If) + .ignore_then(args.clone()) // Parses (cond, a, b) + .try_map(|args: Vec, span| { + if args.len() != 3 { + return Err(Rich::custom(span, "Expected 3 arguments in if(cond, a, b)")); + } + let mut iter = args.into_iter(); + let cond = iter.next().unwrap(); + let if_b = iter.next().unwrap(); + let else_b = iter.next().unwrap(); + Ok(Node::Conditional { + condition: Box::new(cond), + if_block: Box::new(if_b), + else_block: Box::new(else_b), + }) + }); + + let call = select! {Token::Ident(s) => s} + .then(args) + .try_map(|(name, args): (&str, Vec), span| Ok(Node::FnCall { name: name.to_string(), expr: args })); + + let parens = expr.clone().clone().delimited_by(just(Token::LParen), just(Token::RParen)); + let var = select! { Token::Ident(name) => Node::Var(name.to_string()) }; + + let atom = choice((constant, if_expr, call, parens, var)).boxed(); + + let add_op = choice((just(Token::Plus).to(BinaryOp::Add), just(Token::Minus).to(BinaryOp::Sub))); + let mul_op = choice((just(Token::Star).to(BinaryOp::Mul), just(Token::Slash).to(BinaryOp::Div))); + let pow_op = just(Token::Caret).to(BinaryOp::Pow); + let unary_op = just(Token::Minus).to(UnaryOp::Neg); + let cmp_op = choice(( + just(Token::Lt).to(BinaryOp::Lt), + just(Token::Le).to(BinaryOp::Leq), + just(Token::Gt).to(BinaryOp::Gt), + just(Token::Ge).to(BinaryOp::Geq), + just(Token::EqEq).to(BinaryOp::Eq), + )); + + let unary = unary_op.repeated().foldr(atom, |op, expr| Node::UnaryOp { op, expr: Box::new(expr) }).boxed(); + + let cmp = unary.clone().foldl(cmp_op.then(unary).repeated(), |lhs: Node, (op, rhs)| Node::BinOp { + lhs: Box::new(lhs), + op, + rhs: Box::new(rhs), + }); + + let pow = cmp.clone().foldl(pow_op.then(cmp).repeated(), |lhs, (op, rhs)| Node::BinOp { + lhs: Box::new(lhs), + op, + rhs: Box::new(rhs), + }); + + let product = pow + .clone() + .foldl(mul_op.then(pow).repeated(), |lhs, (op, rhs)| Node::BinOp { lhs: Box::new(lhs), op, rhs: Box::new(rhs), - }; - - Ok((node, NodeMetadata::new(unit))) + }) + .boxed(); + + let add = product.clone().foldl(add_op.then(product).repeated(), |lhs, (op, rhs)| Node::BinOp { + lhs: Box::new(lhs), + op, + rhs: Box::new(rhs), + }); + + add.clone().foldl(add.map(|rhs| (BinaryOp::Mul, rhs)).repeated(), |lhs, (op, rhs)| Node::BinOp { + lhs: Box::new(lhs), + op, + rhs: Box::new(rhs), }) - .parse(pairs) + }) } -//TODO: set up Unit test for Units #[cfg(test)] mod tests { use super::*; @@ -320,7 +122,7 @@ mod tests { #[test] fn $name() { let result = Node::try_parse_from_str($input).unwrap(); - assert_eq!(result.0, $expected); + assert_eq!(result, $expected); } )* }; @@ -349,16 +151,20 @@ mod tests { op: BinaryOp::Pow, rhs: Box::new(Node::Lit(Literal::Float(3.0))), }, - test_parse_unary_sqrt: "sqrt(16)" => Node::UnaryOp { - expr: Box::new(Node::Lit(Literal::Float(16.0))), - op: UnaryOp::Sqrt, + test_parse_unary_sqrt: "sqrt(16)" => Node::FnCall { + name: "sqrt".to_string(), + expr: vec![Node::Lit(Literal::Float(16.0))], }, - test_parse_sqr_ident: "sqr(16)" => Node::FnCall { - name:"sqr".to_string(), + test_parse_ii_call: "ii(16)" => Node::FnCall { + name:"ii".to_string(), expr: vec![Node::Lit(Literal::Float(16.0))] }, - - test_parse_complex_expr: "(1 + 2) 3 - 4 ^ 2" => Node::BinOp { + test_parse_i_mul: "i(16)" => Node::BinOp { + lhs: Box::new(Node::Lit(Literal::Complex(Complex::new(0.0, 1.0)))), + op: BinaryOp::Mul, + rhs: Box::new(Node::Lit(Literal::Float(16.0))), + }, + test_parse_complex_expr: "(1 + 2) * 3 - 4 ^ 2" => Node::BinOp { lhs: Box::new(Node::BinOp { lhs: Box::new(Node::BinOp { lhs: Box::new(Node::Lit(Literal::Float(1.0))), @@ -374,6 +180,15 @@ mod tests { op: BinaryOp::Pow, rhs: Box::new(Node::Lit(Literal::Float(2.0))), }), + }, + test_conditional_expr: "if (x+3, 0, 1)" => Node::Conditional{ + condition: Box::new(Node::BinOp{ + lhs: Box::new(Node::Var("x".to_string())), + op: BinaryOp::Add, + rhs: Box::new(Node::Lit(Literal::Float(3.0))), + }), + if_block: Box::new(Node::Lit(Literal::Float(0.0))), + else_block: Box::new(Node::Lit(Literal::Float(1.0))), } } } diff --git a/libraries/math-parser/src/value.rs b/libraries/math-parser/src/value.rs index 3577f3ea60..c4e5217ba8 100644 --- a/libraries/math-parser/src/value.rs +++ b/libraries/math-parser/src/value.rs @@ -52,7 +52,7 @@ impl std::fmt::Display for Number { } impl Number { - pub fn binary_op(self, op: BinaryOp, other: Number) -> Number { + pub fn binary_op(self, op: BinaryOp, other: Number) -> Option { match (self, other) { (Number::Real(lhs), Number::Real(rhs)) => { let result = match op { @@ -61,8 +61,14 @@ impl Number { BinaryOp::Mul => lhs * rhs, BinaryOp::Div => lhs / rhs, BinaryOp::Pow => lhs.powf(rhs), + BinaryOp::Leq => (lhs <= rhs) as u8 as f64, + BinaryOp::Lt => (lhs < rhs) as u8 as f64, + BinaryOp::Geq => (lhs >= rhs) as u8 as f64, + BinaryOp::Gt => (lhs > rhs) as u8 as f64, + BinaryOp::Eq => (lhs == rhs) as u8 as f64, }; - Number::Real(result) + + Some(Number::Real(result)) } (Number::Complex(lhs), Number::Complex(rhs)) => { @@ -72,8 +78,18 @@ impl Number { BinaryOp::Mul => lhs * rhs, BinaryOp::Div => lhs / rhs, BinaryOp::Pow => lhs.powc(rhs), + BinaryOp::Leq | BinaryOp::Lt | BinaryOp::Geq | BinaryOp::Gt => { + return None; + } + BinaryOp::Eq => { + if lhs == rhs { + return Some(Number::Real(1.0)); + } else { + return Some(Number::Real(0.0)); + } + } }; - Number::Complex(result) + Some(Number::Complex(result)) } (Number::Real(lhs), Number::Complex(rhs)) => { @@ -84,8 +100,9 @@ impl Number { BinaryOp::Mul => lhs_complex * rhs, BinaryOp::Div => lhs_complex / rhs, BinaryOp::Pow => lhs_complex.powc(rhs), + _ => return None, }; - Number::Complex(result) + Some(Number::Complex(result)) } (Number::Complex(lhs), Number::Real(rhs)) => { @@ -96,8 +113,9 @@ impl Number { BinaryOp::Mul => lhs * rhs_complex, BinaryOp::Div => lhs / rhs_complex, BinaryOp::Pow => lhs.powf(rhs), + _ => return None, }; - Number::Complex(result) + Some(Number::Complex(result)) } } } diff --git a/node-graph/gcore/src/ops.rs b/node-graph/gcore/src/ops.rs index 99d6b291d3..3ea14373f0 100644 --- a/node-graph/gcore/src/ops.rs +++ b/node-graph/gcore/src/ops.rs @@ -51,7 +51,7 @@ fn math( #[default(1.)] operand_b: U, ) -> U { - let (node, _unit) = match ast::Node::try_parse_from_str(&expression) { + let node = match ast::Node::try_parse_from_str(&expression) { Ok(expr) => expr, Err(e) => { warn!("Invalid expression: `{expression}`\n{e:?}");