expressions: Add code for n-ary functions

Add code that deals with n-ary expressions, such as `comb`, `perm` or
`xroot`.

Signed-off-by: Christophe de Dinechin <christophe@dinechin.org>
This commit is contained in:
Christophe de Dinechin 2024-03-26 23:23:44 +01:00
parent a2ce949b23
commit 4db8bd0b2e
5 changed files with 251 additions and 121 deletions

View file

@ -153,7 +153,8 @@ symbol_p expression::render(uint depth, int &precedence, bool editing)
default:
break;
}
if (argp >= precedence::FUNCTION && argp != precedence::FUNCTION_POWER)
if (argp >= precedence::FUNCTION &&
argp != precedence::FUNCTION_POWER)
arg = space(arg);
return fn + arg;
}
@ -438,6 +439,62 @@ size_t expression::required_memory(id type, id op, algebraic_r x, algebraic_r y)
}
expression::expression(id type, id op, algebraic_g args[], uint arity)
// ----------------------------------------------------------------------------
// Build an equation with 'arity' arguments
// ----------------------------------------------------------------------------
: program(type, nullptr, 0)
{
byte *p = (byte *) payload();
// Compute the size of the program
size_t size = leb128size(op);
for (uint a = 0; a < arity; a++)
size += size_in_expression(args[a]);
// Write the size of the program
p = leb128(p, size);
// Write the arguments
size_t objsize = 0;
byte_p objptr = nullptr;
for (uint a = 0; a < arity; a++)
{
algebraic_p arg = args[arity + ~a];
if (expression_p eq = arg->as<expression>())
{
objptr = eq->value(&objsize);
}
else
{
objsize = arg->size();
objptr = byte_p(arg);
}
memmove(p, objptr, objsize);
p += objsize;
}
// Write the opcode
p = leb128(p, op);
}
size_t expression::required_memory(id type, id op,
algebraic_g args[], uint arity)
// ----------------------------------------------------------------------------
// Size of an equation object with 'arity' arguments
// ----------------------------------------------------------------------------
{
size_t size = leb128size(op);
for (uint a = 0; a < arity; a++)
size += size_in_expression(args[a]);
size += leb128size(size);
size += leb128size(type);
return size;
}
// ============================================================================
//
// Equation rewrite engine

View file

@ -60,6 +60,10 @@ struct expression : program
expression(id type, id op, algebraic_r x, algebraic_r y);
static size_t required_memory(id i, id op, algebraic_r x, algebraic_r y);
// Building expressions from an array of arguments
expression(id type, id op, algebraic_g arg[], uint arity);
static size_t required_memory(id i, id op, algebraic_g arg[], uint arity);
object_p quoted(id type = ID_object) const;
static size_t size_in_expression(object_p obj);
@ -85,6 +89,15 @@ struct expression : program
return rt.make<expression>(type, op, x, y);
}
static expression_p make(id op, algebraic_g args[], uint arity,
id type = ID_expression)
{
for (uint a = 0; a < arity; a++)
if (!args[a])
return nullptr;
return rt.make<expression>(type, op, args, arity);
}
expression_p rewrite(expression_r from, expression_r to) const;
expression_p rewrite(expression_p from, expression_p to) const
{
@ -93,7 +106,9 @@ struct expression : program
expression_p rewrite(size_t size, const byte_p rewrites[]) const;
expression_p rewrite_all(size_t size, const byte_p rewrites[]) const;
static expression_p rewrite(expression_r eq, expression_r from, expression_r to)
static expression_p rewrite(expression_r eq,
expression_r from,
expression_r to)
{
return eq->rewrite(from, to);
}

View file

@ -274,6 +274,52 @@ object::result function::evaluate(algebraic_fn op, bool mat)
}
object::result function::evaluate(id op, nfunction_fn fn, uint arity)
// ----------------------------------------------------------------------------
// Perform the operation from the stack for n-ary functions
// ----------------------------------------------------------------------------
{
if (!rt.args(arity))
return ERROR;
bool is_symbolic = false;
algebraic_g args[arity];
for (uint a = 0; a < arity; a++)
{
algebraic_p arg = rt.stack(a)->as_algebraic();
if (!arg)
{
rt.type_error();
return ERROR;
}
args[a] = arg;
if (arg->is_symbolic())
is_symbolic = true;
// Conversion to numerical if needed (may fail silently)
if (Settings.NumericalResults())
{
(void) to_decimal(args[a], true);
if (!args[a])
return ERROR;
}
}
algebraic_g result;
// Check the symbolic case
if (is_symbolic)
result = expression::make(op, args, arity);
else
result = fn(op, args, arity);
if (result && rt.drop(arity) && rt.push(+result))
return OK;
return ERROR;
}
FUNCTION_BODY(neg)
// ----------------------------------------------------------------------------
// Implementation of 'neg'
@ -704,63 +750,49 @@ FUNCTION_BODY(cubed)
}
COMMAND_BODY(xroot)
NFUNCTION_BODY(xroot)
// ----------------------------------------------------------------------------
// Compute the x-th root
// ----------------------------------------------------------------------------
{
if (rt.args(2))
{
if (object_p x = rt.stack(0))
{
if (object_p y = rt.stack(1))
{
algebraic_g xa = x->as_algebraic();
algebraic_g ya = y->as_algebraic();
if (!xa.Safe() || !ya.Safe())
{
rt.type_error();
}
else if (xa->is_zero())
if (args[0]->is_zero())
{
rt.domain_error();
}
else
{
bool is_int = xa->is_integer();
algebraic_g &x = args[0];
algebraic_g &y = args[1];
bool is_int = x->is_integer();
bool is_neg = false;
if (!is_int && xa->is_decimal())
if (!is_int && x->is_decimal())
{
decimal_g ip, fp;
decimal_p xd = decimal_p(+xa);
decimal_p xd = decimal_p(+x);
if (!xd->split(ip, fp))
return ERROR;
return nullptr;
if (fp->is_zero())
is_int = true;
}
if (is_int)
{
bool is_odd = xa->as_int32(0, false) & 1;
is_neg = ya->is_negative();
bool is_odd = x->as_int32(0, false) & 1;
is_neg = y->is_negative();
if (is_neg && !is_odd)
{
// Root of a negative number
rt.domain_error();
return ERROR;
return nullptr;
}
}
if (is_neg)
xa = -pow(-ya, integer::make(1) / xa);
x = -pow(-y, integer::make(1) / x);
else
xa = pow(ya, integer::make(1) / xa);
if (+xa && rt.drop() && rt.top(xa))
return OK;
x = pow(y, integer::make(1) / x);
return x;
}
}
}
}
return ERROR;
return nullptr;
}
@ -825,37 +857,25 @@ INSERT_BODY(fact)
}
COMMAND_BODY(comb)
NFUNCTION_BODY(comb)
// ----------------------------------------------------------------------------
// Compute number of combinations
// ----------------------------------------------------------------------------
{
if (rt.args(2))
{
algebraic_g n = rt.stack(1)->as_algebraic();
algebraic_g m = rt.stack(0)->as_algebraic();
if (n->is_symbolic() || m->is_symbolic())
{
algebraic_g result = expression::make(ID_comb, n, m);
if (!result || !rt.drop() || !rt.top(result))
return ERROR;
return OK;
}
algebraic_g &n = args[1];
algebraic_g &m = args[0];
if (integer_g nval = n->as<integer>())
{
if (integer_g mval = m->as<integer>())
{
ularge ni = nval->value<ularge>();
ularge mi = mval->value<ularge>();
algebraic_g result = integer::make(ni < mi ? 0 : 1);
for (ularge i = ni - mi + 1; i <= ni && result; i++)
result = result * algebraic_g(integer::make(i));
for (ularge i = 2; i <= mi; i++)
result = result / algebraic_g(integer::make(i));
if (!result || !rt.drop() || !rt.top(result))
return ERROR;
return OK;
n = integer::make(ni < mi ? 0 : 1);
for (ularge i = ni - mi + 1; i <= ni && n; i++)
n = n * algebraic_g(integer::make(i));
for (ularge i = 2; i <= mi && n; i++)
n = n / algebraic_g(integer::make(i));
return n;
}
}
@ -863,40 +883,27 @@ COMMAND_BODY(comb)
rt.value_error();
else
rt.type_error();
}
return ERROR;
return nullptr;
}
COMMAND_BODY(perm)
NFUNCTION_BODY(perm)
// ----------------------------------------------------------------------------
// Compute number of permutations (n! / (n - m)!)
// ----------------------------------------------------------------------------
{
if (rt.args(2))
{
algebraic_g n = rt.stack(1)->as_algebraic();
algebraic_g m = rt.stack(0)->as_algebraic();
if (n->is_symbolic() || m->is_symbolic())
{
algebraic_g result = expression::make(ID_perm, n, m);
if (!result || !rt.drop() || !rt.top(result))
return ERROR;
return OK;
}
algebraic_g &n = args[1];
algebraic_g &m = args[0];
if (integer_g nval = n->as<integer>())
{
if (integer_g mval = m->as<integer>())
{
ularge ni = nval->value<ularge>();
ularge mi = mval->value<ularge>();
algebraic_g result = integer::make(ni < mi ? 0 : 1);
for (ularge i = ni - mi + 1; i <= ni && result; i++)
result = result * algebraic_g(integer::make(i));
if (!result || !rt.drop() || !rt.top(result))
return ERROR;
return OK;
n = integer::make(ni < mi ? 0 : 1);
for (ularge i = ni - mi + 1; i <= ni && n; i++)
n = n * algebraic_g(integer::make(i));
return n;
}
}
@ -904,8 +911,7 @@ COMMAND_BODY(perm)
rt.value_error();
else
rt.type_error();
}
return ERROR;
return nullptr;
}

View file

@ -93,6 +93,14 @@ public:
static const bool does_matrices = false;
typedef algebraic_p (*nfunction_fn)(id op, algebraic_g args[], uint arity);
static result evaluate(id op, nfunction_fn fn, uint arity);
// ------------------------------------------------------------------------
// Evaluate a function with n arguments
// ------------------------------------------------------------------------
};
@ -164,7 +172,7 @@ STANDARD_FUNCTION(tgamma);
STANDARD_FUNCTION(lgamma);
#define FUNCTION_EXT(derived, arity, extra) \
#define FUNCTION_EXT(derived, extra) \
struct derived : function \
/* ----------------------------------------------------------------- */ \
/* Macro to define a mathematical function not from the library */ \
@ -174,7 +182,7 @@ struct derived : function \
\
public: \
OBJECT_DECL(derived); \
ARITY_DECL(arity); \
ARITY_DECL(1); \
PREC_DECL(FUNCTION); \
EVAL_DECL(derived) \
{ \
@ -191,15 +199,15 @@ public: \
static algebraic_p evaluate(algebraic_r x); \
};
#define FUNCTION(derived) FUNCTION_EXT(derived, 1, )
#define FUNCTION(derived) FUNCTION_EXT(derived, )
#define FUNCTION_FANCY(derived) \
FUNCTION_EXT(derived, 1, INSERT_DECL(derived);)
FUNCTION_EXT(derived, INSERT_DECL(derived);)
#define FUNCTION_MAT(derived) \
FUNCTION_EXT(derived, 1, \
FUNCTION_EXT(derived, \
static const bool does_matrices = true;)
#define FUNCTION_FANCY_MAT(derived) \
FUNCTION_EXT(derived, 1, \
FUNCTION_EXT(derived, \
INSERT_DECL(derived); \
static const bool does_matrices = true;)
@ -216,10 +224,7 @@ FUNCTION_FANCY_MAT(inv);
FUNCTION(neg);
FUNCTION_FANCY_MAT(sq);
FUNCTION_FANCY_MAT(cubed);
COMMAND_DECLARE(xroot);
FUNCTION_FANCY(fact);
COMMAND_DECLARE(comb);
COMMAND_DECLARE(perm);
FUNCTION(re);
FUNCTION(im);
@ -235,4 +240,42 @@ FUNCTION(ToFraction);
FUNCTION(RadiansToDegrees);
FUNCTION(DegreesToRadians);
#define NFUNCTION(derived, fnarity, extra) \
struct derived : function \
/* ----------------------------------------------------------------- */ \
/* Macro to define a mathematical function with more than 1 arg */ \
/* ----------------------------------------------------------------- */ \
{ \
derived(id i = ID_##derived) : function(i) {} \
\
public: \
OBJECT_DECL(derived); \
ARITY_DECL(fnarity); \
PREC_DECL(FUNCTION); \
EVAL_DECL(derived) \
{ \
rt.command(o); \
return evaluate(); \
} \
extra \
public: \
static result evaluate() \
{ \
return function::evaluate(derived::static_id, \
derived::evaluate, fnarity); \
} \
static algebraic_p evaluate(id op, algebraic_g args[], uint arity); \
}
#define NFUNCTION_BODY(derived) \
algebraic_p derived::evaluate(id op, algebraic_g args[], uint arity)
NFUNCTION(xroot, 2, );
NFUNCTION(comb, 2, );
NFUNCTION(perm, 2, );
#endif // FUNCTIONS_H

View file

@ -6383,6 +6383,15 @@ void tests::probabilities()
.test(CLEAR, "37 42", NOSHIFT, F2).expect("0");
step("Permutations in menu")
.test(CLEAR, "42 37", NOSHIFT, F2).expect("11708384314607332487859521718704263082803200000000");
step("Symbolic combinations")
.test(CLEAR, "n m", NOSHIFT, F1).expect("'Combinations(n;m)'")
.test(CLEAR, "n 1", NOSHIFT, F1).expect("'Combinations(n;1)'")
.test(CLEAR, "1 z", NOSHIFT, F1).expect("'Combinations(1;z)'");
step("Symbolic permutations")
.test(CLEAR, "n m", NOSHIFT, F2).expect("'Permutations(n;m)'")
.test(CLEAR, "n 1", NOSHIFT, F2).expect("'Permutations(n;1)'")
.test(CLEAR, "1 z", NOSHIFT, F2).expect("'Permutations(1;z)'");
}