NEURON
matexp_visitor.cpp
Go to the documentation of this file.
1 /*
2  * Copyright 2025 David McDougall
3  * See the top-level LICENSE file for details.
4  *
5  * SPDX-License-Identifier: BSD-3-Clause
6  */
7 
9 
10 #include "ast/all.hpp"
12 #include "utils/logger.hpp"
14 
15 #include <algorithm>
16 
17 namespace nmodl {
18 namespace visitor {
19 
20 
21 template <typename T>
22 static bool vector_contains(const std::vector<T>& vec, const T& value) {
23  return std::find(vec.begin(), vec.end(), value) != vec.end();
24 }
25 
26 
28  // Make lists of all KineticBlock's and SolveBlock's in the program
29  node.visit_children(*this);
30 
31  states = node.get_symbol_table()->get_variables(symtab::syminfo::NmodlType::state_var);
32 
33  // Replace solve-steadystate-matexp statements with their MatexpBlock solution
34  for (const auto& solve_block: steadystate_blocks) {
35  replace_solve_block(*solve_block, true);
36  }
37 
38  // Get the MatexpBlock solutions and append them to the end of the file.
39  for (const auto& solve_block: solve_blocks) {
40  node.emplace_back_node(get_solve_block(*solve_block, false));
41  }
42 
43  // Remove solved KINETIC blocks
44  const auto& blocks = node.get_blocks();
45  for (auto iter = blocks.begin(); iter != blocks.end(); iter++) {
46  if ((*iter)->is_kinetic_block()) {
47  const auto kinetic_block = std::dynamic_pointer_cast<ast::KineticBlock>(*iter);
48  const auto& block_name = kinetic_block->get_name()->get_node_name();
49  bool keep = vector_contains(keep_blocks, block_name);
50  if (!keep) {
51  node.erase_node(iter--);
52  logger->debug("MatexpVisitor :: Removing solved KINETIC block \"{}\"", block_name);
53  }
54  }
55  }
56 }
57 
58 
59 /// Populate the lists of solve-block statements in the program
61  // Find solve statements that use the matexp solver method
62  const auto& solve_method = node.get_method();
63  const auto& steadystate_method = node.get_steadystate();
64  const auto& matexp_method = codegen::naming::MATEXP_METHOD;
65  const auto is_method_matexp = [](auto method) {
66  return method && method->get_value()->eval() == matexp_method;
67  };
68  const bool solve = is_method_matexp(solve_method);
69  const bool steadystate = is_method_matexp(steadystate_method);
70  // Save the block for later reference
71  if (solve) {
72  solve_blocks.push_back(&node);
73  } else if (steadystate) {
74  steadystate_blocks.push_back(&node);
75  } else {
76  keep_blocks.push_back(node.get_block_name()->get_node_name());
77  }
78 }
79 
80 
81 /// Populate the list of kinetic blocks in the program
83  kinetic_blocks.push_back(&node);
84 }
85 
86 
87 // Helper class for finding, checking, and removing CONSERVE statements
89  public:
90  std::vector<std::shared_ptr<ast::Conserve>> conserve_statements;
91 
93  this->conserve_statements.clear();
94  }
96  node.visit_children(*this);
97  const auto& statements = node.get_statements();
98  for (auto iter = statements.begin(); iter != statements.end(); iter++) {
99  if ((*iter)->is_conserve()) {
100  node.erase_statement(iter--);
101  }
102  }
103  }
105  // Unpack the conserve statement
106  const auto expr = node.get_expr();
107  const auto react = node.get_react();
108  // Check the CONSERVE statement is usable.
109  const bool primes = node_exists(*react,
110  {
112  });
113  if (primes) {
114  logger->error("MatexpVisitor :: Error : CONSERVE uses derivative");
115  throw std::invalid_argument("CONSERVE uses derivative");
116  }
117  const auto vars = collect_nodes(*react,
118  {
120  });
121  const auto num_vars = vars.size();
122  std::vector<std::string> var_names;
123  for (const auto& var: vars) {
124  var_names.push_back(to_nmodl(var));
125  }
126  std::sort(var_names.begin(), var_names.end());
127  const int num_unqiue = std::unique(var_names.begin(), var_names.end()) - var_names.begin();
128  if (num_vars != num_unqiue) {
129  logger->error("MatexpVisitor :: Error : CONSERVE is non-linear");
130  throw std::invalid_argument("CONSERVE is non-linear");
131  }
132  conserve_statements.push_back(std::make_shared<ast::Conserve>(node));
133  }
134 };
135 
136 
137 // Replace the given solve statement with a MatexpBlock
139  const auto& name = node.get_block_name()->get_node_name();
140  const auto& block = find_kinetic_block(name);
141  const auto& solution = solve_kinetic_block(*block, steadystate);
142  ast::Ast* parent = node.get_parent();
143  assert(parent->is_expression_statement());
144  ((ast::ExpressionStatement*) parent)->set_expression(solution);
145 }
146 
147 
148 // Return the MatexpBlock solution for the given solve-block statement
149 std::shared_ptr<ast::MatexpBlock> MatexpVisitor::get_solve_block(const ast::SolveBlock& node,
150  bool steadystate) {
151  const auto& name = node.get_block_name()->get_node_name();
152  const auto& block = find_kinetic_block(name);
153  const auto& solution = solve_kinetic_block(*block, steadystate);
154  return solution;
155 }
156 
157 
158 // Search the kinetic_blocks vector for the given block
159 ast::KineticBlock* MatexpVisitor::find_kinetic_block(const std::string& block_name) {
160  for (const auto& block: kinetic_blocks) {
161  if (block->get_node_name() == block_name) {
162  return block;
163  }
164  }
165  throw std::runtime_error("cannot find the block '{" + block_name + "}' to solve it");
166 }
167 
168 
169 // Convert a KineticBlock into a MatexpBlock
170 std::shared_ptr<ast::MatexpBlock> MatexpVisitor::solve_kinetic_block(const ast::KineticBlock& node,
171  bool steadystate) {
172  // Make a copy of the statement block, do not modify original.
173  const auto& jacobian_block = std::make_shared<ast::StatementBlock>(*node.get_statement_block());
174  // Convert the reaction statements into assignments to the Jacobian matrix
175  in_jacobian_block = true;
176  this->visit_statement_block(*jacobian_block);
177  in_jacobian_block = false;
178 
179  CollectConserveVisitor conserve_visitor;
180  conserve_visitor.visit_statement_block(*jacobian_block);
181 
182  return std::make_shared<ast::MatexpBlock>(std::make_shared<ast::Boolean>(steadystate),
183  jacobian_block,
184  conserve_visitor.conserve_statements);
185 }
186 
187 
189  // Iterate through the children using indices instead of pointers,
190  // so that more statements can be added to the list
191  for (int index = 0; index < node.get_statements().size(); index++) {
192  node.get_statements()[index]->accept(*this);
193  }
194 }
195 
196 
198  logger->error("MatexpVisitor :: Error : Reaction equation is non-linear");
199  throw std::invalid_argument("Reaction equation is non-linear");
200 }
201 
202 
203 // Argument node is a reactant or product expression
204 static std::string get_state_var_name(const std::shared_ptr<ast::Ast>& node) {
205  const auto reaction_var_name = std::dynamic_pointer_cast<ast::ReactVarName>(node);
206  if (!reaction_var_name) {
208  }
209  const auto coefficient = reaction_var_name->get_value();
210  if (coefficient && coefficient->eval() != 1) {
212  }
213  return reaction_var_name->get_node_name();
214 }
215 
216 
217 // Returns an index into the "state" vector
218 int MatexpVisitor::get_state_index(const std::string& state_name) {
219  for (int index = 0; index < states.size(); index++) {
220  if (state_name == states[index]->get_name()) {
221  return index;
222  }
223  }
224  logger->error(
225  "MatexpVisitor :: Error : Reaction equation contains invalid state variable: \"{}\"",
226  state_name);
227  throw std::invalid_argument("Reaction equation contains invalid state variable");
228 }
229 
230 
231 // Get the index of a statement in a statement-vector
232 static int find_node(const nmodl::ast::StatementVector& statements, const nmodl::ast::Node* node) {
233  for (int index = 0; index < statements.size(); index++) {
234  const nmodl::ast::Node* cursor = statements[index].get();
235  if (cursor == node) {
236  return index;
237  }
238  }
239  assert(false); // unreachable
240 }
241 
242 
243 // Convert a decay reaction statement "->" into equivalent assignments to the Jacobian matrix
244 std::shared_ptr<ast::Statement> MatexpVisitor::transform_decay_statement(
245  std::shared_ptr<ast::Expression> lhs,
246  std::shared_ptr<ast::Expression> kf) {
247  const auto lhs_name = get_state_var_name(lhs);
248  const auto lhs_idx = get_state_index(lhs_name);
249  // Calculate the Jacobian matrix indices
250  const int jf_src_idx = lhs_idx + states.size() * lhs_idx;
251  // Write NMODL to assign to the Jacobian matrix
252  const std::string jf_src = "nmodl_eigen_j[" + std::to_string(jf_src_idx) + "]";
253  const std::string kf_nmodl = to_nmodl(kf);
254  return create_statement(jf_src + " = " + jf_src + " - (" + kf_nmodl + ") * nmodl_dt");
255 }
256 
257 
258 // Convert a reaction statement "<->" into equivalent assignments to the Jacobian matrix
259 std::vector<std::shared_ptr<ast::Statement>> MatexpVisitor::transform_reaction_statement(
260  std::shared_ptr<ast::Expression> lhs,
261  std::shared_ptr<ast::Expression> rhs,
262  std::shared_ptr<ast::Expression> kf,
263  std::shared_ptr<ast::Expression> kb) {
264  // Find the state vector indices
265  const auto lhs_name = get_state_var_name(lhs);
266  const auto rhs_name = get_state_var_name(rhs);
267  const auto lhs_idx = get_state_index(lhs_name);
268  const auto rhs_idx = get_state_index(rhs_name);
269  // Calculate the Jacobian matrix indices
270  const int jf_src_idx = lhs_idx + states.size() * lhs_idx;
271  const int jf_dst_idx = lhs_idx * states.size() + rhs_idx;
272  const int jb_src_idx = rhs_idx * states.size() + rhs_idx;
273  const int jb_dst_idx = lhs_idx + states.size() * rhs_idx;
274  // Create four new statements assigning to the Jacobian matrix
275  const std::string jf_src = "nmodl_eigen_j[" + std::to_string(jf_src_idx) + "]";
276  const std::string jf_dst = "nmodl_eigen_j[" + std::to_string(jf_dst_idx) + "]";
277  const std::string jb_src = "nmodl_eigen_j[" + std::to_string(jb_src_idx) + "]";
278  const std::string jb_dst = "nmodl_eigen_j[" + std::to_string(jb_dst_idx) + "]";
279  const std::string kf_nmodl = to_nmodl(kf);
280  const std::string kb_nmodl = to_nmodl(kb);
281  const std::string jf_n_string = jf_src + " = " + jf_src + " - (" + kf_nmodl + ") * nmodl_dt";
282  const std::string jf_p_string = jf_dst + " = " + jf_dst + " + (" + kf_nmodl + ") * nmodl_dt";
283  const std::string jb_n_string = jb_src + " = " + jb_src + " - (" + kb_nmodl + ") * nmodl_dt";
284  const std::string jb_p_string = jb_dst + " = " + jb_dst + " + (" + kb_nmodl + ") * nmodl_dt";
285  const auto& jf_n = create_statement(jf_n_string);
286  const auto& jf_p = create_statement(jf_p_string);
287  const auto& jb_n = create_statement(jb_n_string);
288  const auto& jb_p = create_statement(jb_p_string);
289  return {jf_n, jf_p, jb_n, jb_p};
290 }
291 
292 
293 // Visit reaction statements inside of kinetic blocks which we are actively solving
295  if (!in_jacobian_block) {
296  return;
297  }
298  // Unpack the reaction data
299  const auto& op = node.get_op().get_value();
300  const auto& lhs = node.get_reaction1();
301  const auto& rhs = node.get_reaction2();
302  const auto& kf = node.get_expression1(); // forwards reaction rate
303  const auto& kb = node.get_expression2(); // backwards reaction rate
304  // Get the parent statement block
305  assert(node.get_parent()->is_statement_block());
306  const auto statement_block = (ast::StatementBlock*) node.get_parent();
307  const auto& statements = statement_block->get_statements();
308  // Find and remove this reaction statement
309  int statement_index = find_node(statements, &node);
310  statement_block->erase_statement(std::begin(statements) + statement_index);
311  // Check for invalid kinetic models
312  if (op == ast::ReactionOp::LTLT) {
314  }
315  // Replace reaction statements with assignments to the Jaconbian
316  else if (op == ast::ReactionOp::MINUSGT) {
317  const auto jf_n = transform_decay_statement(lhs, kf);
318  statement_block->insert_statement(std::begin(statements) + statement_index, jf_n);
319  } else if (op == ast::ReactionOp::LTMINUSGT) {
320  for (const auto& stmt: transform_reaction_statement(lhs, rhs, kf, kb)) {
321  statement_block->insert_statement(std::begin(statements) + statement_index++, stmt);
322  }
323  }
324 }
325 
326 } // namespace visitor
327 } // namespace nmodl
Auto generated AST classes declaration.
Represent CONSERVE statement in NMODL.
Definition: conserve.hpp:38
Base class for all AST node.
Definition: node.hpp:40
Represents top level AST node for whole NMODL input.
Definition: program.hpp:39
Represents block encapsulating list of statements.
Concrete visitor for all AST classes.
Definition: ast_visitor.hpp:37
void visit_statement_block(ast::StatementBlock &node) override
visit node of type ast::StatementBlock
std::vector< std::shared_ptr< ast::Conserve > > conserve_statements
void visit_conserve(ast::Conserve &node) override
visit node of type ast::Conserve
void visit_program(ast::Program &node) override
visit node of type ast::Program
std::vector< ast::SolveBlock * > solve_blocks
blocks to be solved
std::shared_ptr< ast::MatexpBlock > get_solve_block(const ast::SolveBlock &node, bool steadystate)
return the MatexpBlock solution for the given solve-block statement
std::vector< std::shared_ptr< ast::Statement > > transform_reaction_statement(std::shared_ptr< ast::Expression > lhs, std::shared_ptr< ast::Expression > rhs, std::shared_ptr< ast::Expression > kf, std::shared_ptr< ast::Expression > kb)
convert a reaction statement "<->" into equivalent assignments to the Jacobian matrix
std::vector< ast::SolveBlock * > steadystate_blocks
blocks to be solved
void visit_statement_block(ast::StatementBlock &node) override
visit node of type ast::StatementBlock
void visit_reaction_statement(ast::ReactionStatement &node) override
visit node of type ast::ReactionStatement
std::vector< std::shared_ptr< symtab::Symbol > > states
ordered list of state variables
std::shared_ptr< ast::Statement > transform_decay_statement(std::shared_ptr< ast::Expression > lhs, std::shared_ptr< ast::Expression > kf)
convert a decay reaction statement "->" into equivalent assignments to the Jacobian matrix
std::vector< ast::KineticBlock * > kinetic_blocks
all kinetic blocks in the program
std::vector< std::string > keep_blocks
blocks to be solved by a different solver method
std::shared_ptr< ast::MatexpBlock > solve_kinetic_block(const ast::KineticBlock &node, bool steadystate)
convert a KineticBlock into a MatexpBlock
ast::KineticBlock * find_kinetic_block(const std::string &block_name)
search the "kinetic_blocks" vector for the given block
void replace_solve_block(const ast::SolveBlock &node, bool steadystate)
replace the given solve-block statement with a MatexpBlock
int get_state_index(const std::string &state_name)
returns an index into the "states" vector
bool in_jacobian_block
currently visiting kinetic block that is being solved
void visit_solve_block(ast::SolveBlock &node) override
Populate the lists of solve-block statements in the program.
void visit_kinetic_block(ast::KineticBlock &node) override
Populate the list of kinetic blocks in the program.
static double solve(void *v)
Definition: cvodeobj.cpp:87
virtual bool is_expression_statement() const noexcept
Check if the ast node is an instance of ast::ExpressionStatement.
Definition: ast.cpp:226
@ NAME
type of ast::Name
@ PRIME_NAME
type of ast::PrimeName
std::vector< std::shared_ptr< Statement > > StatementVector
Definition: ast_decl.hpp:304
#define assert(ex)
Definition: hocassrt.h:24
double var(InputIterator begin, InputIterator end)
Definition: ivocvect.h:108
#define rhs
Definition: lineq.h:6
Visitor used for generating the necessary AST nodes for matexp solver.
const char * name
Definition: init.cpp:16
std::string to_string(EnumT e, const std::array< std::pair< EnumT, std::string_view >, N > &mapping, const std::string_view enum_name)
Converts an enum value to its corresponding string representation.
Definition: nrnreport.hpp:102
auto get_name(Tag const &tag, int field_index)
Get the nicest available name for the field_index-th instance of Tag.
static constexpr char MATEXP_METHOD[]
matexp method in nmodl
static void nonlinear_reaction_error()
static int find_node(const nmodl::ast::StatementVector &statements, const nmodl::ast::Node *node)
std::shared_ptr< Statement > create_statement(const std::string &code_statement)
Convert given code statement (in string format) to corresponding ast node.
static std::string get_state_var_name(const std::shared_ptr< ast::Ast > &node)
static bool vector_contains(const std::vector< T > &vec, const T &value)
encapsulates code generation backend implementations
Definition: ast_common.hpp:26
std::vector< std::shared_ptr< const ast::Ast > > collect_nodes(const ast::Ast &node, const std::vector< ast::AstNodeType > &types)
traverse node recursively and collect nodes of given types
std::string to_nmodl(const ast::Ast &node, const std::set< ast::AstNodeType > &exclude_types)
Given AST node, return the NMODL string representation.
logger_type logger
Definition: logger.cpp:34
bool node_exists(const ast::Ast &node, ast::AstNodeType ast_type)
Whether a node of type ast_type exists as a subnode of node.
static Node * node(Object *)
Definition: netcvode.cpp:291
short index
Definition: cabvars.h:11
int find(const int, const int, const int, const int, const int)
static uint32_t value
Definition: scoprand.cpp:25
static double unique(void *v)
Definition: seclist.cpp:193
Base class for all Abstract Syntax Tree node types.
Definition: ast.hpp:52
Utility functions for visitors implementation.