NEURON
codegen_helper_visitor.cpp
Go to the documentation of this file.
1 /*
2  * Copyright 2023 Blue Brain Project, EPFL.
3  * See the top-level LICENSE file for details.
4  *
5  * SPDX-License-Identifier: Apache-2.0
6  */
7 
9 
10 #include <algorithm>
11 #include <cmath>
12 #include <memory>
13 
14 #include "ast/all.hpp"
15 #include "ast/constant_var.hpp"
17 #include "parser/c11_driver.hpp"
19 
20 #include "utils/logger.hpp"
21 
22 namespace nmodl {
23 namespace codegen {
24 
25 using namespace ast;
26 
29 
30 /**
31  * Check whether a given SOLVE block solves a PROCEDURE with any of the CVode methods
32  */
33 static bool check_procedure_has_cvode(const std::shared_ptr<const ast::Ast>& solve_node,
34  const std::shared_ptr<const ast::Ast>& procedure_node) {
35  const auto& solve_block = std::dynamic_pointer_cast<const ast::SolveBlock>(solve_node);
36  const auto& method = solve_block->get_method();
37  if (!method) {
38  return false;
39  }
40  const auto& method_name = method->get_node_name();
41 
42  return procedure_node->get_node_name() == solve_block->get_block_name()->get_node_name() &&
43  (method_name == codegen::naming::AFTER_CVODE_METHOD ||
44  method_name == codegen::naming::CVODE_T_METHOD ||
45  method_name == codegen::naming::CVODE_T_V_METHOD);
46 }
47 
48 /**
49  * How symbols are stored in NEURON? See notes written in markdown file.
50  *
51  * Some variables get printed by iterating over symbol table in mod2c.
52  * The example of this is thread variables (and also ions?). In this
53  * case we must have to arrange order if we are going keep compatibility
54  * with NEURON.
55  *
56  * Suppose there are three global variables: bcd, abc, abd, abe
57  * They will be in the 'a' bucket in order:
58  * abe, abd, abc
59  * and in 'b' bucket
60  * bcd
61  * So when we print thread variables, we first have to sort in the opposite
62  * order in which they come and then again order by first character in increasing
63  * order.
64  *
65  * Note that variables in double array do not need this transformation
66  * and it seems like they should just follow definition order.
67  */
68 void CodegenHelperVisitor::sort_with_mod2c_symbol_order(std::vector<SymbolType>& symbols) {
69  /// first sort by global id to get in reverse order
70  std::sort(symbols.begin(),
71  symbols.end(),
72  [](const SymbolType& first, const SymbolType& second) -> bool {
73  return first->get_id() > second->get_id();
74  });
75 
76  /// now order by name (to be same as neuron's bucket)
77  std::sort(symbols.begin(),
78  symbols.end(),
79  [](const SymbolType& first, const SymbolType& second) -> bool {
80  return first->get_name()[0] < second->get_name()[0];
81  });
82 }
83 
84 
85 /**
86  * Find all ions used in mod file
87  */
88 // NOLINTNEXTLINE(readability-function-cognitive-complexity)
90  // collect all use ion statements
91  const auto& ion_nodes = collect_nodes(node, {AstNodeType::USEION});
92 
93  // ion names, read ion variables and write ion variables
94  std::vector<std::string> ion_vars;
95  std::vector<std::string> read_ion_vars;
96  std::vector<std::string> write_ion_vars;
97  std::map<std::string, double> valences;
98 
99  for (const auto& ion_node: ion_nodes) {
100  const auto& ion = std::dynamic_pointer_cast<const ast::Useion>(ion_node);
101  auto ion_name = ion->get_node_name();
102  ion_vars.push_back(ion_name);
103  for (const auto& var: ion->get_readlist()) {
104  read_ion_vars.push_back(var->get_node_name());
105  }
106  for (const auto& var: ion->get_writelist()) {
107  write_ion_vars.push_back(var->get_node_name());
108  }
109 
110  if (ion->get_valence() != nullptr) {
111  valences[ion_name] = ion->get_valence()->get_value()->to_double();
112  }
113  }
114 
115  /**
116  * Check if given variable belongs to given ion.
117  * For example, eca belongs to ca ion, nai belongs to na ion.
118  * We just check if we exclude first/last char, if that is ion name.
119  */
120  auto ion_variable = [](const std::string& var, const std::string& ion) -> bool {
121  auto len = var.size() - 1;
122  return (var.substr(1, len) == ion || var.substr(0, len) == ion);
123  };
124 
125  /// iterate over all ion types and construct the Ion objects
126  for (auto& ion_name: ion_vars) {
127  Ion ion(ion_name);
128  for (auto& read_var: read_ion_vars) {
129  if (ion_variable(read_var, ion_name)) {
130  ion.reads.push_back(read_var);
131  }
132  }
133  for (auto& write_var: write_ion_vars) {
134  if (ion_variable(write_var, ion_name)) {
135  ion.writes.push_back(write_var);
136  if (ion.is_intra_cell_conc(write_var) || ion.is_extra_cell_conc(write_var)) {
137  ion.need_style = true;
138  info.write_concentration = true;
139  }
140  }
141  }
142  if (auto it = valences.find(ion_name); it != valences.end()) {
143  ion.valence = it->second;
144  }
145 
146  info.ions.push_back(std::move(ion));
147  }
148 
149  /// once ions are populated, we can find all currents
150  auto vars = psymtab->get_variables_with_properties(NmodlType::nonspecific_cur_var);
151  for (auto& var: vars) {
152  info.currents.push_back(var->get_name());
153  }
154  vars = psymtab->get_variables_with_properties(NmodlType::electrode_cur_var);
155  for (auto& var: vars) {
156  info.currents.push_back(var->get_name());
157  }
158  for (auto& ion: info.ions) {
159  for (auto& var: ion.writes) {
160  if (ion.is_ionic_current(var)) {
161  info.currents.push_back(var);
162  }
163  }
164  }
165 
166  /// check if write_conc(...) will be needed
167  for (const auto& ion: info.ions) {
168  for (const auto& var: ion.writes) {
169  if (!ion.is_ionic_current(var) && !ion.is_rev_potential(var)) {
170  info.require_wrote_conc = true;
171  }
172  }
173  }
174 }
175 
176 /**
177  * Find whether or not we need to emit CVODE-related code for NEURON
178  * Notes: we generate CVODE-related code if and only if:
179  * - there is exactly ONE block being SOLVEd
180  * - the block is one of the following types:
181  * - DERIVATIVE
182  * - KINETIC
183  * - PROCEDURE being solved with the `after_cvode`, `cvode_t`, or `cvode_t_v` methods
184  */
186  // find the breakpoint block
187  const auto& breakpoint_nodes = collect_nodes(node, {AstNodeType::BREAKPOINT_BLOCK});
188 
189  // do nothing if there are no BREAKPOINT nodes
190  if (breakpoint_nodes.empty()) {
191  return;
192  }
193 
194  // there can only be one BREAKPOINT block in the entire program
195  assert(breakpoint_nodes.size() == 1);
196 
197  const auto& breakpoint_node = std::dynamic_pointer_cast<const ast::BreakpointBlock>(
198  breakpoint_nodes[0]);
199 
200  // all (global) kinetic/derivative nodes
201  const auto& kinetic_or_derivative_nodes =
202  collect_nodes(node, {AstNodeType::KINETIC_BLOCK, AstNodeType::DERIVATIVE_BLOCK});
203 
204  // all (global) procedure nodes
205  const auto& procedure_nodes = collect_nodes(node, {AstNodeType::PROCEDURE_BLOCK});
206 
207  // find all SOLVE blocks in that BREAKPOINT block
208  const auto& solve_nodes = collect_nodes(*breakpoint_node, {AstNodeType::SOLVE_BLOCK});
209 
210  // check whether any of the SOLVE blocks are solving any PROCEDURE with `after_cvode`,
211  // `cvode_t`, or `cvode_t_v` methods
212  const auto using_cvode = std::any_of(
213  solve_nodes.begin(), solve_nodes.end(), [&procedure_nodes](const auto& solve_node) {
214  return std::any_of(procedure_nodes.begin(),
215  procedure_nodes.end(),
216  [&solve_node](const auto& procedure_node) {
217  return check_procedure_has_cvode(solve_node, procedure_node);
218  });
219  });
220 
221  // only case when we emit CVODE code is if we have exactly one block, and
222  // that block is either a KINETIC/DERIVATIVE with any method, or a
223  // PROCEDURE with `after_cvode` method
224  if (solve_nodes.size() == 1 && (kinetic_or_derivative_nodes.size() || using_cvode)) {
225  logger->debug("Will emit code for CVODE");
226  info.emit_cvode = enable_cvode;
227  }
228 }
229 
230 // check that a given AST node `node`, which can be casted to `T`, has a node name which matches
231 // `match`
232 template <typename T>
233 static bool node_name_matches(const std::shared_ptr<const ast::Ast>& node,
234  const std::string& match) {
235  const auto& cast_node = std::dynamic_pointer_cast<const T>(node);
236  return cast_node && cast_node->get_node_name() == match;
237 };
238 
239 
240 /**
241  * Find non-range variables i.e. ones that are not belong to per instance allocation
242  *
243  * Certain variables like pointers, global, parameters are not necessary to be per
244  * instance variables. NEURON apply certain rules to determine which variables become
245  * thread, static or global variables. Here we construct those variables.
246  */
248  /**
249  * Top local variables are local variables appear in global scope. All local
250  * variables in program symbol table are in global scope.
251  */
252  info.constant_variables = psymtab->get_variables_with_properties(NmodlType::constant_var);
255 
256  /**
257  * All global variables remain global if mod file is not marked thread safe.
258  * Otherwise, global variables written at least once gets promoted to thread variables.
259  */
260 
261  std::string variables;
262 
263  auto vars = psymtab->get_variables_with_properties(NmodlType::global_var);
264  for (auto& var: vars) {
265  if (info.vectorize && info.declared_thread_safe && var->get_write_count() > 0) {
266  var->mark_thread_safe();
267  info.thread_variables.push_back(var);
268  info.thread_var_data_size += var->get_length();
269  variables += " " + var->get_name();
270  } else {
271  info.global_variables.push_back(var);
272  }
273  }
274 
275  /**
276  * If parameter is not a range and used only as read variable then it becomes global
277  * variable. To qualify it as thread variable it must be be written at least once and
278  * mod file must be marked as thread safe.
279  * To exclusively get parameters only, we exclude all other variables (in without)
280  * and then sort them with neuron/mod2c order.
281  */
282  // clang-format off
283  auto with = NmodlType::param_assign;
284  auto without = NmodlType::range_var
285  | NmodlType::assigned_definition
286  | NmodlType::global_var
287  | NmodlType::pointer_var
288  | NmodlType::bbcore_pointer_var
289  | NmodlType::read_ion_var
290  | NmodlType::write_ion_var;
291  // clang-format on
292  vars = psymtab->get_variables(with, without);
293  for (auto& var: vars) {
294  // some variables like area and diam are declared in parameter
295  // block but they are not global
296  if (var->get_name() == naming::DIAM_VARIABLE || var->get_name() == naming::AREA_VARIABLE ||
297  var->has_any_property(NmodlType::extern_neuron_variable)) {
298  continue;
299  }
300 
301  // if model is thread safe and if parameter is being written then
302  // those variables should be promoted to thread safe variable
303  if (info.vectorize && info.declared_thread_safe && var->get_write_count() > 0) {
304  var->mark_thread_safe();
305  info.thread_variables.push_back(var);
306  info.thread_var_data_size += var->get_length();
307  } else {
308  info.global_variables.push_back(var);
309  }
310  }
312 
313  /**
314  * \todo Below we calculate thread related id and sizes. This will
315  * need to do from global analysis pass as here we are handling
316  * top local variables, global variables, derivimplicit method.
317  * There might be more use cases with other solver methods.
318  */
319 
320  /**
321  * If derivimplicit is used, then first three thread ids get assigned to:
322  * 1st thread is used for: deriv_advance
323  * 2nd thread is used for: dith
324  * 3rd thread is used for: newtonspace
325  *
326  * slist and dlist represent the offsets for prime variables used. For
327  * euler or derivimplicit methods its always first number.
328  */
329  if (info.derivimplicit_used()) {
334  }
335 
336  /// next thread id is allocated for top local variables
337  if (info.vectorize && !info.top_local_variables.empty()) {
340  }
341 
342  /// next thread id is allocated for thread promoted variables
343  if (info.vectorize && !info.thread_variables.empty()) {
346  }
347 
348  /// find total size of local variables in global scope
349  for (auto& var: info.top_local_variables) {
350  info.top_local_thread_size += var->get_length();
351  }
352 
353  /// find number of prime variables and total size
354  auto primes = psymtab->get_variables_with_properties(NmodlType::prime_name);
355  info.num_primes = static_cast<int>(primes.size());
356  for (auto& variable: primes) {
357  info.primes_size += variable->get_length();
358  }
359 
360  /// find pointer or bbcore pointer variables
361  auto properties = NmodlType::pointer_var | NmodlType::bbcore_pointer_var;
363 
364  /// find RANDOM variables
365  properties = NmodlType::random_var;
367 
368  // find special variables like diam, area
369  const auto& special_variables_usage = collect_nodes(node, {ast::AstNodeType::VAR_NAME});
370  const auto& special_variables_declaration =
373  // If the special variable is actually used, it should show up somewhere as a VarName.
374  // Note that it can appear in VERBATIM, in which case we generate initialization code only
375  // if it has been set as ASSIGNED or PARAMETER.
376  auto predicate_used = [](const auto& var, const auto& name) {
377  return node_name_matches<ast::VarName>(var, name);
378  };
379  auto predicate_declared = [](const auto& var, const auto& name) {
380  return node_name_matches<ast::AssignedDefinition>(var, name) ||
381  node_name_matches<ast::ParamAssign>(var, name);
382  };
383  // map between the name of the variable and whether or not it's used (by ref so changes persist)
384  std::unordered_map<std::string, bool&> special_variables = {
386  for (auto& [name, value]: special_variables) {
387  const auto& used = std::any_of(
388  special_variables_usage.begin(),
389  special_variables_usage.end(),
390  // According to C++17, we cannot pass just `name` because:
391  // "If a lambda-expression explicitly captures an entity that is not odr-usable or
392  // captures a structured binding (explicitly or implicitly), the program is ill-formed."
393  // This means we need to use an "init-capture", i.e. the form `<internal>=<external>`,
394  // where `<internal>` is the variable visible inside the lambda, and `<external>` is the
395  // variable outside of it. Note that just passing `name` is valid C++20, but there is a
396  // bug in some versions of Clang which prevents even that.
397  [&predicate_used, name = name](const auto& var) { return predicate_used(var, name); });
398  const auto& declared = std::any_of(special_variables_declaration.begin(),
399  special_variables_declaration.end(),
400  [&predicate_declared, name = name](const auto& var) {
401  return predicate_declared(var, name);
402  });
403  if (declared && !used) {
404  logger->warn(
405  "Variable {} not used anywhere (except possibly VERBATIM block), but declared in a "
406  "PARAMETER or ASSIGNED block; will generate initialization code for it anyway",
407  name);
408  }
409  value = declared || used;
410  }
411 }
412 
413 /**
414  * Find range variables i.e. ones that are belong to per instance allocation
415  *
416  * In order to be compatible with NEURON, we need to print all variables in
417  * exact order as NEURON/MOD2C implementation. This is important because memory
418  * for all variables is allocated in single 1-D array with certain offset
419  * for each variable. The order of variables determine the offset and hence
420  * they must be in same order as NEURON.
421  *
422  * Here is how order is determined into NEURON/MOD2C implementation:
423  *
424  * First, following three lists are created
425  * - variables with parameter and range property (List 1)
426  * - variables with state and range property (List 2)
427  * - variables with assigned and range property (List 3)
428  *
429  * Once created, we remove some variables due to the following criteria:
430  * - In NEURON/MOD2C implementation, we remove variables with NRNPRANGEIN
431  * or NRNPRANGEOUT type
432  * - So who has NRNPRANGEIN and NRNPRANGEOUT type? these are USEION read
433  * or write variables that are not ionic currents.
434  * - This is the reason for mod files CaDynamics_E2.mod or cal_mig.mod, ica variable
435  * is printed earlier in the list but other variables like cai, cao don't appear
436  * in same order.
437  *
438  * Finally we create 4th list:
439  * - variables with assigned property and not in the previous 3 lists
440  *
441  * We now print the variables in following order:
442  *
443  * - List 1 i.e. range + parameter variables are printed first
444  * - List 3 i.e. range + assigned variables are printed next
445  * - List 2 i.e. range + state variables are printed next
446  * - List 4 i.e. assigned and ion variables not present in the previous 3 lists
447  *
448  * NOTE:
449  * - State variables also have the property `assigned_definition` but these variables
450  * are not from ASSIGNED block.
451  * - Variable can not be range as well as state, it's redeclaration error
452  * - Variable can be parameter as well as range. Without range, parameter
453  * is considered as global variable i.e. one value for all instances.
454  * - If a variable is only defined as RANGE and not in assigned or parameter
455  * or state block then it's not printed.
456  * - Note that a variable property is different than the variable type. For example,
457  * if variable has range property, it doesn't mean the variable is declared as RANGE.
458  * Other variables like STATE and ASSIGNED block variables also get range property
459  * without being explicitly declared as RANGE in the mod file.
460  * - Also, there is difference between declaration order vs. definition order. For
461  * example, POINTER variable in NEURON block is just declaration and doesn't
462  * determine the order in which they will get printed. Below we query symbol table
463  * and order all instance variables into certain order.
464  */
466  /// comparator to decide the order based on definition
467  auto comparator = [](const SymbolType& first, const SymbolType& second) -> bool {
468  return first->get_definition_order() < second->get_definition_order();
469  };
470 
471  /// from symbols vector `vars`, remove all ion variables which are not ionic currents
472  auto remove_non_ioncur_vars = [](SymbolVectorType& vars, const CodegenInfo& info) -> void {
473  vars.erase(std::remove_if(vars.begin(),
474  vars.end(),
475  [&](SymbolType& s) {
476  return info.is_ion_variable(s->get_name()) &&
477  !info.is_ionic_current(s->get_name());
478  }),
479  vars.end());
480  };
481 
482  /// if `secondary` vector contains any symbol that exist in the `primary` then remove it
483  auto remove_var_exist = [](SymbolVectorType& primary, SymbolVectorType& secondary) -> void {
484  secondary.erase(std::remove_if(secondary.begin(),
485  secondary.end(),
486  [&primary](const SymbolType& tosearch) {
487  return std::find_if(primary.begin(),
488  primary.end(),
489  // compare by symbol name
490  [&tosearch](
491  const SymbolType& symbol) {
492  return tosearch->get_name() ==
493  symbol->get_name();
494  }) != primary.end();
495  }),
496  secondary.end());
497  };
498 
499  /**
500  * First come parameters which are range variables.
501  */
502  // clang-format off
503  auto with = NmodlType::range_var
504  | NmodlType::param_assign;
505  auto without = NmodlType::global_var
506  | NmodlType::pointer_var
507  | NmodlType::bbcore_pointer_var
508  | NmodlType::state_var;
509 
510  // clang-format on
512  remove_non_ioncur_vars(info.range_parameter_vars, info);
513  std::sort(info.range_parameter_vars.begin(), info.range_parameter_vars.end(), comparator);
514 
515  /**
516  * Second come assigned variables which are range variables.
517  */
518  // clang-format off
519  with = NmodlType::range_var
520  | NmodlType::assigned_definition;
521  without = NmodlType::global_var
522  | NmodlType::pointer_var
523  | NmodlType::bbcore_pointer_var
524  | NmodlType::state_var
525  | NmodlType::param_assign;
526 
527  // clang-format on
528  info.range_assigned_vars = psymtab->get_variables(with, without);
529  remove_non_ioncur_vars(info.range_assigned_vars, info);
530  std::sort(info.range_assigned_vars.begin(), info.range_assigned_vars.end(), comparator);
531 
532 
533  /**
534  * Third come state variables. All state variables are kind of range by default.
535  * Note that some mod files like CaDynamics_E2.mod use cai as state variable which
536  * appear in USEION read/write list. These variables are not considered in this
537  * variables because non ionic-current variables are removed and printed later.
538  */
539  // clang-format off
540  with = NmodlType::state_var;
541  without = NmodlType::global_var
542  | NmodlType::pointer_var
543  | NmodlType::bbcore_pointer_var;
544 
545  // clang-format on
546  info.state_vars = psymtab->get_variables(with, without);
547  std::sort(info.state_vars.begin(), info.state_vars.end(), comparator);
548 
549  /// range_state_vars is copy of state variables but without non ionic current variables
551  remove_non_ioncur_vars(info.range_state_vars, info);
552 
553  /// Remaining variables are assigned and ion variables which are not in the previous 3 lists
554 
555  // clang-format off
556  with = NmodlType::assigned_definition
557  | NmodlType::read_ion_var
558  | NmodlType::write_ion_var;
559  without = NmodlType::global_var
560  | NmodlType::pointer_var
561  | NmodlType::bbcore_pointer_var
562  | NmodlType::extern_neuron_variable;
563  // clang-format on
564  const auto& variables = psymtab->get_variables_with_properties(with, false);
565  for (const auto& variable: variables) {
566  if (!variable->has_any_property(without)) {
567  info.assigned_vars.push_back(variable);
568  }
569  }
570 
571  /// make sure that variables already present in previous lists
572  /// are removed to avoid any duplication
573  remove_var_exist(info.range_parameter_vars, info.assigned_vars);
574  remove_var_exist(info.range_assigned_vars, info.assigned_vars);
575  remove_var_exist(info.range_state_vars, info.assigned_vars);
576 
577  /// sort variables with their definition order
578  std::sort(info.assigned_vars.begin(), info.assigned_vars.end(), comparator);
579 }
580 
581 
583  auto property = NmodlType::table_statement_var;
585  property = NmodlType::table_assigned_var;
587 }
588 
590  // TODO: it would be nicer not to have this hardcoded list
591  using pair = std::pair<const char*, const char*>;
592  for (const auto& [var, type]: {pair{naming::CELSIUS_VARIABLE, "double"},
593  pair{"secondorder", "int"},
594  pair{"pi", "double"}}) {
595  auto sym = psymtab->lookup(var);
596  if (sym && (sym->get_read_count() || sym->get_write_count() ||
598  info.neuron_global_variables.emplace_back(std::move(sym), type);
599  }
600  }
601 }
602 
603 
605  const auto& type = node.get_type()->get_node_name();
606  if (type == naming::POINT_PROCESS) {
607  info.point_process = true;
608  }
609  if (type == naming::ARTIFICIAL_CELL) {
610  info.artificial_cell = true;
611  info.point_process = true;
612  }
613  info.mod_suffix = node.get_node_name();
614 }
615 
617  info.electrode_current = true;
618 }
619 
620 
624  } else {
626  }
627  node.visit_children(*this);
628 }
629 
630 
633  node.visit_children(*this);
634 }
635 
636 
639  node.visit_children(*this);
640 }
641 
642 
646  info.num_net_receive_parameters = static_cast<int>(node.get_parameters().size());
647  node.visit_children(*this);
648  under_net_receive_block = false;
649 }
650 
651 
653  under_derivative_block = true;
654  node.visit_children(*this);
655  under_derivative_block = false;
656 }
657 
659  info.derivimplicit_callbacks.push_back(&node);
660 }
661 
662 
664  under_breakpoint_block = true;
666  node.visit_children(*this);
667  under_breakpoint_block = false;
668 }
669 
670 
673  node.visit_children(*this);
674 }
675 
677  info.cvode_block = &node;
678  node.visit_children(*this);
679 }
680 
681 
683  info.procedures.push_back(&node);
684  node.visit_children(*this);
685  if (table_statement_used) {
686  table_statement_used = false;
687  info.functions_with_table.push_back(&node);
688  }
689 }
690 
691 
693  info.functions.push_back(&node);
694  node.visit_children(*this);
695  if (table_statement_used) {
696  table_statement_used = false;
697  info.functions_with_table.push_back(&node);
698  }
699 }
700 
701 
703  info.function_tables.push_back(&node);
704 }
705 
706 
710  // Avoid extra declaration for `functor` corresponding to the DERIVATIVE block which is not
711  // printed to the generated CPP file
712  if (!under_derivative_block) {
713  const auto new_unique_functor_name = "functor_" + info.mod_suffix + "_" +
715  info.functor_names[&node] = new_unique_functor_name;
716  }
717  node.visit_children(*this);
718 }
719 
723  node.visit_children(*this);
724 }
725 
727  info.matexp_blocks.push_back(&node);
728  node.visit_children(*this);
729 }
730 
732  auto name = node.get_node_name();
733  if (name == naming::NET_SEND_METHOD) {
734  info.net_send_used = true;
735  }
737  info.net_event_used = true;
738  }
739 }
740 
741 
743  const auto& ion = node.get_ion();
744  const auto& variable = node.get_conductance();
745  std::string ion_name;
746  if (ion) {
747  ion_name = ion->get_node_name();
748  }
749  info.conductances.push_back({ion_name, variable->get_node_name()});
750 }
751 
752 
753 /**
754  * Visit statement block and find prime symbols appear in derivative block
755  *
756  * Equation statements in derivative block has prime on the lhs. The order
757  * of primes could be different that declaration state block. Also, not all
758  * state variables need to appear in equation block. In this case, we want
759  * to find out the the primes in the order of equation definition. This is
760  * just to keep the same order as neuron implementation.
761  *
762  * The primes are already solved and replaced by Dstate or name. And hence
763  * we need to check if the lhs variable is derived from prime name. If it's
764  * Dstate then we have to lookup state to find out corresponding symbol. This
765  * is because prime_variables_by_order should contain state variable name and
766  * not the one replaced by solver pass.
767  *
768  * \todo AST can have duplicate DERIVATIVE blocks if a mod file uses SOLVE
769  * statements in its INITIAL block (e.g. in case of kinetic schemes using
770  * `STEADYSTATE sparse` solver). Such duplicated DERIVATIVE blocks could
771  * be removed by `SolveBlockVisitor`, or we have to avoid visiting them
772  * here. See e.g. SH_na8st.mod test and original reduced_dentate .mod.
773  */
775  const auto& statements = node.get_statements();
776  for (auto& statement: statements) {
777  statement->accept(*this);
779  (assign_lhs->is_name() || assign_lhs->is_var_name())) {
780  auto name = assign_lhs->get_node_name();
781  auto symbol = psymtab->lookup(name);
782  if (symbol != nullptr) {
783  auto is_prime = symbol->has_any_property(NmodlType::prime_name);
784  auto from_state = symbol->has_any_status(Status::from_state);
785  if (is_prime || from_state) {
786  if (from_state) {
787  symbol = psymtab->lookup(name.substr(1, name.size()));
788  }
789  // See the \todo note above.
790  if (std::find_if(info.prime_variables_by_order.begin(),
792  [&](auto const& sym) {
793  return sym->get_name() == symbol->get_name();
794  }) == info.prime_variables_by_order.end()) {
795  info.prime_variables_by_order.push_back(symbol);
797  }
798  }
799  }
800  }
801  assign_lhs = nullptr;
802  }
803 }
804 
806  info.factor_definitions.push_back(&node);
807 }
808 
809 
811  if (node.get_op().eval() == "=") {
812  assign_lhs = node.get_lhs();
813  }
814  node.get_lhs()->accept(*this);
815  node.get_rhs()->accept(*this);
816 }
817 
818 
820  info.bbcore_pointer_used = true;
821 }
822 
824  info.declared_thread_safe = true;
825 }
826 
827 
829  info.watch_count++;
830 }
831 
832 
834  info.watch_statements.push_back(&node);
835  node.visit_children(*this);
836 }
837 
838 
840  info.for_netcon_used = true;
841 }
842 
843 
845  info.table_count++;
846  table_statement_used = true;
847 }
848 
849 
851  psymtab = node.get_symbol_table();
852  auto blocks = node.get_blocks();
853  for (auto& block: blocks) {
854  info.top_blocks.push_back(block.get());
855  if (block->is_verbatim()) {
856  info.top_verbatim_blocks.push_back(block.get());
857  }
858  }
859  node.visit_children(*this);
860  find_ion_variables(node); // Keep this before find_*_range_variables()
866 }
867 
868 
870  node.accept(*this);
871  return info;
872 }
873 
875  info.vectorize = true;
876 }
877 
879  info.vectorize = true;
880 }
881 
883  info.vectorize = false;
884 }
885 
887  info.changed_dt = node.get_value()->eval();
888 }
889 
890 /// visit verbatim block and find all symbols used
892  const auto& text = node.get_statement()->eval();
893  // use C parser to get all tokens
895  driver.scan_string(text);
896  const auto& tokens = driver.all_tokens();
897 
898  // check if the token exist in the symbol table
899  for (auto& token: tokens) {
900  auto symbol = psymtab->lookup(token);
901  if (symbol != nullptr) {
902  info.variables_in_verbatim.insert(token);
903  }
904  }
905 }
906 
908  info.before_after_blocks.push_back(&node);
909 }
910 
912  info.before_after_blocks.push_back(&node);
913 }
914 
915 static std::shared_ptr<ast::Compartment> find_compartment(
917  const std::string& var_name) {
918  const auto& compartment_block = node.get_compartment_statements();
919  for (const auto& stmt: compartment_block->get_statements()) {
920  auto comp = std::dynamic_pointer_cast<ast::Compartment>(stmt);
921 
922  auto species = comp->get_species();
923  auto it = std::find_if(species.begin(), species.end(), [&var_name](auto var) {
924  return var->get_node_name() == var_name;
925  });
926 
927  if (it != species.end()) {
928  return comp;
929  }
930  }
931 
932  return nullptr;
933 }
934 
937  auto longitudinal_diffusion_block = node.get_longitudinal_diffusion_statements();
938  for (auto stmt: longitudinal_diffusion_block->get_statements()) {
939  auto diffusion = std::dynamic_pointer_cast<ast::LonDiffuse>(stmt);
940  auto rate_index_name = diffusion->get_index_name();
941  auto rate_expr = diffusion->get_rate();
942  auto species = diffusion->get_species();
943 
944  auto process_compartment = [](const std::shared_ptr<ast::Compartment>& compartment)
945  -> std::pair<std::shared_ptr<ast::Name>, std::shared_ptr<ast::Expression>> {
946  std::shared_ptr<ast::Expression> volume_expr;
947  std::shared_ptr<ast::Name> volume_index_name;
948  if (!compartment) {
949  volume_index_name = nullptr;
950  volume_expr = std::make_shared<ast::Double>("1.0");
951  } else {
952  volume_index_name = compartment->get_index_name();
953  volume_expr = std::shared_ptr<ast::Expression>(compartment->get_volume()->clone());
954  }
955  return {std::move(volume_index_name), std::move(volume_expr)};
956  };
957 
958  for (auto var: species) {
959  std::string state_name = var->get_value()->get_value();
960  auto compartment = find_compartment(node, state_name);
961  auto [volume_index_name, volume_expr] = process_compartment(compartment);
962 
964  {state_name,
965  LongitudinalDiffusionInfo(volume_index_name,
966  std::shared_ptr<ast::Expression>(volume_expr),
967  rate_index_name,
968  std::shared_ptr<ast::Expression>(rate_expr->clone()))});
969  }
970  }
971 }
972 
973 } // namespace codegen
974 } // namespace nmodl
Auto generated AST classes declaration.
Represents a AFTER block in NMODL.
Definition: after_block.hpp:51
Represents BBCOREPOINTER statement in NMODL.
Represents a BEFORE block in NMODL.
Represents binary expression in the NMODL.
Represents a BREAKPOINT block in NMODL.
Represents CONDUCTANCE statement in NMODL.
Represents a CONSTRUCTOR block in the NMODL.
Represents a block used for variable timestep integration (CVODE) of DERIVATIVE blocks.
Definition: cvode_block.hpp:38
Represents DERIVATIVE block in the NMODL.
Represent a callback to NEURON's derivimplicit solver.
Represents a DESTRUCTOR block in the NMODL.
Represent linear solver solution block based on Eigen.
Represent newton solver solution block based on Eigen.
Represents ELECTRODE_CURRENT variables statement in NMODL.
Represents a INITIAL block in the NMODL.
Represents LINEAR block in the NMODL.
Extracts information required for LONGITUDINAL_DIFFUSION for each KINETIC block.
Represent matexp solver solution block based on Eigen.
Represents NONLINEAR block in the NMODL.
Represents the coreneuron nrn_state callback function.
Represents top level AST node for whole NMODL input.
Definition: program.hpp:39
Represents block encapsulating list of statements.
Represents SUFFIX statement in NMODL.
Definition: suffix.hpp:38
Represents TABLE statement in NMODL.
Represents THREADSAFE statement in NMODL.
Definition: thread_safe.hpp:38
Statement to indicate a change in timestep in a given block.
Definition: update_dt.hpp:38
Represents a C code block.
Definition: verbatim.hpp:38
Represent WATCH statement in NMODL.
static void sort_with_mod2c_symbol_order(std::vector< SymbolType > &symbols)
How symbols are stored in NEURON? See notes written in markdown file.
void visit_derivimplicit_callback(const ast::DerivimplicitCallback &node) override
visit node of type ast::DerivimplicitCallback
void visit_derivative_block(const ast::DerivativeBlock &node) override
visit node of type ast::DerivativeBlock
void check_cvode_codegen(const ast::Program &node)
Find whether or not we need to emit CVODE-related code for NEURON Notes: we generate CVODE-related co...
void visit_non_linear_block(const ast::NonLinearBlock &node) override
visit node of type ast::NonLinearBlock
void visit_breakpoint_block(const ast::BreakpointBlock &node) override
visit node of type ast::BreakpointBlock
void visit_thread_safe(const ast::ThreadSafe &) override
visit node of type ast::ThreadSafe
bool table_statement_used
table statement found
void visit_before_block(const ast::BeforeBlock &node) override
visit node of type ast::BeforeBlock
void visit_update_dt(const ast::UpdateDt &node) override
visit node of type ast::UpdateDt
void visit_discrete_block(const ast::DiscreteBlock &node) override
visit node of type ast::DiscreteBlock
void visit_nrn_state_block(const ast::NrnStateBlock &node) override
visit node of type ast::NrnStateBlock
void visit_function_table_block(const ast::FunctionTableBlock &node) override
visit node of type ast::FunctionTableBlock
void visit_function_call(const ast::FunctionCall &node) override
visit node of type ast::FunctionCall
bool under_derivative_block
if visiting derivative block
void visit_linear_block(const ast::LinearBlock &node) override
visit node of type ast::LinearBlock
void visit_longitudinal_diffusion_block(const ast::LongitudinalDiffusionBlock &node) override
visit node of type ast::LongitudinalDiffusionBlock
void find_non_range_variables(const ast::Program &node)
Find non-range variables i.e.
void visit_eigen_linear_solver_block(const ast::EigenLinearSolverBlock &node) override
visit node of type ast::EigenLinearSolverBlock
void visit_conductance_hint(const ast::ConductanceHint &node) override
visit node of type ast::ConductanceHint
bool under_breakpoint_block
if visiting breakpoint block
void visit_after_block(const ast::AfterBlock &node) override
visit node of type ast::AfterBlock
std::shared_ptr< ast::Expression > assign_lhs
lhs of assignment in derivative block
void visit_matexp_block(const ast::MatexpBlock &node) override
visit node of type ast::MatexpBlock
void find_ion_variables(const ast::Program &node)
Find all ions used in mod file.
void visit_table_statement(const ast::TableStatement &node) override
visit node of type ast::TableStatement
void visit_suffix(const ast::Suffix &node) override
visit node of type ast::Suffix
void visit_verbatim(const ast::Verbatim &node) override
visit verbatim block and find all symbols used
void visit_constructor_block(const ast::ConstructorBlock &node) override
visit node of type ast::ConstructorBlock
void visit_eigen_newton_solver_block(const ast::EigenNewtonSolverBlock &node) override
visit node of type ast::EigenNewtonSolverBlock
codegen::CodegenInfo analyze(const ast::Program &node)
run visitor and return information for code generation
void visit_electrode_current(const ast::ElectrodeCurrent &node) override
visit node of type ast::ElectrodeCurrent
void visit_factor_def(const ast::FactorDef &node) override
visit node of type ast::FactorDef
void visit_statement_block(const ast::StatementBlock &node) override
Visit statement block and find prime symbols appear in derivative block.
void visit_net_receive_block(const ast::NetReceiveBlock &node) override
visit node of type ast::NetReceiveBlock
std::shared_ptr< symtab::Symbol > SymbolType
void visit_watch_statement(const ast::WatchStatement &node) override
visit node of type ast::WatchStatement
std::vector< std::shared_ptr< symtab::Symbol > > SymbolVectorType
void find_range_variables()
Find range variables i.e.
void visit_watch(const ast::Watch &node) override
visit node of type ast::Watch
symtab::SymbolTable * psymtab
symbol table for the program
void visit_program(const ast::Program &node) override
visit node of type ast::Program
codegen::CodegenInfo info
holds all codegen related information
void visit_function_block(const ast::FunctionBlock &node) override
visit node of type ast::FunctionBlock
void visit_binary_expression(const ast::BinaryExpression &node) override
visit node of type ast::BinaryExpression
void visit_for_netcon(const ast::ForNetcon &node) override
visit node of type ast::ForNetcon
void visit_procedure_block(const ast::ProcedureBlock &node) override
visit node of type ast::ProcedureBlock
bool under_net_receive_block
if visiting net receive block
void visit_bbcore_pointer(const ast::BbcorePointer &node) override
visit node of type ast::BbcorePointer
void visit_initial_block(const ast::InitialBlock &node) override
visit node of type ast::InitialBlock
void visit_destructor_block(const ast::DestructorBlock &node) override
visit node of type ast::DestructorBlock
void visit_cvode_block(const ast::CvodeBlock &node) override
visit node of type ast::CvodeBlock
Information required to print LONGITUDINAL_DIFFUSION callbacks.
Class that binds all pieces together for parsing C verbatim blocks.
Definition: c11_driver.hpp:37
std::vector< std::shared_ptr< Symbol > > get_variables(syminfo::NmodlType with=syminfo::NmodlType::empty, syminfo::NmodlType without=syminfo::NmodlType::empty) const
get variables
std::vector< std::shared_ptr< Symbol > > get_variables_with_properties(syminfo::NmodlType properties, bool all=false) const
get variables with properties
std::shared_ptr< Symbol > lookup(const std::string &name) const
check if symbol with given name exist in the current table (but not in parents)
Helper visitor to gather AST information to help code generation.
Auto generated AST classes declaration.
@ VAR_NAME
type of ast::VarName
@ ASSIGNED_DEFINITION
type of ast::AssignedDefinition
@ PARAM_ASSIGN
type of ast::ParamAssign
#define assert(ex)
Definition: hocassrt.h:24
double var(InputIterator begin, InputIterator end)
Definition: ivocvect.h:108
const char * name
Definition: init.cpp:16
void move(Item *q1, Item *q2, Item *q3)
Definition: list.cpp:200
std::string to_string(EnumT e, const std::array< std::pair< EnumT, std::string_view >, N > &mapping, const std::string_view enum_name)
Converts an enum value to its corresponding string representation.
Definition: nrnreport.hpp:102
static constexpr char AREA_VARIABLE[]
similar to node_area but user can explicitly declare it as area
static constexpr char POINT_PROCESS[]
point process keyword in nmodl
static constexpr char NET_EVENT_METHOD[]
net_event function call in nmodl
static constexpr char DIAM_VARIABLE[]
inbuilt neuron variable for diameter of the compartment
static constexpr char ARTIFICIAL_CELL[]
artificial cell keyword in nmodl
static constexpr char CVODE_T_METHOD[]
cvode_t method in nmodl
static constexpr char NET_SEND_METHOD[]
net_send function call in nmodl
static constexpr char CELSIUS_VARIABLE[]
global temperature variable
static constexpr char CVODE_T_V_METHOD[]
cvode_t_v method in nmodl
static constexpr char AFTER_CVODE_METHOD[]
cvode method in nmodl
static bool check_procedure_has_cvode(const std::shared_ptr< const ast::Ast > &solve_node, const std::shared_ptr< const ast::Ast > &procedure_node)
Check whether a given SOLVE block solves a PROCEDURE with any of the CVode methods.
static std::shared_ptr< ast::Compartment > find_compartment(const ast::LongitudinalDiffusionBlock &node, const std::string &var_name)
static bool node_name_matches(const std::shared_ptr< const ast::Ast > &node, const std::string &match)
Status
state during various compiler passes
NmodlType
NMODL variable properties.
encapsulates code generation backend implementations
Definition: ast_common.hpp:26
std::vector< std::shared_ptr< const ast::Ast > > collect_nodes(const ast::Ast &node, const std::vector< ast::AstNodeType > &types)
traverse node recursively and collect nodes of given types
logger_type logger
Definition: logger.cpp:34
static List * info
static int using_cvode
Definition: nocpout.cpp:2682
static Node * node(Object *)
Definition: netcvode.cpp:291
s
Definition: multisend.cpp:521
short type
Definition: cabvars.h:10
#define text
Definition: plot.cpp:60
unsigned char diffusion
Definition: rxd.cpp:52
static uint32_t value
Definition: scoprand.cpp:25
Represent information collected from AST for code generation.
int thread_var_data_size
sum of length of thread promoted variables
std::vector< SymbolType > range_assigned_vars
range variables which are assigned variables as well
std::vector< std::pair< SymbolType, std::string > > neuron_global_variables
[Core]NEURON global variables used (e.g. celsius) and their types
bool bbcore_pointer_used
if bbcore pointer is used
std::vector< const ast::FactorDef * > factor_definitions
all factors defined in the mod file
std::vector< SymbolType > assigned_vars
remaining assigned variables
const ast::CvodeBlock * cvode_block
the CVODE block
std::vector< SymbolType > range_state_vars
state variables excluding such useion read/write variables that are not ionic currents.
bool is_ion_variable(const std::string &name) const noexcept
if either read or write variable
int num_equations
number of equations (i.e.
bool artificial_cell
if mod file is artificial cell
std::vector< const ast::FunctionTableBlock * > function_tables
all functions tables defined in the mod file
std::vector< SymbolType > pointer_variables
pointer or bbcore pointer variables
bool point_process
if mod file is point process
bool electrode_current
if electrode current specified
std::vector< ast::Node * > top_blocks
all top level global blocks
bool diam_used
if diam is used
std::unordered_map< const ast::EigenNewtonSolverBlock *, std::string > functor_names
unique functor names for all the EigenNewtonSolverBlock s
std::vector< SymbolType > global_variables
global variables
int thread_data_index
thread_data_index indicates number of threads being allocated.
bool vectorize
true if mod file is vectorizable (which should be always true for coreneuron) But there are some bloc...
bool net_event_used
if net_event function is used
const ast::BreakpointBlock * breakpoint_node
derivative block
bool net_send_used
if net_send function is used
std::vector< const ast::WatchStatement * > watch_statements
all watch statements
bool declared_thread_safe
A mod file can be declared to be thread safe using the keyword THREADSAFE.
bool thread_callback_register
if thread thread call back routines need to register
int num_primes
number of primes (all state variables not necessary to be prime)
const ast::DestructorBlock * destructor_node
destructor block only for point process
std::vector< SymbolType > external_variables
external variables
std::vector< const ast::ProcedureBlock * > procedures
all procedures defined in the mod file
const ast::NrnStateBlock * nrn_state_block
nrn_state block
const ast::InitialBlock * net_receive_initial_node
initial block within net receive block
std::vector< const ast::MatexpBlock * > matexp_blocks
all matexp solver blocks
std::vector< SymbolType > thread_variables
thread variables (e.g. global variables promoted to thread)
bool eigen_newton_solver_exist
true if eigen newton solver is used
std::vector< SymbolType > constant_variables
constant variables
std::vector< ast::Node * > top_verbatim_blocks
all top level verbatim blocks
int num_net_receive_parameters
number of arguments to net_receive block
std::vector< const ast::DerivimplicitCallback * > derivimplicit_callbacks
derivimplicit callbacks need to be emited
bool is_ionic_current(const std::string &name) const noexcept
if given variable is a ionic current
std::vector< const ast::Block * > before_after_blocks
all before after blocks
bool derivimplicit_used() const
if legacy derivimplicit solver from coreneuron to be used
int table_count
number of table statements
std::vector< SymbolType > table_statement_variables
table variables
std::vector< const ast::Block * > functions_with_table
function or procedures with table statement
std::vector< SymbolType > random_variables
RANDOM variables.
std::vector< SymbolType > top_local_variables
local variables in the global scope
const ast::NetReceiveBlock * net_receive_node
net receive block for point process
std::vector< SymbolType > range_parameter_vars
range variables which are parameter as well
std::vector< const ast::FunctionBlock * > functions
all functions defined in the mod file
int top_local_thread_id
Top local variables are those local variables that appear in global scope.
int derivimplicit_var_thread_id
thread id for derivimplicit variables
bool eigen_linear_solver_exist
true if eigen linear solver is used
int thread_var_thread_id
thread id for thread promoted variables
int primes_size
sum of length of all prime variables
std::map< std::string, LongitudinalDiffusionInfo > longitudinal_diffusion_info
for each state, the information needed to print the callbacks.
int watch_count
number of watch expressions
std::vector< Conductance > conductances
represent conductance statements used in mod file
std::vector< SymbolType > prime_variables_by_order
this is the order in which they appear in derivative block this is required while printing them in in...
bool area_used
if area is used
std::string mod_suffix
name of the suffix
std::string changed_dt
updated dt to use with steadystate solver (in initial block) empty string means no change in dt
int derivimplicit_list_num
slist/dlist id for derivimplicit block
bool for_netcon_used
if for_netcon is used
const ast::ConstructorBlock * constructor_node
constructor block
std::vector< SymbolType > state_vars
all state variables
std::unordered_set< std::string > variables_in_verbatim
all variables/symbols used in the verbatim block
std::vector< SymbolType > table_assigned_variables
const ast::InitialBlock * initial_node
initial block
int top_local_thread_size
total length of all top local variables
Represent ions used in mod file.
bool need_style
if style semantic needed
bool is_intra_cell_conc(const std::string &text) const
Check if variable name is internal cell concentration.
std::optional< double > valence
ion valence
bool is_extra_cell_conc(const std::string &text) const
Check if variable name is external cell concentration.
std::vector< std::string > reads
ion variables that are being read
std::vector< std::string > writes
ion variables that are being written
nmodl::parser::UnitDriver driver
Definition: parser.cpp:28
Utility functions for visitors implementation.