C++ 微积分 - 求导 - 解析法(符号计算)
#include <iostream>
#include <string>
#include <memory>
// 表示一个符号表达式的基类
class Expression {
public:
virtual std::string str() const = 0; // 表示表达式的字符串形式
virtual std::unique_ptr<Expression> derivative() const = 0; // 计算表达式的导数
virtual std::unique_ptr<Expression> clone() const = 0; // 克隆当前表达式
virtual ~Expression() {}
};
// 表示常数的类
class Constant : public Expression {
double value;
public:
Constant(double val) : value(val) {}
std::string str() const override {
return std::to_string(value);
}
std::unique_ptr<Expression> derivative() const override {
return std::make_unique<Constant>(0); // 常数的导数为 0
}
std::unique_ptr<Expression> clone() const override {
return std::make_unique<Constant>(value); // 克隆常数
}
double getValue() const { return value; } // 获取常数值
};
// 表示变量 x 的类
class Variable : public Expression {
public:
std::string str() const override {
return "x";
}
std::unique_ptr<Expression> derivative() const override {
return std::make_unique<Constant>(1); // x 的导数为 1
}
std::unique_ptr<Expression> clone() const override {
return std::make_unique<Variable>(); // 克隆变量
}
};
// 表示加法运算的类
class Add : public Expression {
std::unique_ptr<Expression> left, right;
public:
Add(std::unique_ptr<Expression> l, std::unique_ptr<Expression> r)
: left(std::move(l)), right(std::move(r)) {}
std::string str() const override {
std::string leftStr = left->str();
std::string rightStr = right->str();
return "(" + leftStr + " + " + rightStr + ")";
}
std::unique_ptr<Expression> derivative() const override {
// 加法导数是各自导数的和
return std::make_unique<Add>(left->derivative(), right->derivative());
}
std::unique_ptr<Expression> clone() const override {
return std::make_unique<Add>(left->clone(), right->clone()); // 克隆加法表达式
}
};
// 表示乘法运算的类
class Multiply : public Expression {
std::unique_ptr<Expression> left, right;
public:
Multiply(std::unique_ptr<Expression> l, std::unique_ptr<Expression> r)
: left(std::move(l)), right(std::move(r)) {}
std::string str() const override {
// 处理乘法表达式中的优先级和括号
std::string leftStr = left->str();
std::string rightStr = right->str();
// 在乘法的左右操作数前加括号(如果必要)
if (dynamic_cast<Add*>(left.get()) || dynamic_cast<Multiply*>(left.get())) {
leftStr = "(" + leftStr + ")";
}
if (dynamic_cast<Add*>(right.get()) || dynamic_cast<Multiply*>(right.get())) {
rightStr = "(" + rightStr + ")";
}
return leftStr + " * " + rightStr;
}
std::unique_ptr<Expression> derivative() const override {
// 乘法导数根据乘积法则计算
// (u * v)' = u' * v + u * v'
return std::make_unique<Add>(
std::make_unique<Multiply>(left->derivative()->clone(), right->clone()),
std::make_unique<Multiply>(left->clone(), right->derivative()->clone())
);
}
std::unique_ptr<Expression> clone() const override {
return std::make_unique<Multiply>(left->clone(), right->clone()); // 克隆乘法表达式
}
};
int main() {
// 构造函数 f(x) = x^3 + 2x
auto x = std::make_unique<Variable>();
auto x_cubed = std::make_unique<Multiply>(
std::make_unique<Variable>(),
std::make_unique<Multiply>(std::make_unique<Variable>(), std::make_unique<Variable>())
); // x^3
auto two_x = std::make_unique<Multiply>(std::make_unique<Constant>(2), std::make_unique<Variable>()); // 2x
auto f = std::make_unique<Add>(std::move(x_cubed), std::move(two_x));
// 计算 f'(x)
auto df = f->derivative();
// 输出函数和导数的表达式
std::cout << "f(x) = " << f->str() << std::endl;
std::cout << "f'(x) = " << df->str() << std::endl;
return 0;
}