Skip to content
Snippets Groups Projects
Commit 5f3f78b2 authored by Xavier Routh's avatar Xavier Routh Committed by rarbore2
Browse files

projection_node

parent 64daf66f
No related branches found
No related tags found
1 merge request!29projection_node
Showing
with 138 additions and 168 deletions
......@@ -608,14 +608,6 @@ version = "0.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea"
[[package]]
name = "hercules_cg"
version = "0.1.0"
dependencies = [
"bitvec",
"hercules_ir",
]
[[package]]
name = "hercules_driver"
version = "0.1.0"
......@@ -662,7 +654,6 @@ name = "hercules_opt"
version = "0.1.0"
dependencies = [
"bitvec",
"hercules_cg",
"hercules_ir",
"ordered-float",
"postcard",
......
[workspace]
resolver = "2"
members = [
"hercules_cg",
"hercules_ir",
"hercules_opt",
"hercules_rt",
......
......@@ -394,9 +394,7 @@ impl<'a> Builder<'a> {
Index::Position(idx)
}
pub fn create_control_index(&self, idx: usize) -> Index {
Index::Control(idx)
}
pub fn create_function(
&mut self,
......@@ -478,6 +476,10 @@ impl NodeBuilder {
};
}
pub fn build_projection(&mut self, control: NodeID, selection: usize) {
self.node = Node::Projection { control, selection };
}
pub fn build_return(&mut self, control: NodeID, data: NodeID) {
self.node = Node::Return { control, data };
}
......
......@@ -332,7 +332,7 @@ pub fn control_output_flow(
let node = &function.nodes[node_id.idx()];
// Step 2: clear all bits, if applicable.
if node.is_strictly_control() || node.is_thread_id() || node.is_reduce() || node.is_phi() {
if node.is_control() || node.is_thread_id() || node.is_reduce() || node.is_phi() {
out = UnionNodeSet::Empty;
}
......
......@@ -210,6 +210,10 @@ pub fn get_uses<'a>(node: &'a Node) -> NodeUses<'a> {
NodeUses::Two([*collect, *data])
}
}
Node::Projection { control, selection: _ } => {
NodeUses::One([*control])
}
}
}
......@@ -291,5 +295,8 @@ pub fn get_uses_mut<'a>(node: &'a mut Node) -> NodeUsesMut<'a> {
NodeUsesMut::Two([collect, data])
}
}
Node::Projection { control, selection } => {
NodeUsesMut::One([control])
},
}
}
......@@ -128,15 +128,12 @@ pub enum DynamicConstant {
* However, each of these types are indexed differently. Thus, these two nodes
* operate on an index list, composing indices at different levels in a type
* tree. Each type that can be indexed has a unique variant in the index enum.
* Read nodes are overloaded to select between control successors of if and
* match nodes.
*/
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum Index {
Field(usize),
Variant(usize),
Position(Box<[NodeID]>),
Control(usize),
}
/*
......@@ -225,6 +222,10 @@ pub enum Node {
data: NodeID,
indices: Box<[Index]>,
},
Projection {
control: NodeID,
selection: usize,
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
......@@ -862,7 +863,6 @@ macro_rules! define_pattern_predicate {
}
impl Index {
define_pattern_predicate!(is_control, Index::Control(_));
pub fn try_field(&self) -> Option<usize> {
if let Index::Field(field) = self {
......@@ -880,20 +880,11 @@ impl Index {
}
}
pub fn try_control(&self) -> Option<usize> {
if let Index::Control(val) = self {
Some(*val)
} else {
None
}
}
pub fn lower_case_name(&self) -> &'static str {
match self {
Index::Field(_) => "field",
Index::Variant(_) => "variant",
Index::Position(_) => "position",
Index::Control(_) => "control",
}
}
}
......@@ -958,6 +949,13 @@ impl Node {
}
);
define_pattern_predicate!(is_match, Node::Match { control: _, sum: _ });
define_pattern_predicate!(
is_projection,
Node::Projection {
control: _,
selection: _
}
);
pub fn try_region(&self) -> Option<&[NodeID]> {
if let Node::Region { preds } = self {
......@@ -1017,17 +1015,6 @@ impl Node {
}
}
pub fn try_control_read(&self, branch: usize) -> Option<NodeID> {
if let Node::Read { collect, indices } = self
&& indices.len() == 1
&& indices[0] == Index::Control(branch)
{
Some(*collect)
} else {
None
}
}
pub fn is_zero_constant(&self, constants: &Vec<Constant>) -> bool {
if let Node::Constant { id } = self
&& constants[id.idx()].is_zero()
......@@ -1038,19 +1025,14 @@ impl Node {
}
}
/*
* Read nodes can be considered control when following an if or match
* node. However, it is sometimes useful to exclude such nodes when
* considering control nodes.
*/
pub fn is_strictly_control(&self) -> bool {
self.is_start()
|| self.is_region()
|| self.is_if()
|| self.is_match()
|| self.is_fork()
|| self.is_join()
|| self.is_return()
pub fn try_projection(&self, branch: usize) -> Option<NodeID> {
if let Node::Projection { control, selection } = self
&& branch == *selection
{
Some(*control)
} else {
None
}
}
pub fn upper_case_name(&self) -> &'static str {
......@@ -1110,6 +1092,10 @@ impl Node {
data: _,
indices: _,
} => "Write",
Node::Projection {
control: _,
selection: _
} => "Projection",
}
}
......@@ -1170,23 +1156,22 @@ impl Node {
data: _,
indices: _,
} => "write",
Node::Projection {
control: _,
selection: _
} => "projection",
}
}
pub fn is_control(&self) -> bool {
if self.is_strictly_control() {
return true;
}
if let Node::Read {
collect: _,
indices,
} = self
{
return indices.len() == 1 && indices[0].is_control();
}
false
self.is_start()
|| self.is_region()
|| self.is_if()
|| self.is_match()
|| self.is_fork()
|| self.is_join()
|| self.is_return()
|| self.is_projection()
}
}
......@@ -1491,6 +1476,9 @@ impl IRDisplay for Node {
third.0
)
}
Node::Projection { control, selection } => {
write!(f, "projection({}, {})", control.0, selection)
}
}
}
}
......@@ -1500,7 +1488,6 @@ impl IRDisplay for Index {
match self {
Index::Field(idx) => write!(f, "field({})", idx),
Index::Variant(idx) => write!(f, "variant({})", idx),
Index::Control(idx) => write!(f, "control({})", idx),
Index::Position(indices) => {
write!(f, "position(")?;
for (i, idx) in indices.iter().enumerate() {
......
......@@ -288,6 +288,7 @@ fn parse_node<'a>(
"return" => parse_return(ir_text, context)?,
"constant" => parse_constant_node(ir_text, context)?,
"dynamic_constant" => parse_dynamic_constant_node(ir_text, context)?,
"projection" => parse_projection(ir_text, context)?,
// Unary and binary ops are spelled out in the textual format, but we
// parse them into Unary or Binary node kinds.
"not" => parse_unary(ir_text, context, UnaryOperator::Not)?,
......@@ -619,24 +620,18 @@ fn parse_index<'a>(
)
},
),
nom::combinator::map(
nom::sequence::tuple((
nom::character::complete::multispace0,
nom::bytes::complete::tag("control"),
nom::character::complete::multispace0,
nom::character::complete::char('('),
nom::character::complete::multispace0,
|x| parse_prim::<usize>(x, "1234567890"),
nom::character::complete::multispace0,
nom::character::complete::char(')'),
nom::character::complete::multispace0,
)),
|(_, _, _, _, _, x, _, _, _)| Index::Control(x),
),
))(ir_text)?;
Ok((ir_text, idx))
}
fn parse_projection<'a>(ir_text: &'a str, context: &RefCell<Context<'a>>,
) -> nom::IResult<&'a str, Node> {
let parse_usize = |x| parse_prim::<usize>(x, "1234567890");
let (ir_text, (control, index)) = parse_tuple2(parse_identifier, parse_usize)(ir_text)?;
let control = context.borrow_mut().get_node_id(control);
Ok((ir_text, Node::Projection { control, selection: index }))
}
fn parse_read<'a>(ir_text: &'a str, context: &RefCell<Context<'a>>) -> nom::IResult<&'a str, Node> {
let ir_text = nom::character::complete::multispace0(ir_text)?.0;
let ir_text = nom::character::complete::char('(')(ir_text)?.0;
......
......@@ -824,7 +824,6 @@ fn typeflow(
}
collect_id = *elem_ty_id;
}
(Type::Control(_), Index::Control(_)) => {}
_ => {
return Error(String::from(
"Read node has mismatched input type and indices.",
......@@ -870,7 +869,6 @@ fn typeflow(
}
collect_id = *elem_ty_id;
}
(Type::Control(_), Index::Control(_)) => {}
_ => {
return Error(String::from(
"Write node has mismatched input type and indices.",
......@@ -920,6 +918,10 @@ fn typeflow(
TypeSemilattice::Error(msg) => TypeSemilattice::Error(msg),
}
}
Node::Projection { control: _, selection: _ } => {
// Type is the type of the _if node
inputs[0].clone()
},
}
}
......
......@@ -127,7 +127,7 @@ fn verify_structure(
| Node::Constant { id: _ }
| Node::DynamicConstant { id: _ } => {}
_ => {
if function.nodes[user.idx()].is_strictly_control() {
if function.nodes[user.idx()].is_control() {
if found_control {
Err("A start node must have exactly one control user.")?;
} else {
......@@ -154,7 +154,7 @@ fn verify_structure(
data: _,
} => {}
_ => {
if function.nodes[user.idx()].is_strictly_control() {
if function.nodes[user.idx()].is_control() {
if found_control {
Err("A region node must have exactly one control user.")?;
} else {
......@@ -181,7 +181,7 @@ fn verify_structure(
match function.nodes[user.idx()] {
Node::ThreadID { control: _ } => {}
_ => {
if function.nodes[user.idx()].is_strictly_control() {
if function.nodes[user.idx()].is_control() {
if found_control {
Err("A fork node must have exactly one control user.")?;
} else {
......@@ -209,7 +209,7 @@ fn verify_structure(
reduct: _,
} => {}
_ => {
if function.nodes[user.idx()].is_strictly_control() {
if function.nodes[user.idx()].is_control() {
if found_control {
Err("A join node must have exactly one control user.")?;
} else {
......@@ -235,23 +235,21 @@ fn verify_structure(
Err(format!("If node must have 2 users, not {}.", users.len()))?;
}
if let (
Node::Read {
collect: _,
indices: indices1,
Node::Projection {
control: _,
selection: result1,
},
Node::Read {
collect: _,
indices: indices2,
Node::Projection {
control: _,
selection: result2,
},
) = (
&function.nodes[users[0].idx()],
&function.nodes[users[1].idx()],
) {
if indices1.len() != 1
|| indices2.len() != 1
|| !((indices1[0] == Index::Control(0) && indices2[0] == Index::Control(1))
|| (indices1[0] == Index::Control(1)
&& indices2[0] == Index::Control(0)))
if
!((*result1 == 0 && *result2 == 1) || (*result2 == 0 && *result1 == 1))
{
Err("If node's user Read nodes must reference different elements of output product.")?;
}
......@@ -315,21 +313,13 @@ fn verify_structure(
}
let mut users_covered = bitvec![u8, Lsb0; 0; users.len()];
for user in users {
if let Node::Read {
collect: _,
ref indices,
if let Node::Projection {
control: _,
ref selection,
} = function.nodes[user.idx()]
{
if indices.len() != 1 {
Err("Match node's user Read nodes must have a single index.")?;
}
let index = if let Index::Control(index) = indices[0] {
index
} else {
Err("Match node's user Read node must use a control index.")?
};
assert!(index < users.len(), "Read child of match node reads from bad index, but ran after typecheck succeeded.");
users_covered.set(index, true);
assert!(*selection < users.len(), "Read child of match node reads from bad index, but ran after typecheck succeeded.");
users_covered.set(*selection, true);
}
}
if users_covered.count_ones() != users.len() {
......
......@@ -10,4 +10,3 @@ take_mut = "*"
postcard = { version = "*", features = ["alloc"] }
serde = { version = "*", features = ["derive"] }
hercules_ir = { path = "../hercules_ir" }
hercules_cg = { path = "../hercules_cg" }
......@@ -732,22 +732,28 @@ fn ccp_flow_function(
}),
constant: ConstantLattice::bottom(),
},
// Read handles reachability when following an if or match.
Node::Read { collect, indices } => match &function.nodes[collect.idx()] {
Node::Read { collect, indices: _ } => {
CCPLattice {
reachability: inputs[collect.idx()].reachability.clone(),
constant: ConstantLattice::bottom(),
}
}
// Projection handles reachability when following an if or match.
Node::Projection { control, selection } => match &function.nodes[control.idx()] {
Node::If { control: _, cond } => {
let cond_constant = &inputs[cond.idx()].constant;
let if_reachability = &inputs[collect.idx()].reachability;
let if_constant = &inputs[collect.idx()].constant;
let if_reachability = &inputs[control.idx()].reachability;
let if_constant = &inputs[control.idx()].constant;
let new_reachability = if cond_constant.is_top() {
ReachabilityLattice::top()
} else if let ConstantLattice::Constant(cons) = cond_constant {
if let Constant::Boolean(val) = cons {
if *val && indices[0] == Index::Control(0) {
if *val && *selection == 0 {
// If condition is true and this is the false
// branch, then unreachable.
ReachabilityLattice::top()
} else if !val && indices[0] == Index::Control(1) {
} else if !val && *selection == 1 {
// If condition is true and this is the true branch,
// then unreachable.
ReachabilityLattice::top()
......@@ -768,14 +774,14 @@ fn ccp_flow_function(
}
Node::Match { control: _, sum } => {
let sum_constant = &inputs[sum.idx()].constant;
let if_reachability = &inputs[collect.idx()].reachability;
let if_constant = &inputs[collect.idx()].constant;
let if_reachability = &inputs[control.idx()].reachability;
let if_constant = &inputs[control.idx()].constant;
let new_reachability = if sum_constant.is_top() {
ReachabilityLattice::top()
} else if let ConstantLattice::Constant(cons) = sum_constant {
if let Constant::Summation(_, variant, _) = cons {
if Index::Control(*variant as usize) != indices[0] {
if *variant as usize != *selection {
// If match variant is not the same as this branch,
// then unreachable.
ReachabilityLattice::top()
......@@ -783,7 +789,7 @@ fn ccp_flow_function(
if_reachability.clone()
}
} else {
panic!("Attempted to interpret Read node, where corresponding match node has a non-summation constant input. Did typechecking succeed?")
panic!("Attempted to interpret projection node, where corresponding match node has a non-summation constant input. Did typechecking succeed?")
}
} else {
if_reachability.clone()
......@@ -794,10 +800,7 @@ fn ccp_flow_function(
constant: if_constant.clone(),
}
}
_ => CCPLattice {
reachability: inputs[collect.idx()].reachability.clone(),
constant: ConstantLattice::bottom(),
},
_ => panic!("Projection predecessor can only be an if or match node."),
},
// Write is uninterpreted for now.
Node::Write {
......
......@@ -42,19 +42,18 @@ fn guarded_fork(function: &Function,
// Identify fork nodes
let Node::Fork { control, factor } = node else { return None; };
// Whose predecessor is a read from an if
let Node::Read { collect : if_node, ref indices }
let Node::Projection { control : if_node, ref selection }
= function.nodes[control.idx()] else { return None; };
if indices.len() != 1 { return None; }
let Index::Control(branchIdx) = indices[0] else { return None; };
let Node::If { control : if_pred, cond } = function.nodes[if_node.idx()]
else { return None; };
// Whose condition is appropriate
let Node::Binary { left, right, op } = function.nodes[cond.idx()]
else { return None; };
let branch_idx = *selection;
// branchIdx == 1 means the true branch so we want the condition to be
// 0 < n or n > 0
if branchIdx == 1
if branch_idx == 1
&& !((op == BinaryOperator::LT && function.nodes[left.idx()].is_zero_constant(constants)
&& function.nodes[right.idx()].try_dynamic_constant() == Some(*factor))
|| (op == BinaryOperator::GT && function.nodes[right.idx()].is_zero_constant(constants)
......@@ -63,7 +62,7 @@ fn guarded_fork(function: &Function,
}
// branchIdx == 0 means the false branch so we want the condition to be
// n < 0 or 0 > n
if branchIdx == 0
if branch_idx == 0
&& !((op == BinaryOperator::LT && function.nodes[left.idx()].try_dynamic_constant() == Some(*factor)
&& function.nodes[right.idx()].is_zero_constant(constants))
|| (op == BinaryOperator::GT && function.nodes[right.idx()].try_dynamic_constant() == Some(*factor)
......@@ -97,13 +96,12 @@ fn guarded_fork(function: &Function,
return None;
};
// Other predecessor needs to be the other read from the guard's if
let Node::Read { collect : read_control, ref indices }
let Node::Projection { control : if_node2, ref selection }
= function.nodes[other_pred.idx()]
else { return None; };
if indices.len() != 1 { return None; }
let Index::Control(elseBranch) = indices[0] else { return None; };
if elseBranch == branchIdx { return None; }
if read_control != if_node { return None; }
let else_branch = *selection;
if else_branch == branch_idx { return None; }
if if_node2 != if_node { return None; }
// Finally, identify the phi nodes associated with the region and match
// them with the reduce nodes of the fork-join
......
......@@ -40,7 +40,7 @@ pub fn forkify(
}
// Check for a very particular loop indexing structure.
let if_ctrl = function.nodes[single_pred_loop.idx()].try_control_read(1)?;
let if_ctrl = function.nodes[single_pred_loop.idx()].try_projection(1)?;
let (_, if_cond) = function.nodes[if_ctrl.idx()].try_if()?;
let (idx, bound) = function.nodes[if_cond.idx()].try_binary(BinaryOperator::LT)?;
let (phi, one) = function.nodes[idx.idx()].try_binary(BinaryOperator::Add)?;
......@@ -123,13 +123,13 @@ pub fn forkify(
.next()
.unwrap();
let loop_end = function.nodes[loop_true_read.idx()]
.try_control_read(1)
.try_projection(1)
.unwrap();
let loop_false_read = *def_use
.get_users(loop_end)
.iter()
.filter_map(|id| {
if function.nodes[id.idx()].try_control_read(0).is_some() {
if function.nodes[id.idx()].try_projection(0).is_some() {
Some(id)
} else {
None
......
extern crate hercules_cg;
// extern crate hercules_cg;
extern crate hercules_ir;
extern crate postcard;
extern crate serde;
......@@ -12,7 +12,7 @@ use std::process::*;
use self::serde::Deserialize;
use self::hercules_cg::*;
// use self::hercules_cg::*;
use self::hercules_ir::*;
use crate::*;
......@@ -434,6 +434,7 @@ impl PassManager {
continue;
}
Pass::Codegen(output_file_name) => {
/*
self.make_def_uses();
self.make_reverse_postorders();
self.make_typing();
......@@ -485,8 +486,8 @@ impl PassManager {
file.write_all(&hbin_contents)
.expect("PANIC: Unable to write output file contents.");
// Codegen doesn't require clearing analysis results.
continue;
// Codegen doesn't require clearing analysis results.*/
continue;
}
}
......
......@@ -104,16 +104,14 @@ pub fn predication(
}
// Introduce condition variables into sets, as this is where
// branching occurs.
Node::Read {
collect,
ref indices,
Node::Projection {
control,
ref selection,
} => {
assert_eq!(indices.len(), 1);
let truth_value = indices[0].try_control().unwrap();
assert!(truth_value < 2);
let mut sets = condition_valuations[&collect].clone();
let condition = function.nodes[collect.idx()].try_if().unwrap().1;
if truth_value == 0 {
assert!(*selection < 2);
let mut sets = condition_valuations[&control].clone();
let condition = function.nodes[control.idx()].try_if().unwrap().1;
if *selection == 0 {
sets.0.insert(condition);
} else {
sets.1.insert(condition);
......
......@@ -6,14 +6,14 @@ fn tricky(x: i32) -> i32
val = phi(loop, one, later_val)
b = ne(one, val)
if1 = if(loop, b)
if1_false = read(if1, control(0))
if1_true = read(if1, control(1))
if1_false = projection(if1, 0)
if1_true = projection(if1, 1)
middle = region(if1_false, if1_true)
inter_val = sub(two, val)
later_val = phi(middle, inter_val, two)
idx_dec = sub(idx, one)
cond = gte(idx_dec, one)
if2 = if(middle, cond)
if2_false = read(if2, control(0))
if2_true = read(if2, control(1))
if2_false = projection(if2, 0)
if2_true = projection(if2, 1)
r = return(if2_false, later_val)
......@@ -8,6 +8,6 @@ fn simple2(x: i32) -> i32
fac_acc = mul(fac, idx_inc)
in_bounds = lt(idx_inc, x)
if = if(loop, in_bounds)
if_false = read(if, control(0))
if_true = read(if, control(1))
if_false = projection(if, 0)
if_true = projection(if, 1)
r = return(if_false, fac_acc)
\ No newline at end of file
......@@ -12,6 +12,6 @@ fn strset<1>(str: array(u8, #0), byte: u8) -> array(u8, #0)
continue = ne(read, byte)
if_cond = and(continue, in_bounds)
if = if(loop, if_cond)
if_false = read(if, control(0))
if_true = read(if, control(1))
if_false = projection(if, 0)
if_true = projection(if, 1)
r = return(if_false, str_inc)
......@@ -11,8 +11,8 @@ fn sum<1>(a: array(f32, #0)) -> f32
red_add = add(red, read)
in_bounds = lt(idx_inc, bound)
if = if(loop, in_bounds)
if_false = read(if, control(0))
if_true = read(if, control(1))
if_false = projection(if, 0)
if_true = projection(if, 1)
r = return(if_false, red_add)
fn alt_sum<1>(a: array(f32, #0)) -> f32
......@@ -28,8 +28,8 @@ fn alt_sum<1>(a: array(f32, #0)) -> f32
rem = rem(idx, two_idx)
odd = eq(rem, one_idx)
negate_if = if(loop, odd)
negate_if_false = read(negate_if, control(0))
negate_if_true = read(negate_if, control(1))
negate_if_false = projection(negate_if, 0)
negate_if_true = projection(negate_if, 1)
negate_bottom = region(negate_if_false, negate_if_true)
read = read(a, position(idx))
read_neg = neg(read)
......@@ -37,6 +37,6 @@ fn alt_sum<1>(a: array(f32, #0)) -> f32
red_add = add(red, read_phi)
in_bounds = lt(idx_inc, bound)
if = if(negate_bottom, in_bounds)
if_false = read(if, control(0))
if_true = read(if, control(1))
if_false = projection(if, 0)
if_true = projection(if, 1)
r = return(if_false, red_add)
......@@ -42,12 +42,10 @@ impl SSA {
let right_proj = right_builder.id();
// True branch
let proj_left = builder.create_control_index(1);
left_builder.build_read(if_builder.id(), vec![proj_left].into());
left_builder.build_projection(if_builder.id(), 1);
// False branch
let proj_right = builder.create_control_index(0);
right_builder.build_read(if_builder.id(), vec![proj_right].into());
right_builder.build_projection(if_builder.id(), 0);
let _ = builder.add_node(left_builder);
let _ = builder.add_node(right_builder);
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment