C++ 微积分 - 求导 - 解析法(符号计算)

时间:2025-04-12 14:29:29
#include <iostream> #include <string> #include <memory> // 表示一个符号表达式的基类 class Expression { public: virtual std::string str() const = 0; // 表示表达式的字符串形式 virtual std::unique_ptr<Expression> derivative() const = 0; // 计算表达式的导数 virtual std::unique_ptr<Expression> clone() const = 0; // 克隆当前表达式 virtual ~Expression() {} }; // 表示常数的类 class Constant : public Expression { double value; public: Constant(double val) : value(val) {} std::string str() const override { return std::to_string(value); } std::unique_ptr<Expression> derivative() const override { return std::make_unique<Constant>(0); // 常数的导数为 0 } std::unique_ptr<Expression> clone() const override { return std::make_unique<Constant>(value); // 克隆常数 } double getValue() const { return value; } // 获取常数值 }; // 表示变量 x 的类 class Variable : public Expression { public: std::string str() const override { return "x"; } std::unique_ptr<Expression> derivative() const override { return std::make_unique<Constant>(1); // x 的导数为 1 } std::unique_ptr<Expression> clone() const override { return std::make_unique<Variable>(); // 克隆变量 } }; // 表示加法运算的类 class Add : public Expression { std::unique_ptr<Expression> left, right; public: Add(std::unique_ptr<Expression> l, std::unique_ptr<Expression> r) : left(std::move(l)), right(std::move(r)) {} std::string str() const override { std::string leftStr = left->str(); std::string rightStr = right->str(); return "(" + leftStr + " + " + rightStr + ")"; } std::unique_ptr<Expression> derivative() const override { // 加法导数是各自导数的和 return std::make_unique<Add>(left->derivative(), right->derivative()); } std::unique_ptr<Expression> clone() const override { return std::make_unique<Add>(left->clone(), right->clone()); // 克隆加法表达式 } }; // 表示乘法运算的类 class Multiply : public Expression { std::unique_ptr<Expression> left, right; public: Multiply(std::unique_ptr<Expression> l, std::unique_ptr<Expression> r) : left(std::move(l)), right(std::move(r)) {} std::string str() const override { // 处理乘法表达式中的优先级和括号 std::string leftStr = left->str(); std::string rightStr = right->str(); // 在乘法的左右操作数前加括号(如果必要) if (dynamic_cast<Add*>(left.get()) || dynamic_cast<Multiply*>(left.get())) { leftStr = "(" + leftStr + ")"; } if (dynamic_cast<Add*>(right.get()) || dynamic_cast<Multiply*>(right.get())) { rightStr = "(" + rightStr + ")"; } return leftStr + " * " + rightStr; } std::unique_ptr<Expression> derivative() const override { // 乘法导数根据乘积法则计算 // (u * v)' = u' * v + u * v' return std::make_unique<Add>( std::make_unique<Multiply>(left->derivative()->clone(), right->clone()), std::make_unique<Multiply>(left->clone(), right->derivative()->clone()) ); } std::unique_ptr<Expression> clone() const override { return std::make_unique<Multiply>(left->clone(), right->clone()); // 克隆乘法表达式 } }; int main() { // 构造函数 f(x) = x^3 + 2x auto x = std::make_unique<Variable>(); auto x_cubed = std::make_unique<Multiply>( std::make_unique<Variable>(), std::make_unique<Multiply>(std::make_unique<Variable>(), std::make_unique<Variable>()) ); // x^3 auto two_x = std::make_unique<Multiply>(std::make_unique<Constant>(2), std::make_unique<Variable>()); // 2x auto f = std::make_unique<Add>(std::move(x_cubed), std::move(two_x)); // 计算 f'(x) auto df = f->derivative(); // 输出函数和导数的表达式 std::cout << "f(x) = " << f->str() << std::endl; std::cout << "f'(x) = " << df->str() << std::endl; return 0; }