From bb1ed178e5a787d42b9ec5d170820e3e3b56227a Mon Sep 17 00:00:00 2001
From: Russel Arbore <russel.jma@gmail.com>
Date: Sun, 10 Sep 2023 16:51:01 -0500
Subject: [PATCH] RefCell parse context, parse prod and sum types

---
 hercules_ir/src/parse.rs | 138 ++++++++++++++++++++++++++-------------
 1 file changed, 94 insertions(+), 44 deletions(-)

diff --git a/hercules_ir/src/parse.rs b/hercules_ir/src/parse.rs
index aae3914d..6798c9f9 100644
--- a/hercules_ir/src/parse.rs
+++ b/hercules_ir/src/parse.rs
@@ -1,6 +1,8 @@
 extern crate nom;
 
+use std::cell::RefCell;
 use std::collections::HashMap;
+use std::str::FromStr;
 
 use crate::*;
 
@@ -69,11 +71,13 @@ impl<'a> Context<'a> {
     }
 }
 
-fn parse_module<'a>(ir_text: &'a str, mut context: Context<'a>) -> nom::IResult<&'a str, Module> {
+fn parse_module<'a>(ir_text: &'a str, context: Context<'a>) -> nom::IResult<&'a str, Module> {
+    let context = RefCell::new(context);
     let (rest, functions) =
-        nom::combinator::all_consuming(nom::multi::many0(|x| parse_function(x, &mut context)))(
+        nom::combinator::all_consuming(nom::multi::many0(|x| parse_function(x, &context)))(
             ir_text,
         )?;
+    let mut context = context.into_inner();
     let mut fixed_functions = vec![
         Function {
             name: String::from(""),
@@ -115,9 +119,9 @@ fn parse_module<'a>(ir_text: &'a str, mut context: Context<'a>) -> nom::IResult<
 
 fn parse_function<'a>(
     ir_text: &'a str,
-    context: &mut Context<'a>,
+    context: &RefCell<Context<'a>>,
 ) -> nom::IResult<&'a str, Function> {
-    context.node_ids.clear();
+    context.borrow_mut().node_ids.clear();
     let ir_text = nom::character::complete::multispace0(ir_text)?.0;
     let ir_text = nom::bytes::complete::tag("fn")(ir_text)?.0;
     let ir_text = nom::character::complete::multispace0(ir_text)?.0;
@@ -135,28 +139,30 @@ fn parse_function<'a>(
             nom::character::complete::multispace0,
         )),
     )(ir_text)?;
-    context.node_ids.insert("start", NodeID::new(0));
+    context
+        .borrow_mut()
+        .node_ids
+        .insert("start", NodeID::new(0));
     for param in params.iter() {
-        context
-            .node_ids
-            .insert(param.1, NodeID::new(context.node_ids.len()));
+        let id = NodeID::new(context.borrow().node_ids.len());
+        context.borrow_mut().node_ids.insert(param.1, id);
     }
     let ir_text = nom::character::complete::char(')')(ir_text)?.0;
     let ir_text = nom::character::complete::multispace0(ir_text)?.0;
     let ir_text = nom::bytes::complete::tag("->")(ir_text)?.0;
     let (ir_text, return_type) = parse_type_id(ir_text, context)?;
     let (ir_text, nodes) = nom::multi::many1(|x| parse_node(x, context))(ir_text)?;
-    let mut fixed_nodes = vec![Node::Start; context.node_ids.len()];
+    let mut fixed_nodes = vec![Node::Start; context.borrow().node_ids.len()];
     for (name, node) in nodes {
-        fixed_nodes[context.node_ids.remove(name).unwrap().idx()] = node;
+        fixed_nodes[context.borrow_mut().node_ids.remove(name).unwrap().idx()] = node;
     }
-    for (_, id) in context.node_ids.iter() {
+    for (_, id) in context.borrow().node_ids.iter() {
         if id.idx() != 0 {
             fixed_nodes[id.idx()] = Node::Parameter { index: id.idx() }
         }
     }
     let ir_text = nom::character::complete::multispace0(ir_text)?.0;
-    context.get_function_id(function_name);
+    context.borrow_mut().get_function_id(function_name);
     Ok((
         ir_text,
         Function {
@@ -171,7 +177,7 @@ fn parse_function<'a>(
 
 fn parse_node<'a>(
     ir_text: &'a str,
-    context: &mut Context<'a>,
+    context: &RefCell<Context<'a>>,
 ) -> nom::IResult<&'a str, (&'a str, Node)> {
     let ir_text = nom::character::complete::multispace0(ir_text)?.0;
     let (ir_text, node_name) = nom::character::complete::alphanumeric1(ir_text)?;
@@ -186,11 +192,14 @@ fn parse_node<'a>(
         "call" => parse_call(ir_text, context)?,
         _ => todo!(),
     };
-    context.get_node_id(node_name);
+    context.borrow_mut().get_node_id(node_name);
     Ok((ir_text, (node_name, node)))
 }
 
-fn parse_return<'a>(ir_text: &'a str, context: &mut Context<'a>) -> nom::IResult<&'a str, Node> {
+fn parse_return<'a>(
+    ir_text: &'a str,
+    context: &RefCell<Context<'a>>,
+) -> nom::IResult<&'a str, Node> {
     let ir_text = nom::character::complete::multispace0(ir_text)?.0;
     let ir_text = nom::character::complete::char('(')(ir_text)?.0;
     let ir_text = nom::character::complete::multispace0(ir_text)?.0;
@@ -201,18 +210,14 @@ fn parse_return<'a>(ir_text: &'a str, context: &mut Context<'a>) -> nom::IResult
     let (ir_text, value) = nom::character::complete::alphanumeric1(ir_text)?;
     let ir_text = nom::character::complete::multispace0(ir_text)?.0;
     let ir_text = nom::character::complete::char(')')(ir_text)?.0;
-    Ok((
-        ir_text,
-        Node::Return {
-            control: context.get_node_id(control),
-            value: context.get_node_id(value),
-        },
-    ))
+    let control = context.borrow_mut().get_node_id(control);
+    let value = context.borrow_mut().get_node_id(value);
+    Ok((ir_text, Node::Return { control, value }))
 }
 
 fn parse_constant_node<'a>(
     ir_text: &'a str,
-    context: &mut Context<'a>,
+    context: &RefCell<Context<'a>>,
 ) -> nom::IResult<&'a str, Node> {
     let ir_text = nom::character::complete::multispace0(ir_text)?.0;
     let ir_text = nom::character::complete::char('(')(ir_text)?.0;
@@ -227,7 +232,7 @@ fn parse_constant_node<'a>(
     Ok((ir_text, Node::Constant { id }))
 }
 
-fn parse_add<'a>(ir_text: &'a str, context: &mut Context<'a>) -> nom::IResult<&'a str, Node> {
+fn parse_add<'a>(ir_text: &'a str, context: &RefCell<Context<'a>>) -> nom::IResult<&'a str, Node> {
     let ir_text = nom::character::complete::multispace0(ir_text)?.0;
     let ir_text = nom::character::complete::char('(')(ir_text)?.0;
     let ir_text = nom::character::complete::multispace0(ir_text)?.0;
@@ -242,17 +247,20 @@ fn parse_add<'a>(ir_text: &'a str, context: &mut Context<'a>) -> nom::IResult<&'
     let (ir_text, right) = nom::character::complete::alphanumeric1(ir_text)?;
     let ir_text = nom::character::complete::multispace0(ir_text)?.0;
     let ir_text = nom::character::complete::char(')')(ir_text)?.0;
+    let control = context.borrow_mut().get_node_id(control);
+    let left = context.borrow_mut().get_node_id(left);
+    let right = context.borrow_mut().get_node_id(right);
     Ok((
         ir_text,
         Node::Add {
-            control: context.get_node_id(control),
-            left: context.get_node_id(left),
-            right: context.get_node_id(right),
+            control,
+            left,
+            right,
         },
     ))
 }
 
-fn parse_call<'a>(ir_text: &'a str, context: &mut Context<'a>) -> nom::IResult<&'a str, Node> {
+fn parse_call<'a>(ir_text: &'a str, context: &RefCell<Context<'a>>) -> nom::IResult<&'a str, Node> {
     let ir_text = nom::character::complete::multispace0(ir_text)?.0;
     let ir_text = nom::character::complete::char('(')(ir_text)?.0;
     let ir_text = nom::character::complete::multispace0(ir_text)?.0;
@@ -271,28 +279,33 @@ fn parse_call<'a>(ir_text: &'a str, context: &mut Context<'a>) -> nom::IResult<&
     let function = function_and_args.remove(0);
     let args: Vec<NodeID> = function_and_args
         .into_iter()
-        .map(|x| context.get_node_id(x))
+        .map(|x| context.borrow_mut().get_node_id(x))
         .collect();
     let ir_text = nom::character::complete::multispace0(ir_text)?.0;
     let ir_text = nom::character::complete::char(')')(ir_text)?.0;
+    let control = context.borrow_mut().get_node_id(control);
+    let function = context.borrow_mut().get_function_id(function);
     Ok((
         ir_text,
         Node::Call {
-            control: context.get_node_id(control),
-            function: context.get_function_id(function),
+            control,
+            function,
             args: args.into_boxed_slice(),
         },
     ))
 }
 
-fn parse_type_id<'a>(ir_text: &'a str, context: &mut Context<'a>) -> nom::IResult<&'a str, TypeID> {
+fn parse_type_id<'a>(
+    ir_text: &'a str,
+    context: &RefCell<Context<'a>>,
+) -> nom::IResult<&'a str, TypeID> {
     let ir_text = nom::character::complete::multispace0(ir_text)?.0;
     let (ir_text, ty) = parse_type(ir_text, context)?;
-    let id = context.get_type_id(ty);
+    let id = context.borrow_mut().get_type_id(ty);
     Ok((ir_text, id))
 }
 
-fn parse_type<'a>(ir_text: &'a str, context: &mut Context<'a>) -> nom::IResult<&'a str, Type> {
+fn parse_type<'a>(ir_text: &'a str, context: &RefCell<Context<'a>>) -> nom::IResult<&'a str, Type> {
     let ir_text = nom::character::complete::multispace0(ir_text)?.0;
     let (ir_text, ty) = nom::branch::alt((
         nom::combinator::map(
@@ -322,16 +335,56 @@ fn parse_type<'a>(ir_text: &'a str, context: &mut Context<'a>) -> nom::IResult<&
         }),
         nom::combinator::map(nom::bytes::complete::tag("f32"), |_| Type::Float32),
         nom::combinator::map(nom::bytes::complete::tag("f64"), |_| Type::Float64),
+        nom::combinator::map(
+            nom::sequence::tuple((
+                nom::bytes::complete::tag("prod"),
+                nom::character::complete::multispace0,
+                nom::character::complete::char('('),
+                nom::character::complete::multispace0,
+                nom::multi::separated_list1(
+                    nom::sequence::tuple((
+                        nom::character::complete::multispace0,
+                        nom::character::complete::char(','),
+                        nom::character::complete::multispace0,
+                    )),
+                    |x| parse_type_id(x, context),
+                ),
+                nom::character::complete::multispace0,
+                nom::character::complete::char(')'),
+            )),
+            |(_, _, _, _, ids, _, _)| Type::Product(ids.into_boxed_slice()),
+        ),
+        nom::combinator::map(
+            nom::sequence::tuple((
+                nom::bytes::complete::tag("sum"),
+                nom::character::complete::multispace0,
+                nom::character::complete::char('('),
+                nom::character::complete::multispace0,
+                nom::multi::separated_list1(
+                    nom::sequence::tuple((
+                        nom::character::complete::multispace0,
+                        nom::character::complete::char(','),
+                        nom::character::complete::multispace0,
+                    )),
+                    |x| parse_type_id(x, context),
+                ),
+                nom::character::complete::multispace0,
+                nom::character::complete::char(')'),
+            )),
+            |(_, _, _, _, ids, _, _)| Type::Summation(ids.into_boxed_slice()),
+        ),
     ))(ir_text)?;
     Ok((ir_text, ty))
 }
 
 fn parse_dynamic_constant_id<'a>(
     ir_text: &'a str,
-    context: &mut Context<'a>,
+    context: &RefCell<Context<'a>>,
 ) -> nom::IResult<&'a str, DynamicConstantID> {
     let (ir_text, dynamic_constant) = parse_dynamic_constant(ir_text)?;
-    let id = context.get_dynamic_constant_id(dynamic_constant);
+    let id = context
+        .borrow_mut()
+        .get_dynamic_constant_id(dynamic_constant);
     Ok((ir_text, id))
 }
 
@@ -355,17 +408,17 @@ fn parse_dynamic_constant<'a>(ir_text: &'a str) -> nom::IResult<&'a str, Dynamic
 fn parse_constant_id<'a>(
     ir_text: &'a str,
     ty: Type,
-    context: &mut Context<'a>,
+    context: &RefCell<Context<'a>>,
 ) -> nom::IResult<&'a str, ConstantID> {
     let (ir_text, constant) = parse_constant(ir_text, ty, context)?;
-    let id = context.get_constant_id(constant);
+    let id = context.borrow_mut().get_constant_id(constant);
     Ok((ir_text, id))
 }
 
 fn parse_constant<'a>(
     ir_text: &'a str,
     ty: Type,
-    context: &mut Context<'a>,
+    context: &RefCell<Context<'a>>,
 ) -> nom::IResult<&'a str, Constant> {
     let (ir_text, constant) = match ty {
         Type::Integer8 => parse_integer8(ir_text)?,
@@ -380,14 +433,11 @@ fn parse_constant<'a>(
         Type::Float64 => parse_float64(ir_text)?,
         _ => todo!(),
     };
-    context.get_type_id(ty);
+    context.borrow_mut().get_type_id(ty);
     Ok((ir_text, constant))
 }
 
-fn parse_prim<'a, T: std::str::FromStr>(
-    ir_text: &'a str,
-    chars: &'static str,
-) -> nom::IResult<&'a str, T> {
+fn parse_prim<'a, T: FromStr>(ir_text: &'a str, chars: &'static str) -> nom::IResult<&'a str, T> {
     let (ir_text, x_text) = nom::bytes::complete::is_a(chars)(ir_text)?;
     let x = x_text.parse::<T>().map_err(|_| {
         nom::Err::Error(nom::error::Error {
-- 
GitLab