12 #include "utils/logger.hpp"
29 node.visit_children(*
this);
44 const auto& blocks =
node.get_blocks();
45 for (
auto iter = blocks.begin(); iter != blocks.end(); iter++) {
46 if ((*iter)->is_kinetic_block()) {
47 const auto kinetic_block = std::dynamic_pointer_cast<ast::KineticBlock>(*iter);
48 const auto& block_name = kinetic_block->get_name()->get_node_name();
51 node.erase_node(iter--);
52 logger->debug(
"MatexpVisitor :: Removing solved KINETIC block \"{}\"", block_name);
62 const auto& solve_method =
node.get_method();
63 const auto& steadystate_method =
node.get_steadystate();
65 const auto is_method_matexp = [](
auto method) {
66 return method && method->get_value()->eval() == matexp_method;
68 const bool solve = is_method_matexp(solve_method);
69 const bool steadystate = is_method_matexp(steadystate_method);
73 }
else if (steadystate) {
93 this->conserve_statements.clear();
96 node.visit_children(*
this);
97 const auto& statements =
node.get_statements();
98 for (
auto iter = statements.begin(); iter != statements.end(); iter++) {
99 if ((*iter)->is_conserve()) {
100 node.erase_statement(iter--);
106 const auto expr =
node.get_expr();
107 const auto react =
node.get_react();
114 logger->error(
"MatexpVisitor :: Error : CONSERVE uses derivative");
115 throw std::invalid_argument(
"CONSERVE uses derivative");
121 const auto num_vars = vars.size();
122 std::vector<std::string> var_names;
123 for (
const auto&
var: vars) {
126 std::sort(var_names.begin(), var_names.end());
127 const int num_unqiue =
std::unique(var_names.begin(), var_names.end()) - var_names.begin();
128 if (num_vars != num_unqiue) {
129 logger->error(
"MatexpVisitor :: Error : CONSERVE is non-linear");
130 throw std::invalid_argument(
"CONSERVE is non-linear");
139 const auto&
name =
node.get_block_name()->get_node_name();
151 const auto&
name =
node.get_block_name()->get_node_name();
161 if (block->get_node_name() == block_name) {
165 throw std::runtime_error(
"cannot find the block '{" + block_name +
"}' to solve it");
173 const auto& jacobian_block = std::make_shared<ast::StatementBlock>(*
node.get_statement_block());
182 return std::make_shared<ast::MatexpBlock>(std::make_shared<ast::Boolean>(steadystate),
192 node.get_statements()[
index]->accept(*
this);
198 logger->error(
"MatexpVisitor :: Error : Reaction equation is non-linear");
199 throw std::invalid_argument(
"Reaction equation is non-linear");
205 const auto reaction_var_name = std::dynamic_pointer_cast<ast::ReactVarName>(
node);
206 if (!reaction_var_name) {
209 const auto coefficient = reaction_var_name->get_value();
210 if (coefficient && coefficient->eval() != 1) {
213 return reaction_var_name->get_node_name();
225 "MatexpVisitor :: Error : Reaction equation contains invalid state variable: \"{}\"",
227 throw std::invalid_argument(
"Reaction equation contains invalid state variable");
235 if (cursor ==
node) {
245 std::shared_ptr<ast::Expression> lhs,
246 std::shared_ptr<ast::Expression> kf) {
250 const int jf_src_idx = lhs_idx +
states.size() * lhs_idx;
252 const std::string jf_src =
"nmodl_eigen_j[" +
std::to_string(jf_src_idx) +
"]";
253 const std::string kf_nmodl =
to_nmodl(kf);
254 return create_statement(jf_src +
" = " + jf_src +
" - (" + kf_nmodl +
") * nmodl_dt");
260 std::shared_ptr<ast::Expression> lhs,
261 std::shared_ptr<ast::Expression>
rhs,
262 std::shared_ptr<ast::Expression> kf,
263 std::shared_ptr<ast::Expression> kb) {
270 const int jf_src_idx = lhs_idx +
states.size() * lhs_idx;
271 const int jf_dst_idx = lhs_idx *
states.size() + rhs_idx;
272 const int jb_src_idx = rhs_idx *
states.size() + rhs_idx;
273 const int jb_dst_idx = lhs_idx +
states.size() * rhs_idx;
275 const std::string jf_src =
"nmodl_eigen_j[" +
std::to_string(jf_src_idx) +
"]";
276 const std::string jf_dst =
"nmodl_eigen_j[" +
std::to_string(jf_dst_idx) +
"]";
277 const std::string jb_src =
"nmodl_eigen_j[" +
std::to_string(jb_src_idx) +
"]";
278 const std::string jb_dst =
"nmodl_eigen_j[" +
std::to_string(jb_dst_idx) +
"]";
279 const std::string kf_nmodl =
to_nmodl(kf);
280 const std::string kb_nmodl =
to_nmodl(kb);
281 const std::string jf_n_string = jf_src +
" = " + jf_src +
" - (" + kf_nmodl +
") * nmodl_dt";
282 const std::string jf_p_string = jf_dst +
" = " + jf_dst +
" + (" + kf_nmodl +
") * nmodl_dt";
283 const std::string jb_n_string = jb_src +
" = " + jb_src +
" - (" + kb_nmodl +
") * nmodl_dt";
284 const std::string jb_p_string = jb_dst +
" = " + jb_dst +
" + (" + kb_nmodl +
") * nmodl_dt";
289 return {jf_n, jf_p, jb_n, jb_p};
299 const auto& op =
node.get_op().get_value();
300 const auto& lhs =
node.get_reaction1();
301 const auto&
rhs =
node.get_reaction2();
302 const auto& kf =
node.get_expression1();
303 const auto& kb =
node.get_expression2();
305 assert(
node.get_parent()->is_statement_block());
307 const auto& statements = statement_block->get_statements();
310 statement_block->erase_statement(std::begin(statements) + statement_index);
318 statement_block->insert_statement(std::begin(statements) + statement_index, jf_n);
321 statement_block->insert_statement(std::begin(statements) + statement_index++, stmt);
Auto generated AST classes declaration.
Represent CONSERVE statement in NMODL.
Base class for all AST node.
Represents top level AST node for whole NMODL input.
Represents block encapsulating list of statements.
Concrete visitor for all AST classes.
void visit_statement_block(ast::StatementBlock &node) override
visit node of type ast::StatementBlock
std::vector< std::shared_ptr< ast::Conserve > > conserve_statements
void visit_conserve(ast::Conserve &node) override
visit node of type ast::Conserve
void visit_program(ast::Program &node) override
visit node of type ast::Program
std::vector< ast::SolveBlock * > solve_blocks
blocks to be solved
std::shared_ptr< ast::MatexpBlock > get_solve_block(const ast::SolveBlock &node, bool steadystate)
return the MatexpBlock solution for the given solve-block statement
std::vector< std::shared_ptr< ast::Statement > > transform_reaction_statement(std::shared_ptr< ast::Expression > lhs, std::shared_ptr< ast::Expression > rhs, std::shared_ptr< ast::Expression > kf, std::shared_ptr< ast::Expression > kb)
convert a reaction statement "<->" into equivalent assignments to the Jacobian matrix
std::vector< ast::SolveBlock * > steadystate_blocks
blocks to be solved
void visit_statement_block(ast::StatementBlock &node) override
visit node of type ast::StatementBlock
void visit_reaction_statement(ast::ReactionStatement &node) override
visit node of type ast::ReactionStatement
std::vector< std::shared_ptr< symtab::Symbol > > states
ordered list of state variables
std::shared_ptr< ast::Statement > transform_decay_statement(std::shared_ptr< ast::Expression > lhs, std::shared_ptr< ast::Expression > kf)
convert a decay reaction statement "->" into equivalent assignments to the Jacobian matrix
std::vector< ast::KineticBlock * > kinetic_blocks
all kinetic blocks in the program
std::vector< std::string > keep_blocks
blocks to be solved by a different solver method
std::shared_ptr< ast::MatexpBlock > solve_kinetic_block(const ast::KineticBlock &node, bool steadystate)
convert a KineticBlock into a MatexpBlock
ast::KineticBlock * find_kinetic_block(const std::string &block_name)
search the "kinetic_blocks" vector for the given block
void replace_solve_block(const ast::SolveBlock &node, bool steadystate)
replace the given solve-block statement with a MatexpBlock
int get_state_index(const std::string &state_name)
returns an index into the "states" vector
bool in_jacobian_block
currently visiting kinetic block that is being solved
void visit_solve_block(ast::SolveBlock &node) override
Populate the lists of solve-block statements in the program.
void visit_kinetic_block(ast::KineticBlock &node) override
Populate the list of kinetic blocks in the program.
static double solve(void *v)
virtual bool is_expression_statement() const noexcept
Check if the ast node is an instance of ast::ExpressionStatement.
@ PRIME_NAME
type of ast::PrimeName
std::vector< std::shared_ptr< Statement > > StatementVector
double var(InputIterator begin, InputIterator end)
Visitor used for generating the necessary AST nodes for matexp solver.
std::string to_string(EnumT e, const std::array< std::pair< EnumT, std::string_view >, N > &mapping, const std::string_view enum_name)
Converts an enum value to its corresponding string representation.
auto get_name(Tag const &tag, int field_index)
Get the nicest available name for the field_index-th instance of Tag.
static constexpr char MATEXP_METHOD[]
matexp method in nmodl
@ state_var
state variable
static void nonlinear_reaction_error()
static int find_node(const nmodl::ast::StatementVector &statements, const nmodl::ast::Node *node)
std::shared_ptr< Statement > create_statement(const std::string &code_statement)
Convert given code statement (in string format) to corresponding ast node.
static std::string get_state_var_name(const std::shared_ptr< ast::Ast > &node)
static bool vector_contains(const std::vector< T > &vec, const T &value)
encapsulates code generation backend implementations
std::vector< std::shared_ptr< const ast::Ast > > collect_nodes(const ast::Ast &node, const std::vector< ast::AstNodeType > &types)
traverse node recursively and collect nodes of given types
std::string to_nmodl(const ast::Ast &node, const std::set< ast::AstNodeType > &exclude_types)
Given AST node, return the NMODL string representation.
bool node_exists(const ast::Ast &node, ast::AstNodeType ast_type)
Whether a node of type ast_type exists as a subnode of node.
static Node * node(Object *)
int find(const int, const int, const int, const int, const int)
static double unique(void *v)
Base class for all Abstract Syntax Tree node types.
Utility functions for visitors implementation.