From 4db8bd0b2e70b9593926a73318ec167e9cc6c5a2 Mon Sep 17 00:00:00 2001 From: Christophe de Dinechin Date: Tue, 26 Mar 2024 23:23:44 +0100 Subject: [PATCH] expressions: Add code for n-ary functions Add code that deals with n-ary expressions, such as `comb`, `perm` or `xroot`. Signed-off-by: Christophe de Dinechin --- src/expression.cc | 61 ++++++++++++- src/expression.h | 17 +++- src/functions.cc | 224 ++++++++++++++++++++++++---------------------- src/functions.h | 61 +++++++++++-- src/tests.cc | 9 ++ 5 files changed, 251 insertions(+), 121 deletions(-) diff --git a/src/expression.cc b/src/expression.cc index 8a1c2f97..8f7576ce 100644 --- a/src/expression.cc +++ b/src/expression.cc @@ -153,7 +153,8 @@ symbol_p expression::render(uint depth, int &precedence, bool editing) default: break; } - if (argp >= precedence::FUNCTION && argp != precedence::FUNCTION_POWER) + if (argp >= precedence::FUNCTION && + argp != precedence::FUNCTION_POWER) arg = space(arg); return fn + arg; } @@ -431,13 +432,69 @@ size_t expression::required_memory(id type, id op, algebraic_r x, algebraic_r y) // Size of an equation object with one argument // ---------------------------------------------------------------------------- { - size_t size = leb128size(op) + size_in_expression(x) + size_in_expression(y); + size_t size = leb128size(op)+size_in_expression(x)+size_in_expression(y); size += leb128size(size); size += leb128size(type); return size; } +expression::expression(id type, id op, algebraic_g args[], uint arity) +// ---------------------------------------------------------------------------- +// Build an equation with 'arity' arguments +// ---------------------------------------------------------------------------- + : program(type, nullptr, 0) +{ + byte *p = (byte *) payload(); + + // Compute the size of the program + size_t size = leb128size(op); + for (uint a = 0; a < arity; a++) + size += size_in_expression(args[a]); + + // Write the size of the program + p = leb128(p, size); + + // Write the arguments + size_t objsize = 0; + byte_p objptr = nullptr; + for (uint a = 0; a < arity; a++) + { + algebraic_p arg = args[arity + ~a]; + if (expression_p eq = arg->as()) + { + objptr = eq->value(&objsize); + } + else + { + objsize = arg->size(); + objptr = byte_p(arg); + } + memmove(p, objptr, objsize); + p += objsize; + } + + // Write the opcode + p = leb128(p, op); +} + + +size_t expression::required_memory(id type, id op, + algebraic_g args[], uint arity) +// ---------------------------------------------------------------------------- +// Size of an equation object with 'arity' arguments +// ---------------------------------------------------------------------------- +{ + size_t size = leb128size(op); + for (uint a = 0; a < arity; a++) + size += size_in_expression(args[a]); + size += leb128size(size); + size += leb128size(type); + return size; +} + + + // ============================================================================ // // Equation rewrite engine diff --git a/src/expression.h b/src/expression.h index bd6148bd..24aced7c 100644 --- a/src/expression.h +++ b/src/expression.h @@ -60,6 +60,10 @@ struct expression : program expression(id type, id op, algebraic_r x, algebraic_r y); static size_t required_memory(id i, id op, algebraic_r x, algebraic_r y); + // Building expressions from an array of arguments + expression(id type, id op, algebraic_g arg[], uint arity); + static size_t required_memory(id i, id op, algebraic_g arg[], uint arity); + object_p quoted(id type = ID_object) const; static size_t size_in_expression(object_p obj); @@ -85,6 +89,15 @@ struct expression : program return rt.make(type, op, x, y); } + static expression_p make(id op, algebraic_g args[], uint arity, + id type = ID_expression) + { + for (uint a = 0; a < arity; a++) + if (!args[a]) + return nullptr; + return rt.make(type, op, args, arity); + } + expression_p rewrite(expression_r from, expression_r to) const; expression_p rewrite(expression_p from, expression_p to) const { @@ -93,7 +106,9 @@ struct expression : program expression_p rewrite(size_t size, const byte_p rewrites[]) const; expression_p rewrite_all(size_t size, const byte_p rewrites[]) const; - static expression_p rewrite(expression_r eq, expression_r from, expression_r to) + static expression_p rewrite(expression_r eq, + expression_r from, + expression_r to) { return eq->rewrite(from, to); } diff --git a/src/functions.cc b/src/functions.cc index 6cf83902..55cb5dcf 100644 --- a/src/functions.cc +++ b/src/functions.cc @@ -274,6 +274,52 @@ object::result function::evaluate(algebraic_fn op, bool mat) } +object::result function::evaluate(id op, nfunction_fn fn, uint arity) +// ---------------------------------------------------------------------------- +// Perform the operation from the stack for n-ary functions +// ---------------------------------------------------------------------------- +{ + if (!rt.args(arity)) + return ERROR; + + bool is_symbolic = false; + algebraic_g args[arity]; + for (uint a = 0; a < arity; a++) + { + algebraic_p arg = rt.stack(a)->as_algebraic(); + if (!arg) + { + rt.type_error(); + return ERROR; + } + args[a] = arg; + if (arg->is_symbolic()) + is_symbolic = true; + + // Conversion to numerical if needed (may fail silently) + if (Settings.NumericalResults()) + { + (void) to_decimal(args[a], true); + if (!args[a]) + return ERROR; + } + } + + + algebraic_g result; + + // Check the symbolic case + if (is_symbolic) + result = expression::make(op, args, arity); + else + result = fn(op, args, arity); + + if (result && rt.drop(arity) && rt.push(+result)) + return OK; + return ERROR; +} + + FUNCTION_BODY(neg) // ---------------------------------------------------------------------------- // Implementation of 'neg' @@ -704,63 +750,49 @@ FUNCTION_BODY(cubed) } -COMMAND_BODY(xroot) +NFUNCTION_BODY(xroot) // ---------------------------------------------------------------------------- // Compute the x-th root // ---------------------------------------------------------------------------- { - if (rt.args(2)) + if (args[0]->is_zero()) { - if (object_p x = rt.stack(0)) + rt.domain_error(); + } + else + { + algebraic_g &x = args[0]; + algebraic_g &y = args[1]; + bool is_int = x->is_integer(); + bool is_neg = false; + if (!is_int && x->is_decimal()) { - if (object_p y = rt.stack(1)) + decimal_g ip, fp; + decimal_p xd = decimal_p(+x); + if (!xd->split(ip, fp)) + return nullptr; + if (fp->is_zero()) + is_int = true; + } + if (is_int) + { + bool is_odd = x->as_int32(0, false) & 1; + is_neg = y->is_negative(); + if (is_neg && !is_odd) { - algebraic_g xa = x->as_algebraic(); - algebraic_g ya = y->as_algebraic(); - if (!xa.Safe() || !ya.Safe()) - { - rt.type_error(); - } - else if (xa->is_zero()) - { - rt.domain_error(); - } - else - { - bool is_int = xa->is_integer(); - bool is_neg = false; - if (!is_int && xa->is_decimal()) - { - decimal_g ip, fp; - decimal_p xd = decimal_p(+xa); - if (!xd->split(ip, fp)) - return ERROR; - if (fp->is_zero()) - is_int = true; - } - if (is_int) - { - bool is_odd = xa->as_int32(0, false) & 1; - is_neg = ya->is_negative(); - if (is_neg && !is_odd) - { - // Root of a negative number - rt.domain_error(); - return ERROR; - } - } - - if (is_neg) - xa = -pow(-ya, integer::make(1) / xa); - else - xa = pow(ya, integer::make(1) / xa); - if (+xa && rt.drop() && rt.top(xa)) - return OK; - } + // Root of a negative number + rt.domain_error(); + return nullptr; } } + + if (is_neg) + x = -pow(-y, integer::make(1) / x); + else + x = pow(y, integer::make(1) / x); + return x; } - return ERROR; + return nullptr; } @@ -825,87 +857,61 @@ INSERT_BODY(fact) } -COMMAND_BODY(comb) +NFUNCTION_BODY(comb) // ---------------------------------------------------------------------------- // Compute number of combinations // ---------------------------------------------------------------------------- { - if (rt.args(2)) + algebraic_g &n = args[1]; + algebraic_g &m = args[0]; + if (integer_g nval = n->as()) { - algebraic_g n = rt.stack(1)->as_algebraic(); - algebraic_g m = rt.stack(0)->as_algebraic(); - if (n->is_symbolic() || m->is_symbolic()) + if (integer_g mval = m->as()) { - algebraic_g result = expression::make(ID_comb, n, m); - if (!result || !rt.drop() || !rt.top(result)) - return ERROR; - return OK; + ularge ni = nval->value(); + ularge mi = mval->value(); + n = integer::make(ni < mi ? 0 : 1); + for (ularge i = ni - mi + 1; i <= ni && n; i++) + n = n * algebraic_g(integer::make(i)); + for (ularge i = 2; i <= mi && n; i++) + n = n / algebraic_g(integer::make(i)); + return n; } - - if (integer_g nval = n->as()) - { - if (integer_g mval = m->as()) - { - ularge ni = nval->value(); - ularge mi = mval->value(); - algebraic_g result = integer::make(ni < mi ? 0 : 1); - for (ularge i = ni - mi + 1; i <= ni && result; i++) - result = result * algebraic_g(integer::make(i)); - for (ularge i = 2; i <= mi; i++) - result = result / algebraic_g(integer::make(i)); - if (!result || !rt.drop() || !rt.top(result)) - return ERROR; - return OK; - } - } - - if (n->is_real() && m->is_real()) - rt.value_error(); - else - rt.type_error(); } - return ERROR; + + if (n->is_real() && m->is_real()) + rt.value_error(); + else + rt.type_error(); + return nullptr; } -COMMAND_BODY(perm) +NFUNCTION_BODY(perm) // ---------------------------------------------------------------------------- // Compute number of permutations (n! / (n - m)!) // ---------------------------------------------------------------------------- { - if (rt.args(2)) + algebraic_g &n = args[1]; + algebraic_g &m = args[0]; + if (integer_g nval = n->as()) { - algebraic_g n = rt.stack(1)->as_algebraic(); - algebraic_g m = rt.stack(0)->as_algebraic(); - if (n->is_symbolic() || m->is_symbolic()) + if (integer_g mval = m->as()) { - algebraic_g result = expression::make(ID_perm, n, m); - if (!result || !rt.drop() || !rt.top(result)) - return ERROR; - return OK; + ularge ni = nval->value(); + ularge mi = mval->value(); + n = integer::make(ni < mi ? 0 : 1); + for (ularge i = ni - mi + 1; i <= ni && n; i++) + n = n * algebraic_g(integer::make(i)); + return n; } - - if (integer_g nval = n->as()) - { - if (integer_g mval = m->as()) - { - ularge ni = nval->value(); - ularge mi = mval->value(); - algebraic_g result = integer::make(ni < mi ? 0 : 1); - for (ularge i = ni - mi + 1; i <= ni && result; i++) - result = result * algebraic_g(integer::make(i)); - if (!result || !rt.drop() || !rt.top(result)) - return ERROR; - return OK; - } - } - - if (n->is_real() && m->is_real()) - rt.value_error(); - else - rt.type_error(); } - return ERROR; + + if (n->is_real() && m->is_real()) + rt.value_error(); + else + rt.type_error(); + return nullptr; } diff --git a/src/functions.h b/src/functions.h index f757b4d3..0c4edce0 100644 --- a/src/functions.h +++ b/src/functions.h @@ -93,6 +93,14 @@ public: static const bool does_matrices = false; + + typedef algebraic_p (*nfunction_fn)(id op, algebraic_g args[], uint arity); + static result evaluate(id op, nfunction_fn fn, uint arity); + // ------------------------------------------------------------------------ + // Evaluate a function with n arguments + // ------------------------------------------------------------------------ + + }; @@ -164,7 +172,7 @@ STANDARD_FUNCTION(tgamma); STANDARD_FUNCTION(lgamma); -#define FUNCTION_EXT(derived, arity, extra) \ +#define FUNCTION_EXT(derived, extra) \ struct derived : function \ /* ----------------------------------------------------------------- */ \ /* Macro to define a mathematical function not from the library */ \ @@ -174,7 +182,7 @@ struct derived : function \ \ public: \ OBJECT_DECL(derived); \ - ARITY_DECL(arity); \ + ARITY_DECL(1); \ PREC_DECL(FUNCTION); \ EVAL_DECL(derived) \ { \ @@ -191,15 +199,15 @@ public: \ static algebraic_p evaluate(algebraic_r x); \ }; -#define FUNCTION(derived) FUNCTION_EXT(derived, 1, ) +#define FUNCTION(derived) FUNCTION_EXT(derived, ) #define FUNCTION_FANCY(derived) \ - FUNCTION_EXT(derived, 1, INSERT_DECL(derived);) + FUNCTION_EXT(derived, INSERT_DECL(derived);) #define FUNCTION_MAT(derived) \ - FUNCTION_EXT(derived, 1, \ + FUNCTION_EXT(derived, \ static const bool does_matrices = true;) #define FUNCTION_FANCY_MAT(derived) \ - FUNCTION_EXT(derived, 1, \ + FUNCTION_EXT(derived, \ INSERT_DECL(derived); \ static const bool does_matrices = true;) @@ -216,10 +224,7 @@ FUNCTION_FANCY_MAT(inv); FUNCTION(neg); FUNCTION_FANCY_MAT(sq); FUNCTION_FANCY_MAT(cubed); -COMMAND_DECLARE(xroot); FUNCTION_FANCY(fact); -COMMAND_DECLARE(comb); -COMMAND_DECLARE(perm); FUNCTION(re); FUNCTION(im); @@ -235,4 +240,42 @@ FUNCTION(ToFraction); FUNCTION(RadiansToDegrees); FUNCTION(DegreesToRadians); + + + +#define NFUNCTION(derived, fnarity, extra) \ +struct derived : function \ +/* ----------------------------------------------------------------- */ \ +/* Macro to define a mathematical function with more than 1 arg */ \ +/* ----------------------------------------------------------------- */ \ +{ \ + derived(id i = ID_##derived) : function(i) {} \ + \ +public: \ + OBJECT_DECL(derived); \ + ARITY_DECL(fnarity); \ + PREC_DECL(FUNCTION); \ + EVAL_DECL(derived) \ + { \ + rt.command(o); \ + return evaluate(); \ + } \ + extra \ +public: \ + static result evaluate() \ + { \ + return function::evaluate(derived::static_id, \ + derived::evaluate, fnarity); \ + } \ + static algebraic_p evaluate(id op, algebraic_g args[], uint arity); \ +} + + +#define NFUNCTION_BODY(derived) \ + algebraic_p derived::evaluate(id op, algebraic_g args[], uint arity) + +NFUNCTION(xroot, 2, ); +NFUNCTION(comb, 2, ); +NFUNCTION(perm, 2, ); + #endif // FUNCTIONS_H diff --git a/src/tests.cc b/src/tests.cc index de388b36..bce7d647 100644 --- a/src/tests.cc +++ b/src/tests.cc @@ -6383,6 +6383,15 @@ void tests::probabilities() .test(CLEAR, "37 42", NOSHIFT, F2).expect("0"); step("Permutations in menu") .test(CLEAR, "42 37", NOSHIFT, F2).expect("11 708 384 314 607 332 487 859 521 718 704 263 082 803 200 000 000"); + + step("Symbolic combinations") + .test(CLEAR, "n m", NOSHIFT, F1).expect("'Combinations(n;m)'") + .test(CLEAR, "n 1", NOSHIFT, F1).expect("'Combinations(n;1)'") + .test(CLEAR, "1 z", NOSHIFT, F1).expect("'Combinations(1;z)'"); + step("Symbolic permutations") + .test(CLEAR, "n m", NOSHIFT, F2).expect("'Permutations(n;m)'") + .test(CLEAR, "n 1", NOSHIFT, F2).expect("'Permutations(n;1)'") + .test(CLEAR, "1 z", NOSHIFT, F2).expect("'Permutations(1;z)'"); }