diff --git a/hercules_ir/src/dot.rs b/hercules_ir/src/dot.rs index f040848ee0de807f1ed552fb66da4de2bd420b11..eaaf68fbd6ea762fcc0ec2a42e12d806d41d8ed8 100644 --- a/hercules_ir/src/dot.rs +++ b/hercules_ir/src/dot.rs @@ -5,8 +5,7 @@ use std::collections::HashMap; pub fn write_dot<W: std::fmt::Write>(module: &Module, w: &mut W) -> std::fmt::Result { write!(w, "digraph \"Module\" {{\n")?; for i in 0..module.functions.len() { - let function = &module.functions[i]; - write_function(i, function, &module.constants, w)?; + write_function(i, module, &module.constants, w)?; } write!(w, "}}")?; Ok(()) @@ -14,13 +13,14 @@ pub fn write_dot<W: std::fmt::Write>(module: &Module, w: &mut W) -> std::fmt::Re fn write_function<W: std::fmt::Write>( i: usize, - function: &Function, + module: &Module, constants: &Vec<Constant>, w: &mut W, ) -> std::fmt::Result { let mut visited = HashMap::default(); + let function = &module.functions[i]; for j in 0..function.nodes.len() { - visited = write_node(i, j, &function.nodes, constants, visited, w)?.1; + visited = write_node(i, j, module, constants, visited, w)?.1; } Ok(()) } @@ -28,7 +28,7 @@ fn write_function<W: std::fmt::Write>( fn write_node<W: std::fmt::Write>( i: usize, j: usize, - nodes: &Vec<Node>, + module: &Module, constants: &Vec<Constant>, mut visited: HashMap<NodeID, String>, w: &mut W, @@ -37,7 +37,7 @@ fn write_node<W: std::fmt::Write>( if visited.contains_key(&id) { Ok((visited.get(&id).unwrap().clone(), visited)) } else { - let node = &nodes[j]; + let node = &module.functions[i].nodes[j]; let name = format!("{}_{}_{}", get_string_node_kind(node), i, j); visited.insert(NodeID::new(j), name.clone()); let visited = match node { @@ -47,9 +47,9 @@ fn write_node<W: std::fmt::Write>( } Node::Return { control, value } => { let (control_name, visited) = - write_node(i, control.idx(), nodes, constants, visited, w)?; + write_node(i, control.idx(), module, constants, visited, w)?; let (value_name, visited) = - write_node(i, value.idx(), nodes, constants, visited, w)?; + write_node(i, value.idx(), module, constants, visited, w)?; write!(w, "{} [label=\"return\"];\n", name)?; write!(w, "{} -> {};\n", control_name, name)?; write!(w, "{} -> {};\n", value_name, name)?; @@ -69,16 +69,34 @@ fn write_node<W: std::fmt::Write>( right, } => { let (control_name, visited) = - write_node(i, control.idx(), nodes, constants, visited, w)?; - let (left_name, visited) = write_node(i, left.idx(), nodes, constants, visited, w)?; + write_node(i, control.idx(), module, constants, visited, w)?; + let (left_name, visited) = + write_node(i, left.idx(), module, constants, visited, w)?; let (right_name, visited) = - write_node(i, right.idx(), nodes, constants, visited, w)?; + write_node(i, right.idx(), module, constants, visited, w)?; write!(w, "{} [label=\"add\"];\n", name)?; write!(w, "{} -> {};\n", control_name, name)?; write!(w, "{} -> {};\n", left_name, name)?; write!(w, "{} -> {};\n", right_name, name)?; visited } + Node::Call { + control, + function, + args, + } => { + let (control_name, mut visited) = + write_node(i, control.idx(), module, constants, visited, w)?; + for arg in args.iter() { + let (arg_name, tmp_visited) = + write_node(i, arg.idx(), module, constants, visited, w)?; + visited = tmp_visited; + write!(w, "{} -> {};\n", arg_name, name)?; + } + write!(w, "{} [label=\"call({})\"];\n", name, function.idx())?; + write!(w, "{} -> {};\n", control_name, name)?; + visited + } _ => todo!(), }; Ok((visited.get(&id).unwrap().clone(), visited)) @@ -131,6 +149,10 @@ fn get_string_node_kind(node: &Node) -> &'static str { left: _, right: _, } => "div", - Node::Call { args: _ } => "call", + Node::Call { + control: _, + function: _, + args: _, + } => "call", } } diff --git a/hercules_ir/src/ir.rs b/hercules_ir/src/ir.rs index 21bc50be7acda12bee8b4423820d7818e505792b..172f5ad9dd74462507415caffaeb07b720d2e908 100644 --- a/hercules_ir/src/ir.rs +++ b/hercules_ir/src/ir.rs @@ -97,6 +97,8 @@ pub enum Node { right: NodeID, }, Call { + control: NodeID, + function: FunctionID, args: Box<[NodeID]>, }, } diff --git a/hercules_ir/src/parse.rs b/hercules_ir/src/parse.rs index 667ced3d14d961bbfa951a271fc88cfe67eba38c..9f401ea99dfbf417d7fc630711f73ea531a9b7a4 100644 --- a/hercules_ir/src/parse.rs +++ b/hercules_ir/src/parse.rs @@ -63,6 +63,20 @@ fn parse_module<'a>(ir_text: &'a str, mut context: Context<'a>) -> nom::IResult< nom::combinator::all_consuming(nom::multi::many0(|x| parse_function(x, &mut context)))( ir_text, )?; + let mut fixed_functions = vec![ + Function { + name: String::from(""), + param_types: vec![], + return_type: TypeID::new(0), + nodes: vec![] + }; + context.function_ids.len() + ]; + for function in functions { + let function_name = function.name.clone(); + let function_id = context.function_ids.remove(function_name.as_str()).unwrap(); + fixed_functions[function_id.idx()] = function; + } let mut types = vec![Type::Control(0); context.interned_types.len()]; for (ty, id) in context.interned_types { types[id.idx()] = ty; @@ -74,7 +88,7 @@ fn parse_module<'a>(ir_text: &'a str, mut context: Context<'a>) -> nom::IResult< Ok(( rest, Module { - functions, + functions: fixed_functions, types, constants, }, @@ -123,6 +137,8 @@ fn parse_function<'a>( fixed_nodes[id.idx()] = Node::Parameter { index: id.idx() } } } + let ir_text = nom::character::complete::multispace0(ir_text)?.0; + context.get_function_id(function_name); Ok(( ir_text, Function { @@ -148,6 +164,7 @@ fn parse_node<'a>( "return" => parse_return(ir_text, context)?, "constant" => parse_constant_node(ir_text, context)?, "add" => parse_add(ir_text, context)?, + "call" => parse_call(ir_text, context)?, _ => todo!(), }; context.get_node_id(node_name); @@ -216,6 +233,39 @@ fn parse_add<'a>(ir_text: &'a str, context: &mut Context<'a>) -> nom::IResult<&' )) } +fn parse_call<'a>(ir_text: &'a str, context: &mut Context<'a>) -> nom::IResult<&'a str, Node> { + let ir_text = nom::character::complete::multispace0(ir_text)?.0; + let ir_text = nom::character::complete::char('(')(ir_text)?.0; + let ir_text = nom::character::complete::multispace0(ir_text)?.0; + let (ir_text, control) = nom::character::complete::alphanumeric1(ir_text)?; + let ir_text = nom::character::complete::multispace0(ir_text)?.0; + let ir_text = nom::character::complete::char(',')(ir_text)?.0; + let ir_text = nom::character::complete::multispace0(ir_text)?.0; + let (ir_text, mut function_and_args) = nom::multi::separated_list1( + nom::sequence::tuple(( + nom::character::complete::multispace0, + nom::character::complete::char(','), + nom::character::complete::multispace0, + )), + nom::character::complete::alphanumeric1, + )(ir_text)?; + let function = function_and_args.remove(0); + let args: Vec<NodeID> = function_and_args + .into_iter() + .map(|x| context.get_node_id(x)) + .collect(); + let ir_text = nom::character::complete::multispace0(ir_text)?.0; + let ir_text = nom::character::complete::char(')')(ir_text)?.0; + Ok(( + ir_text, + Node::Call { + control: context.get_node_id(control), + function: context.get_function_id(function), + args: args.into_boxed_slice(), + }, + )) +} + fn parse_type_id<'a>(ir_text: &'a str, context: &mut Context<'a>) -> nom::IResult<&'a str, TypeID> { let ir_text = nom::character::complete::multispace0(ir_text)?.0; let (ir_text, ty) = parse_type(ir_text)?; @@ -284,8 +334,19 @@ mod tests { #[test] fn parse_ir1() { - let module = - parse("fn add(x: i32, y: i32) -> i32 c = constant(i8, 5) r = return(start, w) w = add(start, z, c) z = add(start, x, y)"); + let module = parse( + " +fn myfunc(x: i32) -> i32 + y = call(start, add, x, x) + r = return(start, y) + +fn add(x: i32, y: i32) -> i32 + c = constant(i8, 5) + r = return(start, w) + w = add(start, z, c) + z = add(start, x, y) +", + ); println!("{:?}", module); let mut dot = String::new(); write_dot(&module, &mut dot).unwrap();