11 #include "config/config.h"
27 using visitor::DefUseAnalyzeVisitor;
29 using visitor::RenameVisitor;
30 using visitor::SymtabVisitor;
40 return optimize_ion_variable_copies() &&
info.ion_has_write_variable();
44 std::vector<std::string> variables;
45 for (
const auto&
param: params) {
46 variables.push_back(std::get<3>(
param));
48 return fmt::format(
"{}", fmt::join(variables,
", "));
53 std::vector<std::string> variables;
54 for (
const auto&
param: params) {
55 variables.push_back(fmt::format(
"{}{} {}{}",
61 return fmt::format(
"{}", fmt::join(variables,
", "));
67 auto parameters =
node->get_parameters();
68 return std::any_of(parameters.begin(),
70 [&
name](
const decltype(*parameters.begin()) arg) {
71 return arg->get_node_name() == name;
77 return "update_table_" + method_name(block_name);
89 if (
node.is_unit_state()
90 ||
node.is_line_comment()
91 ||
node.is_block_comment()
92 ||
node.is_solve_block()
93 ||
node.is_conductance_hint()
94 ||
node.is_table_statement()) {
98 if (
node.is_expression_statement()) {
100 if (expression->is_solve_block()) {
103 if (expression->is_initial_block()) {
112 if (net_receive_required() && !
info.artificial_cell) {
113 if (
info.net_event_used ||
info.net_send_used ||
info.is_watch_used()) {
122 return info.point_process && !
info.artificial_cell &&
info.net_receive_node !=
nullptr;
127 if (
info.artificial_cell) {
130 return info.nrn_state_block !=
nullptr || breakpoint_exist();
135 return info.breakpoint_node !=
nullptr && !
info.currents.empty();
140 return info.net_receive_node !=
nullptr;
145 return info.breakpoint_node !=
nullptr;
150 return net_receive_exist();
165 const auto&
function = program_symtab->lookup(
name);
166 auto properties = NmodlType::function_block | NmodlType::procedure_block;
167 return function &&
function->has_any_property(properties);
171 auto it = std::find_if(
info.function_tables.begin(),
172 info.function_tables.end(),
173 [
name](
const auto&
node) { return node->get_node_name() == name; });
174 return it !=
info.function_tables.end();
179 for (
const auto&
var: codegen_float_variables) {
180 n_floats +=
var->get_length();
188 const auto count_semantics = [](
int sum,
const IndexSemantics& sem) {
return sum += sem.size; };
189 return std::accumulate(
info.semantics.begin(),
info.semantics.end(), 0, count_semantics);
219 if (
node.is_if_statement()
220 ||
node.is_else_if_statement()
221 ||
node.is_else_statement()
222 ||
node.is_from_statement()
223 ||
node.is_verbatim()
224 ||
node.is_conductance_hint()
225 ||
node.is_while_statement()
226 ||
node.is_protect_statement()
227 ||
node.is_mutex_lock()
228 ||
node.is_mutex_unlock()) {
231 if (
node.is_expression_statement()) {
233 if (expression->is_statement_block()
234 || expression->is_eigen_newton_solver_block()
235 || expression->is_eigen_linear_solver_block()
236 || expression->is_solution_expression()
237 || expression->is_for_netcon()) {
253 if (optimize_ion_variable_copies()) {
254 return ion_read_statements_optimized(
type);
256 std::vector<std::string> statements;
257 for (
const auto& ion:
info.ions) {
258 auto name = ion.name;
259 for (
const auto&
var: ion.reads) {
260 auto const iter =
std::find(ion.implicit_reads.begin(), ion.implicit_reads.end(),
var);
261 if (iter != ion.implicit_reads.end()) {
264 auto variable_names = read_ion_variable_name(
var);
265 auto first = get_variable_name(variable_names.first);
266 auto second = get_variable_name(variable_names.second);
267 statements.push_back(fmt::format(
"{} = {};", first, second));
269 for (
const auto&
var: ion.writes) {
270 if (ion.is_ionic_conc(
var)) {
271 auto variables = read_ion_variable_name(
var);
272 auto first = get_variable_name(variables.first);
273 auto second = get_variable_name(variables.second);
274 statements.push_back(fmt::format(
"{} = {};", first, second));
283 std::vector<std::string> statements;
284 for (
const auto& ion:
info.ions) {
285 for (
const auto&
var: ion.writes) {
286 if (ion.is_ionic_conc(
var)) {
287 auto variables = read_ion_variable_name(
var);
288 auto first =
"ionvar." + variables.first;
289 const auto& second = get_variable_name(variables.second);
290 statements.push_back(fmt::format(
"{} = {};", first, second));
299 std::vector<ShadowUseStatement> statements;
300 for (
const auto& ion:
info.ions) {
301 std::string concentration;
302 for (
const auto&
var: ion.writes) {
303 auto variable_names = write_ion_variable_name(
var);
304 if (ion.is_ionic_current(
var)) {
307 auto lhs = variable_names.first;
310 if (
info.point_process) {
312 rhs += fmt::format(
"*(1.e2/{})",
area);
317 if (!ion.is_rev_potential(
var)) {
320 auto lhs = variable_names.first;
322 auto rhs = get_variable_name(variable_names.second);
328 append_conc_write_statements(statements, ion, concentration);
343 if (statement.
op.empty() && statement.
rhs.empty()) {
344 auto text = statement.
lhs +
";";
349 auto lhs = get_variable_name(statement.
lhs);
350 auto text = fmt::format(
"{} {} {};", lhs, statement.
op, statement.
rhs);
364 auto breakpoint =
info.breakpoint_node;
365 if (breakpoint ==
nullptr) {
368 auto symtab = breakpoint->get_statement_block()->get_symbol_table();
369 auto variables = symtab->get_variables_with_properties(NmodlType::local_var);
370 for (
const auto&
var: variables) {
371 auto renamed_name =
var->get_name();
372 auto original_name =
var->get_original_name();
373 if (
current == original_name) {
402 std::vector<std::shared_ptr<const ast::Ast>> nodes;
410 printer->add_line(
"#pragma omp simd");
411 printer->add_line(
"#pragma ivdep");
422 if (ion_variable_struct_required()) {
423 if (
info.is_ion_read_variable(
name)) {
426 if (
info.is_ion_write_variable(
name)) {
438 const std::string&
name) {
444 const std::string&
name) {
462 printer->add_line(
"/*********************************************************");
463 printer->add_line(
"Model Name : ",
info.mod_suffix);
464 printer->add_line(
"Filename : ",
info.mod_file,
".mod");
465 printer->add_line(
"NMODL Version : ", nmodl_version());
466 printer->fmt_line(
"Vectorized : {}",
info.vectorize);
467 printer->fmt_line(
"Threadsafe : {}",
info.thread_safe);
468 printer->add_line(
"Simulator : ", simulator_name());
469 printer->add_line(
"Backend : ", backend_name());
470 printer->add_line(
"NMODL Compiler : ", version);
471 printer->add_line(
"*********************************************************/");
476 for (
const auto& f:
info.function_tables) {
477 printer->fmt_line(
"void* _ptable_{}{{}};", f->get_node_name());
478 codegen_global_variables.push_back(make_symbol(
"_ptable_" + f->get_node_name()));
486 printer->fmt_line(
"static_assert(std::is_trivially_copy_constructible_v<{}>);",
488 printer->fmt_line(
"static_assert(std::is_trivially_move_constructible_v<{}>);",
490 printer->fmt_line(
"static_assert(std::is_trivially_copy_assignable_v<{}>);", global_struct());
491 printer->fmt_line(
"static_assert(std::is_trivially_move_assignable_v<{}>);", global_struct());
492 printer->fmt_line(
"static_assert(std::is_trivially_destructible_v<{}>);", global_struct());
497 printer->fmt_line(
"static {} {};", global_struct(), global_struct_instance());
502 const auto&
name =
node.get_node_name();
506 auto get_renamed_random_function =
507 [&](
const std::string&
name) -> std::pair<std::string, bool> {
511 return {
name,
false};
513 auto [function_name, is_random_function] = get_renamed_random_function(
name);
515 if (defined_method(
name)) {
516 function_name = method_name(
name);
520 print_nrn_pointing(
node);
524 if (is_net_send(
name)) {
525 print_net_send_call(
node);
529 if (is_net_move(
name)) {
530 print_net_move_call(
node);
534 if (is_net_event(
name)) {
535 print_net_event_call(
node);
539 if (is_function_table_call(
name)) {
540 print_function_table_call(
node);
544 const auto& arguments =
node.get_arguments();
545 printer->add_text(function_name,
'(');
547 if (defined_method(
name)) {
548 auto internal_args = internal_method_arguments();
549 printer->add_text(internal_args);
550 if (!arguments.empty() && !internal_args.empty()) {
551 printer->add_text(
", ");
555 print_vector_elements(arguments,
", ");
556 printer->add_text(
')');
560 printer->add_text(
"nrn_pointing(&");
561 print_vector_elements(
node.get_arguments(),
", ");
562 printer->add_text(
")");
566 print_function_procedure_helper(
node);
574 std::string return_var;
575 if (
info.function_uses_table(
name)) {
576 return_var =
"ret_f_" +
name;
578 return_var =
"ret_" +
name;
582 auto block =
node.get_statement_block().get();
586 print_function_procedure_helper(
node);
592 const auto&
p =
node.get_parameters();
593 auto [params, table_params] = function_table_parameters(
node);
594 printer->fmt_push_block(
"double {}({})", method_name(
name), get_parameter_str(params));
595 printer->fmt_line(
"double _arg[{}];",
p.size());
596 for (
size_t i = 0;
i <
p.size(); ++
i) {
597 printer->fmt_line(
"_arg[{}] = {};",
i,
p[
i]->get_node_name());
599 printer->fmt_line(
"return hoc_func_table({}, {}, _arg);",
600 get_variable_name(std::string(
"_ptable_" +
name),
true),
602 printer->pop_block();
604 printer->fmt_push_block(
"double table_{}({})",
606 get_parameter_str(table_params));
607 printer->fmt_line(
"hoc_spec_table(&{}, {});",
608 get_variable_name(std::string(
"_ptable_" +
name)),
610 printer->add_line(
"return 0.;");
611 printer->pop_block();
616 printer->add_line(
"#ifndef NRN_PRCELLSTATE");
617 printer->add_line(
"#define NRN_PRCELLSTATE 0");
618 printer->add_line(
"#endif");
623 auto variable_printer = [&](
const std::vector<SymbolType>& variables) {
624 for (
const auto&
v: variables) {
625 auto name =
v->get_name();
626 if (!
info.point_process) {
630 name += fmt::format(
"[{}]",
v->get_length());
632 printer->add_line(add_escape_quote(
name),
",");
636 printer->add_newline(2);
637 printer->add_line(
"/** channel information */");
638 printer->fmt_line(
"static const char *{}[] = {{", get_channel_info_var_name());
639 printer->increase_indent();
640 printer->add_line(add_escape_quote(nmodl_version()),
",");
641 printer->add_line(add_escape_quote(
info.mod_suffix),
",");
642 variable_printer(
info.range_parameter_vars);
643 printer->add_line(
"0,");
644 variable_printer(
info.range_assigned_vars);
645 printer->add_line(
"0,");
646 variable_printer(
info.range_state_vars);
647 printer->add_line(
"0,");
648 variable_printer(
info.pointer_variables);
649 printer->add_line(
"0");
650 printer->decrease_indent();
651 printer->add_line(
"};");
655 printer->fmt_line(
"using namespace {};", namespace_name());
659 printer->add_newline(2);
660 printer->fmt_push_block(
"namespace {}", namespace_name());
665 printer->pop_block();
670 if (
info.top_verbatim_blocks.empty()) {
673 print_namespace_stop();
675 printer->add_newline(2);
676 print_using_namespace();
678 printing_top_verbatim_blocks =
true;
680 for (
const auto& block:
info.top_verbatim_blocks) {
681 printer->add_newline(2);
682 block->accept(*
this);
685 printing_top_verbatim_blocks =
false;
687 print_namespace_start();
700 printer->push_block();
703 const auto& statements =
node.get_statements();
704 for (
const auto& statement: statements) {
705 if (statement_to_skip(*statement)) {
709 if (!statement->is_verbatim() && !statement->is_mutex_lock() &&
710 !statement->is_mutex_unlock() && !statement->is_protect_statement()) {
711 printer->add_indent();
713 statement->accept(*
this);
714 if (need_semicolon(*statement)) {
715 printer->add_text(
';');
717 if (!statement->is_mutex_lock() && !statement->is_mutex_unlock()) {
718 printer->add_newline();
723 printer->pop_block_nl(0);
741 auto model_symbol_table = std::make_shared<symtab::ModelSymbolTable>();
749 auto is_functor_const =
true;
751 for (
const auto& variable: variables) {
752 const auto& chain =
v.analyze(complete_block, variable->get_node_name());
753 is_functor_const = !(chain.eval() ==
DUState::D || chain.eval() == DUState::LD ||
754 chain.eval() == DUState::CD);
755 if (!is_functor_const) {
760 return is_functor_const;
765 for (
const auto& functor_name:
info.functor_names) {
766 printer->add_newline(2);
767 print_functor_definition(*functor_name.first);
774 auto float_type = default_float_data_type();
775 int N =
node.get_n_state_vars()->get_value();
777 const auto functor_name =
info.functor_names[&
node];
778 printer->fmt_push_block(
"struct {}", functor_name);
780 auto params = functor_params();
781 for (
const auto&
param: params) {
782 printer->fmt_line(
"{}{} {};", std::get<0>(
param), std::get<1>(
param), std::get<3>(
param));
785 if (ion_variable_struct_required()) {
786 print_ion_variable();
789 print_statement_block(*
node.get_variable_block(),
false,
false);
790 printer->add_newline();
792 printer->push_block(
"void initialize()");
793 print_statement_block(*
node.get_initialize_block(),
false,
false);
794 printer->pop_block();
795 printer->add_newline();
797 printer->fmt_line(
"{}({})", functor_name, get_parameter_str(params));
798 printer->increase_indent();
799 auto initializers = std::vector<std::string>();
800 for (
const auto&
param: params) {
801 initializers.push_back(fmt::format(
"{0}({0})", std::get<3>(
param)));
804 printer->add_multi_line(
": " + fmt::format(
"{}", fmt::join(initializers,
", ")));
805 printer->decrease_indent();
806 printer->add_line(
"{}");
808 printer->add_indent();
810 const auto& variable_block = *
node.get_variable_block();
811 const auto& functor_block = *
node.get_functor_block();
814 "void operator()(const Eigen::Matrix<{0}, {1}, 1>& nmodl_eigen_xm, Eigen::Matrix<{0}, {1}, "
815 "1>& nmodl_eigen_dxm, Eigen::Matrix<{0}, {1}, "
816 "1>& nmodl_eigen_fm, "
817 "Eigen::Matrix<{0}, {1}, {1}>& nmodl_eigen_jm) {2}",
820 is_functor_const(variable_block, functor_block) ?
"const " :
"");
821 printer->push_block();
822 printer->fmt_line(
"const {}* nmodl_eigen_x = nmodl_eigen_xm.data();", float_type);
823 printer->fmt_line(
"{}* nmodl_eigen_dx = nmodl_eigen_dxm.data();", float_type);
824 printer->fmt_line(
"{}* nmodl_eigen_j = nmodl_eigen_jm.data();", float_type);
825 printer->fmt_line(
"{}* nmodl_eigen_f = nmodl_eigen_fm.data();", float_type);
827 for (
size_t i = 0;
i < N; ++
i) {
829 "nmodl_eigen_dx[{0}] = std::max(1e-6, 0.02*std::fabs(nmodl_eigen_x[{0}]));",
i);
832 print_statement_block(functor_block,
false,
false);
833 printer->pop_block();
834 printer->add_newline();
837 printer->push_block(
"void finalize()");
838 print_statement_block(*
node.get_finalize_block(),
false,
false);
839 printer->pop_block();
841 printer->pop_block(
";");
848 printer->add_multi_line(R
"CODE(
850 nmodl_eigen_jm.computeInverseWithCheck(nmodl_eigen_jm_inv,invertible);
851 nmodl_eigen_xm = nmodl_eigen_jm_inv*nmodl_eigen_fm;
852 if (!invertible) assert(false && "Singular or ill-conditioned matrix (Eigen::inverse)!");
859 printer->add_line(
"if (!nmodl_eigen_jm.IsRowMajor) nmodl_eigen_jm.transposeInPlace();");
862 printer->fmt_line(
"Eigen::Matrix<int, {}, 1> pivot;", N);
863 printer->fmt_line(
"Eigen::Matrix<{0}, {1}, 1> rowmax;", float_type, N);
867 "if (nmodl::crout::Crout<{0}>({1}, nmodl_eigen_jm.data(), pivot.data(), rowmax.data()) "
868 "< 0) assert(false && \"Singular or ill-conditioned matrix (nmodl::crout)!\");",
874 "nmodl::crout::solveCrout<{0}>({1}, nmodl_eigen_jm.data(), nmodl_eigen_fm.data(), "
875 "nmodl_eigen_xm.data(), pivot.data());",
892 if (!
info.factor_definitions.empty()) {
893 printer->add_newline(2);
894 printer->add_line(
"/** constants used in nmodl from UNITS */");
895 for (
const auto& it:
info.factor_definitions) {
896 const std::string format_string =
"static const double {} = {};";
897 printer->fmt_line(format_string, it->get_node_name(), it->get_value()->get_value());
913 if (enable_variable_name_lookup) {
916 printer->add_text(
name);
942 node.visit_children(*
this);
952 throw std::runtime_error(
"PRIME encountered during code generation, ODEs not solved?");
961 const auto& at_index =
node.get_at();
965 printer->add_text(
"@");
966 at_index->accept(*
this);
969 printer->add_text(
"[");
970 printer->add_text(
"static_cast<int>(");
971 index->accept(*
this);
972 printer->add_text(
")");
973 printer->add_text(
"]");
979 node.get_name()->accept(*
this);
980 printer->add_text(
"[");
981 printer->add_text(
"static_cast<int>(");
982 node.get_length()->accept(*
this);
983 printer->add_text(
")");
984 printer->add_text(
"]");
989 printer->add_text(local_var_type(),
' ');
990 print_vector_elements(
node.get_variables(),
", ");
995 printer->add_text(
"if (");
996 node.get_condition()->accept(*
this);
997 printer->add_text(
") ");
998 node.get_statement_block()->accept(*
this);
999 print_vector_elements(
node.get_elseifs(),
"");
1000 const auto& elses =
node.get_elses();
1002 elses->accept(*
this);
1008 printer->add_text(
" else if (");
1009 node.get_condition()->accept(*
this);
1010 printer->add_text(
") ");
1011 node.get_statement_block()->accept(*
this);
1016 printer->add_text(
" else ");
1017 node.visit_children(*
this);
1022 printer->add_text(
"while (");
1023 node.get_condition()->accept(*
this);
1024 printer->add_text(
") ");
1025 node.get_statement_block()->accept(*
this);
1031 const auto& from =
node.get_from();
1032 const auto& to =
node.get_to();
1033 const auto& inc =
node.get_increment();
1034 const auto& block =
node.get_statement_block();
1035 printer->fmt_text(
"for (int {} = ",
name);
1036 from->accept(*
this);
1037 printer->fmt_text(
"; {} <= ",
name);
1040 printer->fmt_text(
"; {} += ",
name);
1043 printer->fmt_text(
"; {}++",
name);
1045 printer->add_text(
") ");
1046 block->accept(*
this);
1051 printer->add_text(
"(");
1052 node.get_expression()->accept(*
this);
1053 printer->add_text(
")");
1058 auto op =
node.get_op().eval();
1059 const auto& lhs =
node.get_lhs();
1060 const auto&
rhs =
node.get_rhs();
1062 printer->add_text(
"pow(");
1064 printer->add_text(
", ");
1066 printer->add_text(
")");
1069 printer->add_text(
" " + op +
" ");
1076 printer->add_text(
node.eval());
1081 printer->add_text(
" " +
node.eval());
1091 print_statement_block(
node);
1096 print_function_call(
node);
1106 printer->fmt_line(
"#pragma omp critical ({})",
info.mod_suffix);
1107 printer->add_indent();
1108 printer->push_block();
1113 printer->pop_block();
1118 auto block =
node.get_node_to_solve().get();
1119 if (block->is_statement_block()) {
1121 print_statement_block(*statement_block,
false,
false);
1123 block->accept(*
this);
1130 printer->add_newline();
1132 auto float_type = default_float_data_type();
1133 int N =
node.get_n_state_vars()->get_value();
1134 printer->fmt_line(
"Eigen::Matrix<{}, {}, 1> nmodl_eigen_xm;", float_type, N);
1135 printer->fmt_line(
"{}* nmodl_eigen_x = nmodl_eigen_xm.data();", float_type);
1137 print_statement_block(*
node.get_setup_x_block(),
false,
false);
1140 printer->add_line(
"// call newton solver");
1141 printer->fmt_line(
"{} newton_functor({});",
1143 get_arg_str(functor_params()));
1144 printer->add_line(
"newton_functor.initialize();");
1146 "int newton_iterations = nmodl::newton::newton_solver(nmodl_eigen_xm, newton_functor);");
1148 "if (newton_iterations < 0) assert(false && \"Newton solver did not converge!\");");
1151 print_statement_block(*
node.get_update_states_block(),
false,
false);
1152 printer->add_line(
"newton_functor.initialize(); // TODO mimic calling F again.");
1153 printer->add_line(
"newton_functor.finalize();");
1158 printer->add_newline();
1160 const std::string float_type = default_float_data_type();
1161 int N =
node.get_n_state_vars()->get_value();
1162 printer->fmt_line(
"Eigen::Matrix<{0}, {1}, 1> nmodl_eigen_xm, nmodl_eigen_fm;", float_type, N);
1163 printer->fmt_line(
"Eigen::Matrix<{0}, {1}, {1}> nmodl_eigen_jm;", float_type, N);
1165 printer->fmt_line(
"Eigen::Matrix<{0}, {1}, {1}> nmodl_eigen_jm_inv;", float_type, N);
1167 printer->fmt_line(
"{}* nmodl_eigen_x = nmodl_eigen_xm.data();", float_type);
1168 printer->fmt_line(
"{}* nmodl_eigen_j = nmodl_eigen_jm.data();", float_type);
1169 printer->fmt_line(
"{}* nmodl_eigen_f = nmodl_eigen_fm.data();", float_type);
1170 print_statement_block(*
node.get_variable_block(),
false,
false);
1171 print_statement_block(*
node.get_initialize_block(),
false,
false);
1172 print_statement_block(*
node.get_setup_x_block(),
false,
false);
1174 printer->add_newline();
1175 print_eigen_linear_solver(float_type, N);
1176 printer->add_newline();
1178 print_statement_block(*
node.get_update_states_block(),
false,
false);
1179 print_statement_block(*
node.get_finalize_block(),
false,
false);
1186 const std::vector<std::string>
states) {
1189 ast::AstNodeType::NAME,
1191 std::vector<int> var_indices;
1192 for (
const auto&
var: vars) {
1193 for (
int state_index = 0; state_index <
states.size(); state_index++) {
1194 if (
states[state_index] ==
var->get_node_name()) {
1195 var_indices.push_back(state_index);
1204 const auto& state_symbols =
info.state_vars;
1205 std::vector<std::string>
states;
1206 for (
const auto& sym: state_symbols) {
1207 states.push_back(sym->get_name());
1210 const std::string vector_type =
"Eigen::Matrix<" + float_type +
", " + n_states +
", 1>";
1211 const std::string matrix_type =
"Eigen::Matrix<" + float_type +
", " + n_states +
", " +
1214 if (
node.get_steadystate()->eval()) {
1215 printer->fmt_line(
"{} nmodl_dt = 24.0 * 60.0 * 60.0 * 1000.0;", float_type);
1217 printer->add_indent();
1218 printer->fmt_text(
"{} nmodl_dt = ", float_type);
1219 auto dt_var =
ast::Name(std::make_shared<ast::String>(
"dt"));
1220 dt_var.accept(*
this);
1221 printer->add_text(
";");
1222 printer->add_newline();
1225 printer->add_newline();
1226 printer->fmt_line(
"{} nmodl_eigen_jm = {}::Zero();", matrix_type, matrix_type);
1227 printer->fmt_line(
"{}* nmodl_eigen_j = nmodl_eigen_jm.data();", float_type);
1229 print_statement_block(*
node.get_jacobian_block(),
false,
false);
1231 printer->fmt_line(
"{} nmodl_eigen_xm;", vector_type);
1232 printer->fmt_line(
"{}* nmodl_eigen_x = nmodl_eigen_xm.data();", float_type);
1234 const auto& conserve_statements =
node.get_conserve();
1235 if (
node.get_steadystate()->eval() && !conserve_statements.empty()) {
1236 for (
int i = 0;
i <
states.size();
i++) {
1237 printer->fmt_line(
"nmodl_eigen_x[{}] = 0;",
i);
1239 for (
int i = 0;
i < conserve_statements.size();
i++) {
1241 printer->add_indent();
1242 printer->fmt_text(
"const {} nmodl_conserve_steadystate_{} = {}(",
1246 conserve_statements[
i]->get_expr()->accept(*
this);
1247 printer->fmt_text(
") / {}.0;", var_indices.size());
1248 printer->add_newline();
1249 for (
const auto& state_index: var_indices) {
1250 printer->fmt_line(
"nmodl_eigen_x[{}] = nmodl_conserve_steadystate_{};",
1257 for (
int i = 0;
i <
states.size();
i++) {
1258 printer->add_indent();
1259 printer->fmt_text(
"nmodl_eigen_x[{}] = ",
i);
1261 state_var.accept(*
this);
1262 printer->add_text(
";");
1263 printer->add_newline();
1267 printer->fmt_line(
"{} nmodl_eigen_ym = nmodl_eigen_jm.exp() * nmodl_eigen_xm;", vector_type);
1268 printer->fmt_line(
"{}* nmodl_eigen_y = nmodl_eigen_ym.data();", float_type);
1270 for (
int i = 0;
i < conserve_statements.size();
i++) {
1273 printer->add_indent();
1274 printer->fmt_text(
"const {} nmodl_conserve_{} = (", float_type,
i);
1275 conserve_statements[
i]->get_expr()->accept(*
this);
1276 printer->add_text(
") / (");
1277 for (
int var_index = 0; var_index < var_indices.size(); var_index++) {
1278 const auto& state_index = var_indices[var_index];
1279 printer->fmt_text(
"nmodl_eigen_y[{}]", state_index);
1280 if (var_index < var_indices.size() - 1) {
1281 printer->add_text(
" + ");
1284 printer->add_text(
");");
1285 printer->add_newline();
1287 for (
const auto& state_index: var_indices) {
1288 printer->fmt_line(
"nmodl_eigen_ym[{}] *= nmodl_conserve_{};", state_index,
i);
1292 for (
int i = 0;
i <
states.size();
i++) {
1293 printer->add_indent();
1295 state_var.accept(*
this);
1296 printer->fmt_text(
" = nmodl_eigen_y[{}];",
i);
1297 printer->add_newline();
1308 info.semantics.clear();
1310 if (
info.point_process) {
1314 for (
const auto& ion:
info.ions) {
1315 for (
auto i = 0;
i < ion.reads.size(); ++
i) {
1316 info.semantics.emplace_back(
index++, ion.name +
"_ion", 1);
1318 for (
const auto&
var: ion.writes) {
1320 if (
std::find(ion.reads.begin(), ion.reads.end(),
var) == ion.reads.end()) {
1321 info.semantics.emplace_back(
index++, ion.name +
"_ion", 1);
1323 if (ion.is_ionic_current(
var)) {
1324 info.semantics.emplace_back(
index++, ion.name +
"_ion", 1);
1327 if (ion.need_style) {
1328 info.semantics.emplace_back(
index++, fmt::format(
"{}_ion", ion.name), 1);
1329 info.semantics.emplace_back(
index++, fmt::format(
"#{}_ion", ion.name), 1);
1332 for (
auto&
var:
info.pointer_variables) {
1333 if (
info.first_pointer_var_index == -1) {
1336 int size =
var->get_length();
1337 if (
var->has_any_property(NmodlType::pointer_var)) {
1345 for (
auto&
var:
info.random_variables) {
1346 if (
info.first_random_var_index == -1) {
1349 int size =
var->get_length();
1354 if (
info.diam_used) {
1358 if (
info.area_used) {
1362 if (
info.net_send_used) {
1370 if (!
info.watch_statements.empty()) {
1371 for (
int i = 0;
i <
info.watch_statements.size() + 1;
i++) {
1376 if (
info.for_netcon_used) {
1385 return first->get_definition_order() < second->get_definition_order();
1388 auto assigned =
info.assigned_vars;
1392 for (
const auto& state:
states) {
1393 auto name =
"D" + state->get_name();
1394 auto symbol = make_symbol(
name);
1395 if (state->is_array()) {
1396 symbol->set_as_array(state->get_length());
1398 symbol->set_definition_order(state->get_definition_order());
1399 assigned.push_back(symbol);
1401 std::sort(assigned.begin(), assigned.end(), comparator);
1403 auto variables =
info.range_parameter_vars;
1404 variables.insert(variables.end(),
1405 info.range_assigned_vars.begin(),
1406 info.range_assigned_vars.end());
1407 variables.insert(variables.end(),
info.range_state_vars.begin(),
info.range_state_vars.end());
1409 for (
const auto&
v: assigned) {
1410 auto it = std::find_if(
info.external_variables.begin(),
1411 info.external_variables.end(),
1412 [&
v](
auto it) { return it->get_name() == get_name(v); });
1414 if (it ==
info.external_variables.end()) {
1415 variables.push_back(
v);
1419 if (needs_v_unused()) {
1423 if (breakpoint_exist()) {
1428 if (
auto r = std::find_if(variables.cbegin(),
1430 [&](
const auto&
s) { return name == s->get_name(); });
1431 r == variables.cend()) {
1432 variables.push_back(make_symbol(
name));
1436 if (net_receive_exist()) {
1459 std::vector<IndexVariableInfo> variables;
1460 if (
info.point_process) {
1462 variables.back().is_constant =
true;
1464 add_variable_point_process(variables);
1467 for (
auto& ion:
info.ions) {
1468 bool need_style =
false;
1469 std::unordered_map<std::string, int> ion_vars;
1473 auto const has_var = [&ion](
const char*
suffix) ->
bool {
1474 auto const pred = [
name = ion.name +
suffix](
auto const& x) {
return x ==
name; };
1475 return std::any_of(ion.reads.begin(), ion.reads.end(), pred) ||
1476 std::any_of(ion.writes.begin(), ion.writes.end(), pred);
1478 auto const add_implicit_read = [&ion](
const char*
suffix) {
1480 ion.reads.push_back(
name);
1483 bool const have_ionin{has_var(
"i")}, have_ionout{has_var(
"o")};
1484 if (have_ionin && !have_ionout) {
1485 add_implicit_read(
"o");
1486 }
else if (have_ionout && !have_ionin) {
1487 add_implicit_read(
"i");
1489 for (
const auto&
var: ion.reads) {
1491 variables.emplace_back(make_symbol(
name));
1492 variables.back().is_constant =
true;
1493 ion_vars[
name] =
static_cast<int>(variables.size() - 1);
1497 std::shared_ptr<symtab::Symbol> ion_di_dv_var =
nullptr;
1499 for (
const auto&
var: ion.writes) {
1502 const auto ion_vars_it = ion_vars.find(
name);
1503 if (ion_vars_it != ion_vars.end()) {
1504 variables[ion_vars_it->second].is_constant =
false;
1508 if (ion.is_ionic_current(
var)) {
1512 if (ion.is_intra_cell_conc(
var) || ion.is_extra_cell_conc(
var)) {
1518 if (ion_di_dv_var !=
nullptr) {
1519 variables.emplace_back(ion_di_dv_var);
1524 variables.emplace_back(make_symbol(
"style_" + ion.name),
false,
true);
1525 variables.back().is_constant =
true;
1529 for (
const auto&
var:
info.pointer_variables) {
1531 if (
var->has_any_property(NmodlType::pointer_var)) {
1532 variables.emplace_back(make_symbol(
name));
1534 variables.emplace_back(make_symbol(
name),
true);
1538 for (
const auto&
var:
info.random_variables) {
1540 variables.emplace_back(make_symbol(
name),
true);
1541 variables.back().symbol->add_properties(NmodlType::random_var);
1544 if (
info.diam_used) {
1548 if (
info.area_used) {
1552 add_variable_tqitem(variables);
1559 if (!
info.watch_statements.empty()) {
1560 for (
int i = 0;
i <
info.watch_statements.size() + 1;
i++) {
1561 variables.emplace_back(make_symbol(fmt::format(
"watch{}",
i)),
false,
false,
true);
1565 if (
info.for_netcon_used) {
1573 program_symtab =
node.get_symbol_table();
1577 info.mod_file = mod_filename;
1579 if (
info.mod_suffix ==
"") {
1580 info.mod_suffix = std::filesystem::path(mod_filename).stem().string();
1582 info.rsuffix =
info.point_process ?
"" :
"_" +
info.mod_suffix;
1583 if (
info.mod_suffix ==
"nothing") {
1587 if (!
info.vectorize) {
1588 logger->warn(
"CodegenCppVisitor : MOD file uses non-thread safe constructs of NMODL");
1591 codegen_float_variables = get_float_variables();
1592 codegen_int_variables = get_int_variables();
1594 update_index_semantics();
1596 info.semantic_variable_count = int_variables_size();
1619 throw std::logic_error(
"compute_method_name not implemented");
1625 print_codegen_routines();
1631 auto statement = get_table_statement(
node);
1632 auto table_variables = statement->get_table_vars();
1633 auto with = statement->get_with()->eval();
1635 auto tmin_name = get_variable_name(
"tmin_" +
name);
1636 auto mfac_name = get_variable_name(
"mfac_" +
name);
1637 auto function_name = method_name(
"f_" +
name);
1639 printer->add_newline(2);
1640 print_function_declaration(
node,
name);
1641 printer->push_block();
1643 const auto& params =
node.get_parameters();
1644 printer->fmt_push_block(
"if ({} == 0)", use_table_var);
1645 if (
node.is_procedure_block()) {
1646 printer->fmt_line(
"{}({}, {});",
1648 internal_method_arguments(),
1649 params[0].
get()->get_node_name());
1650 printer->add_line(
"return 0;");
1652 printer->fmt_line(
"return {}({}, {});",
1654 internal_method_arguments(),
1655 params[0].
get()->get_node_name());
1657 printer->pop_block();
1659 printer->fmt_line(
"double xi = {} * ({} - {});",
1661 params[0].
get()->get_node_name(),
1663 printer->push_block(
"if (isnan(xi))");
1664 if (
node.is_procedure_block()) {
1665 for (
const auto&
var: table_variables) {
1666 auto var_name = get_variable_name(
var->get_node_name());
1667 auto [
is_array, array_length] = check_if_var_is_array(
var->get_node_name());
1669 for (
int j = 0;
j < array_length;
j++) {
1670 printer->fmt_line(
"{}[{}] = xi;", var_name,
j);
1673 printer->fmt_line(
"{} = xi;", var_name);
1676 printer->add_line(
"return 0;");
1678 printer->add_line(
"return xi;");
1680 printer->pop_block();
1682 printer->fmt_push_block(
"if (xi <= 0. || xi >= {}.)", with);
1683 printer->fmt_line(
"int index = (xi <= 0.) ? 0 : {};", with);
1684 if (
node.is_procedure_block()) {
1685 for (
const auto& variable: table_variables) {
1686 auto var_name = variable->get_node_name();
1687 auto instance_name = get_variable_name(var_name);
1688 auto table_name = get_variable_name(
"t_" + var_name);
1689 auto [
is_array, array_length] = check_if_var_is_array(var_name);
1691 for (
int j = 0;
j < array_length;
j++) {
1693 "{}[{}] = {}[{}][index];", instance_name,
j, table_name,
j);
1696 printer->fmt_line(
"{} = {}[index];", instance_name, table_name);
1699 printer->add_line(
"return 0;");
1701 auto table_name = get_variable_name(
"t_" +
name);
1702 printer->fmt_line(
"return {}[index];", table_name);
1704 printer->pop_block();
1706 printer->add_line(
"int i = int(xi);");
1707 printer->add_line(
"double theta = xi - double(i);");
1708 if (
node.is_procedure_block()) {
1709 for (
const auto&
var: table_variables) {
1710 auto var_name =
var->get_node_name();
1711 auto instance_name = get_variable_name(var_name);
1712 auto table_name = get_variable_name(
"t_" + var_name);
1713 auto [
is_array, array_length] = check_if_var_is_array(
var->get_node_name());
1715 for (
size_t j = 0;
j < array_length;
j++) {
1717 "{0}[{1}] = {2}[{1}][i] + theta*({2}[{1}][i+1]-{2}[{1}][i]);",
1723 printer->fmt_line(
"{0} = {1}[i] + theta*({1}[i+1]-{1}[i]);",
1728 printer->add_line(
"return 0;");
1730 auto table_name = get_variable_name(
"t_" +
name);
1731 printer->fmt_line(
"return {0}[i] + theta * ({0}[i+1] - {0}[i]);", table_name);
1734 printer->pop_block();
1739 auto statement = get_table_statement(
node);
1740 auto table_variables = statement->get_table_vars();
1741 auto depend_variables = statement->get_depend_vars();
1742 const auto& from = statement->get_from();
1743 const auto& to = statement->get_to();
1745 auto internal_params = internal_method_parameters();
1746 auto with = statement->get_with()->eval();
1748 auto tmin_name = get_variable_name(
"tmin_" +
name);
1749 auto mfac_name = get_variable_name(
"mfac_" +
name);
1750 auto float_type = default_float_data_type();
1752 printer->add_newline(2);
1753 printer->fmt_push_block(
"void {}({})",
1754 table_update_function_name(
name),
1755 get_parameter_str(internal_params));
1757 printer->fmt_push_block(
"if ({} == 0)", use_table_var);
1758 printer->add_line(
"return;");
1759 printer->pop_block();
1761 printer->add_line(
"static bool make_table = true;");
1762 for (
const auto& variable: depend_variables) {
1763 printer->fmt_line(
"static {} save_{};", float_type, variable->get_node_name());
1766 for (
const auto& variable: depend_variables) {
1767 const auto& var_name = variable->get_node_name();
1768 const auto& instance_name = get_variable_name(var_name);
1769 printer->fmt_push_block(
"if (save_{} != {})", var_name, instance_name);
1770 printer->add_line(
"make_table = true;");
1771 printer->pop_block();
1774 printer->push_block(
"if (make_table)");
1776 printer->add_line(
"make_table = false;");
1778 printer->add_indent();
1779 printer->add_text(tmin_name,
" = ");
1780 from->accept(*
this);
1781 printer->add_text(
';');
1782 printer->add_newline();
1784 printer->add_indent();
1785 printer->add_text(
"double tmax = ");
1787 printer->add_text(
';');
1788 printer->add_newline();
1791 printer->fmt_line(
"double dx = (tmax-{}) / {}.;", tmin_name, with);
1792 printer->fmt_line(
"{} = 1./dx;", mfac_name);
1794 printer->fmt_line(
"double x = {};", tmin_name);
1795 printer->fmt_push_block(
"for (std::size_t i = 0; i < {}; x += dx, i++)", with + 1);
1796 auto function = method_name(
"f_" +
name);
1797 if (
node.is_procedure_block()) {
1798 printer->fmt_line(
"{}({}, x);",
function, internal_method_arguments());
1799 for (
const auto& variable: table_variables) {
1800 auto var_name = variable->get_node_name();
1801 auto instance_name = get_variable_name(var_name);
1802 auto table_name = get_variable_name(
"t_" + var_name);
1803 auto [
is_array, array_length] = check_if_var_is_array(var_name);
1805 for (
int j = 0;
j < array_length;
j++) {
1807 "{}[{}][i] = {}[{}];", table_name,
j, instance_name,
j);
1810 printer->fmt_line(
"{}[i] = {};", table_name, instance_name);
1814 auto table_name = get_variable_name(
"t_" +
name);
1815 printer->fmt_line(
"{}[i] = {}({}, x);",
1818 internal_method_arguments());
1820 printer->pop_block();
1822 for (
const auto& variable: depend_variables) {
1823 auto var_name = variable->get_node_name();
1824 auto instance_name = get_variable_name(var_name);
1825 printer->fmt_line(
"save_{} = {};", var_name, instance_name);
1828 printer->pop_block();
1830 printer->pop_block();
1835 const std::unordered_set<CppObjectSpecifier>& specifiers) {
1837 for (
const auto& specifier: specifiers) {
1841 result += object_specifier_map[specifier];
1848 const auto& table_statements =
collect_nodes(
node, {AstNodeType::TABLE_STATEMENT});
1850 if (table_statements.size() != 1) {
1851 auto message = fmt::format(
"One table statement expected in {} found {}",
1852 node.get_node_name(),
1853 table_statements.size());
1854 throw std::runtime_error(message);
1861 auto symbol = program_symtab->lookup_in_scope(
name);
1863 throw std::runtime_error(
1864 fmt::format(
"CodegenCppVisitor:: {} not found in symbol table!",
name));
1866 if (symbol->is_array()) {
1867 return {
true, symbol->get_length()};
1875 for (
const auto& state:
info.state_vars) {
1876 auto state_name = state->get_name();
1877 if (!
info.is_ionic_conc(state_name)) {
1878 auto lhs = get_variable_name(state_name);
1879 auto rhs = get_variable_name(state_name +
"0");
1881 if (state->is_array()) {
1882 for (
int i = 0;
i < state->get_length(); ++
i) {
1883 printer->fmt_line(
"{}[{}] = {};", lhs,
i,
rhs);
1886 printer->fmt_line(
"{} = {};", lhs,
rhs);
Auto generated AST classes declaration.
Represents binary expression in the NMODL.
Operator used in ast::BinaryExpression.
Base class for all block scoped nodes.
Represents a boolean variable.
Represent CONSERVE statement in NMODL.
std::shared_ptr< Expression > get_react() const noexcept
Getter for member variable Conserve::react.
Represents a double variable.
Represent linear solver solution block based on Eigen.
Represent newton solver solution block based on Eigen.
Represents a float variable.
Represents specific element of an array variable.
Represents an integer variable.
Represent matexp solver solution block based on Eigen.
Represent MUTEXLOCK statement in NMODL.
Represent MUTEXUNLOCK statement in NMODL.
Represents a prime variable (for ODE)
Represents top level AST node for whole NMODL input.
Represent solution of a block in the AST.
Represents block encapsulating list of statements.
const StatementVector & get_statements() const noexcept
Getter for member variable StatementBlock::statements.
StatementVector::const_iterator insert_statement(StatementVector::const_iterator position, const std::shared_ptr< Statement > &n)
Insert member to statements.
symtab::SymbolTable * get_symbol_table() const override
Return associated symbol table for the current ast node.
Represents TABLE statement in NMODL.
Statement to indicate a change in timestep in a given block.
std::vector< SymbolType > get_float_variables() const
Determine all float variables required during code generation.
bool nrn_cur_required() const noexcept
Check if nrn_cur function is required.
void visit_prime_name(const ast::PrimeName &node) override
visit node of type ast::PrimeName
void visit_matexp_block(const ast::MatexpBlock &node) override
visit node of type ast::MatexpBlock
virtual void print_nrn_pointing(const ast::FunctionCall &node)
Print nrn_pointing.
void print_functors_definitions()
Print all Newton functor structs.
bool range_variable_setup_required() const noexcept
Check if setup_range_variable function is required.
void visit_from_statement(const ast::FromStatement &node) override
visit node of type ast::FromStatement
void print_top_verbatim_blocks()
Print top level (global scope) verbatim blocks.
void visit_unit(const ast::Unit &node) override
visit node of type ast::Unit
void visit_binary_operator(const ast::BinaryOperator &node) override
visit node of type ast::BinaryOperator
std::vector< std::string > ion_read_statements_optimized(BlockType type) const
For a given output block type, return minimal statements for all read ion variables.
bool net_receive_buffering_required() const noexcept
Check if net receive/send buffering kernels required.
bool is_function_table_call(const std::string &name) const
std::string format_float_string(const std::string &value)
Convert a given float value to its string representation.
void print_namespace_start()
Prints the start of the simulator namespace.
std::vector< ShadowUseStatement > ion_write_statements(BlockType type)
For a given output block type, return statements for writing back ion variables.
bool defined_method(const std::string &name) const
Check if given method is defined in this model.
void visit_integer(const ast::Integer &node) override
visit node of type ast::Integer
bool nrn_state_required() const noexcept
Check if nrn_state function is required.
virtual void print_parallel_iteration_hint(BlockType type, const ast::Block *block)
Print pragma annotations for channel iterations.
bool net_receive_exist() const noexcept
Check if net_receive node exist.
void visit_function_call(const ast::FunctionCall &node) override
visit node of type ast::FunctionCall
void visit_eigen_linear_solver_block(const ast::EigenLinearSolverBlock &node) override
visit node of type ast::EigenLinearSolverBlock
void visit_if_statement(const ast::IfStatement &node) override
visit node of type ast::IfStatement
static std::pair< std::string, std::string > read_ion_variable_name(const std::string &name)
Return ion variable name and corresponding ion read variable name.
void visit_solution_expression(const ast::SolutionExpression &node) override
visit node of type ast::SolutionExpression
void visit_else_statement(const ast::ElseStatement &node) override
visit node of type ast::ElseStatement
std::string get_object_specifiers(const std::unordered_set< CppObjectSpecifier > &)
std::tuple< bool, int > check_if_var_is_array(const std::string &name)
Check if the given name exist in the symbol.
void visit_program(const ast::Program &program) override
Main and only member function to call after creating an instance of this class.
virtual void print_global_var_struct_assertions() const
Print static assertions about the global variable struct.
void print_namespace_stop()
Prints the end of the simulator namespace.
void visit_var_name(const ast::VarName &node) override
void visit_string(const ast::String &node) override
visit node of type ast::String
static std::string get_parameter_str(const ParamVector ¶ms)
Generate the string representing the procedure parameter declaration.
void print_prcellstate_macros() const
Print declaration of macro NRN_PRCELLSTATE for debugging.
void visit_float(const ast::Float &node) override
visit node of type ast::Float
bool breakpoint_exist() const noexcept
Check if breakpoint node exist.
bool net_receive_required() const noexcept
Check if net_receive function is required.
void print_procedure(const ast::ProcedureBlock &node)
Print NMODL procedure in target backend code.
void print_using_namespace()
Prints f"using namespace {namespace_name()}".
virtual void print_global_struct_function_table_ptrs()
Print the entries of for FUNCTION_TABLEs in the global struct.
static std::string get_arg_str(const ParamVector ¶ms)
Generate the string representing the parameters in a function call.
void visit_update_dt(const ast::UpdateDt &node) override
visit node of type ast::UpdateDt
std::vector< std::string > ion_read_statements(BlockType type) const
For a given output block type, return statements for all read ion variables.
bool ion_variable_struct_required() const
Check if a structure for ion variables is required.
void visit_unary_operator(const ast::UnaryOperator &node) override
visit node of type ast::UnaryOperator
const ast::TableStatement * get_table_statement(const ast::Block &)
std::vector< std::tuple< std::string, std::string, std::string, std::string > > ParamVector
A vector of parameters represented by a 4-tuple of strings:
void print_eigen_linear_solver(const std::string &float_type, int N)
Print linear solver using Eigen.
static bool statement_to_skip(const ast::Statement &node)
Check if given statement should be skipped during code generation.
void visit_else_if_statement(const ast::ElseIfStatement &node) override
visit node of type ast::ElseIfStatement
void visit_name(const ast::Name &node) override
visit node of type ast::Name
std::shared_ptr< symtab::Symbol > SymbolType
void print_statement_block(const ast::StatementBlock &node, bool open_brace=true, bool close_brace=true)
Print any statement block in nmodl with option to (not) print braces.
static std::pair< std::string, std::string > write_ion_variable_name(const std::string &name)
Return ion variable name and corresponding ion write variable name.
void print_nmodl_constants()
Print the nmodl constants used in backend code.
void visit_statement_block(const ast::StatementBlock &node) override
void visit_mutex_lock(const ast::MutexLock &node) override
visit node of type ast::MutexLock
int float_variables_size() const
Number of float variables in the model.
std::string table_update_function_name(const std::string &block_name) const
The name of the function that updates the table value if the parameters changed.
void visit_boolean(const ast::Boolean &node) override
visit node of type ast::Boolean
void visit_paren_expression(const ast::ParenExpression &node) override
visit node of type ast::ParenExpression
void visit_double(const ast::Double &node) override
visit node of type ast::Double
void print_table_replacement_function(const ast::Block &)
Print replacement function for function or procedure using table.
std::vector< IndexVariableInfo > get_int_variables()
Determine all int variables required during code generation.
void print_table_check_function(const ast::Block &)
Print check_function() for functions or procedure using table.
std::string format_double_string(const std::string &value)
Convert a given double value to its string representation.
std::string process_shadow_update_statement(const ShadowUseStatement &statement, BlockType type)
Process shadow update statement.
int int_variables_size() const
Number of integer variables in the model.
std::string breakpoint_current(std::string current) const
Determine the variable name for the "current" used in breakpoint block taking into account intermedia...
virtual void print_global_var_struct_decl()
Instantiate global var instance.
bool is_functor_const(const ast::StatementBlock &variable_block, const ast::StatementBlock &functor_block)
Checks whether the functor_block generated by sympy solver modifies any variable outside its scope.
std::string update_if_ion_variable_name(const std::string &name) const
Determine the updated name if the ion variable has been optimized.
void visit_indexed_name(const ast::IndexedName &node) override
visit node of type ast::IndexedName
void visit_mutex_unlock(const ast::MutexUnlock &node) override
visit node of type ast::MutexUnlock
bool has_parameter_of_name(const T &node, const std::string &name)
Check if function or procedure node has parameter with given name.
void visit_local_list_statement(const ast::LocalListStatement &node) override
visit node of type ast::LocalListStatement
void visit_while_statement(const ast::WhileStatement &node) override
visit node of type ast::WhileStatement
bool net_send_buffer_required() const noexcept
Check if net_send_buffer is required.
void print_rename_state_vars() const
virtual void print_function_call(const ast::FunctionCall &node)
Print call to internal or external function.
void visit_eigen_newton_solver_block(const ast::EigenNewtonSolverBlock &node) override
visit node of type ast::EigenNewtonSolverBlock
void visit_binary_expression(const ast::BinaryExpression &node) override
visit node of type ast::BinaryExpression
std::string compute_method_name(BlockType type) const
int get_int_variable_index(const std::string &var_name)
void update_index_semantics()
populate all index semantics needed for registration with coreneuron
void print_mechanism_info()
Print backend code for byte array that has mechanism information (to be registered with NEURON/CoreNE...
virtual void setup(const ast::Program &node)
void print_function(const ast::FunctionBlock &node)
Print NMODL function in target backend code.
void print_function_tables(const ast::FunctionTableBlock &node)
Print the internal function for FUNCTION_TABLES.
void print_functor_definition(const ast::EigenNewtonSolverBlock &node)
Based on the EigenNewtonSolverBlock passed print the definition needed for its functor.
void print_backend_info()
Print top file header printed in generated code.
static bool need_semicolon(const ast::Statement &node)
Check if a semicolon is required at the end of given statement.
Helper visitor to gather AST information to help code generation.
Visitor to return Def-Use chain for a given variable in the block/node
Blindly rename given variable to new name
Concrete visitor for constructing symbol table from AST.
void visit_statement_block(ast::StatementBlock &node) override
visit node of type ast::StatementBlock
Visitor for printing C++ code compatible with legacy api of CoreNEURON
Helper visitor to gather AST information to help code generation.
Implement utility functions for codegen visitors.
Visitor to return Def-Use chain for a given variable in the block/node
@ MUTEX_UNLOCK
type of ast::MutexUnlock
@ MUTEX_LOCK
type of ast::MutexLock
@ PROTECT_STATEMENT
type of ast::ProtectStatement
@ LOCAL_VAR
type of ast::LocalVar
int get_index_from_name(const std::vector< T > &variables, const std::string &name)
BlockType
Helper to represent various block types.
@ Destructor
destructor block
@ Constructor
constructor block
@ Equation
breakpoint block
bool is_array(const Symbol &sym)
double var(InputIterator begin, InputIterator end)
void move(Item *q1, Item *q2, Item *q3)
std::string to_string(EnumT e, const std::array< std::pair< EnumT, std::string_view >, N > &mapping, const std::string_view enum_name)
Converts an enum value to its corresponding string representation.
static constexpr char POINT_PROCESS_SEMANTIC[]
semantic type for point process variable
static constexpr char AREA_VARIABLE[]
similar to node_area but user can explicitly declare it as area
static constexpr char NRN_CONSTRUCTOR_METHOD[]
nrn_constructor method in generated code
static constexpr char NET_SEND_SEMANTIC[]
semantic type for net send call
static constexpr char NRN_INIT_METHOD[]
nrn_init method in generated code
static constexpr char T_SAVE_VARIABLE[]
variable t indicating last execution time of net receive block
static constexpr char CONDUCTANCE_UNUSED_VARIABLE[]
range variable when conductance is not used (for vectorized model)
static constexpr char NRN_STATE_METHOD[]
nrn_state method in generated code
static constexpr char CORE_POINTER_SEMANTIC[]
semantic type for core pointer variable
static constexpr char USE_TABLE_VARIABLE[]
global variable to indicate if table is used
static constexpr char WATCH_SEMANTIC[]
semantic type for watch statement
static constexpr char CONDUCTANCE_VARIABLE[]
range variable for conductance
static constexpr char FOR_NETCON_SEMANTIC[]
semantic type for for_netcon statement
static constexpr char DEFAULT_FLOAT_TYPE[]
default float variable type
static constexpr char NRN_CUR_METHOD[]
nrn_cur method in generated code
static constexpr char FOR_NETCON_VARIABLE[]
name of the integer variabe to store FOR_NETCON info.
static constexpr char NRN_DESTRUCTOR_METHOD[]
nrn_destructor method in generated code
static constexpr char DIAM_VARIABLE[]
inbuilt neuron variable for diameter of the compartment
static std::unordered_map< std::string, std::string > RANDOM_FUNCTIONS_MAPPING
static constexpr char RANDOM_SEMANTIC[]
semantic type for RANDOM variable
static constexpr char ION_VARNAME_PREFIX[]
prefix for ion variable
static constexpr char POINTER_SEMANTIC[]
semantic type for pointer variable
static constexpr char NRN_WATCH_CHECK_METHOD[]
nrn_watch_check method in generated c++ file
static constexpr char AREA_SEMANTIC[]
semantic type for area variable
static constexpr char VOLTAGE_UNUSED_VARIABLE[]
range variable for voltage when unused (for vectorized model)
static constexpr char NODE_AREA_VARIABLE[]
inbuilt neuron variable for area of the compartment
std::string format_float_string(const std::string &s_value)
Handles the float constants format being printed in the generated code.
std::string format_double_string(const std::string &s_value)
Handles the double constants format being printed in the generated code.
std::vector< int > get_conserve_variable_indices(const ast::Conserve &conserve, const std::vector< std::string > states)
Read the names of the state variable being conserved, and return their indices into the given list of...
const std::regex regex_special_chars
NmodlType
NMODL variable properties.
DUState
Represent a state in Def-Use chain.
encapsulates code generation backend implementations
std::vector< std::shared_ptr< const ast::Ast > > collect_nodes(const ast::Ast &node, const std::vector< ast::AstNodeType > &types)
traverse node recursively and collect nodes of given types
bool is_nrn_pointing(const std::string &name)
Is given name nrn_pointing.
Symbol * breakpoint_current(Symbol *s)
static Node * node(Object *)
static double param[NPARAM]
Blindly rename given variable to new name
int find(const int, const int, const int, const int, const int)
Implement string manipulation functions.
static const std::string NMODL_VERSION
project tagged version in the cmake
static const std::string GIT_REVISION
git revision id
Represent semantic information for index variable.
Represents ion write statement during code generation.
THIS FILE IS GENERATED AT BUILD TIME AND SHALL NOT BE EDITED.
Utility functions for visitors implementation.