diff --git a/hercules_opt/src/sroa.rs b/hercules_opt/src/sroa.rs index 74b4352bad44e32ade7abce04d081f5303425450..ef3eceab898147011f1ab932686afbbda6b4fc45 100644 --- a/hercules_opt/src/sroa.rs +++ b/hercules_opt/src/sroa.rs @@ -2,6 +2,7 @@ extern crate bitvec; extern crate hercules_ir; use std::collections::HashMap; +use std::iter::zip; use self::bitvec::prelude::*; @@ -199,7 +200,7 @@ pub fn sroa( // Assemble a mapping from old nodes IDs acting on the product constant // to new nodes IDs operating on the field constants. - let map: HashMap<NodeID, Vec<NodeID>> = to_replace + let old_to_new_id_map: HashMap<NodeID, Vec<NodeID>> = to_replace .iter() .map(|old_id| match function.nodes[old_id.idx()] { Node::Phi { @@ -240,7 +241,55 @@ pub fn sroa( }) .collect(); - // Replace the old nodes with the new nodes. + // Replace the old nodes with the new nodes. Since we've already + // allocated the node IDs, at this point we can iterate through the to- + // replace nodes in an arbitrary order. + for (old_id, new_ids) in &old_to_new_id_map { + // First, add the new nodes to the node list. + let node = function.nodes[old_id.idx()].clone(); + match node { + // Replace the original constant with constants for each field. + Node::Constant { id: _ } => { + for (new_id, field_id) in zip(new_ids.iter(), constant_fields.iter()) { + function.nodes[new_id.idx()] = Node::Constant { id: *field_id }; + } + } + // Replace writes using the constant as the data use with a + // series of writes writing the invidiual constant fields. TODO: + // handle the case where the constant is the collect use of the + // write node. + Node::Write { + collect, + data, + ref indices, + } => { + // Create the write chain. + assert!(old_to_new_id_map.contains_key(&data), "PANIC: Can't handle case where write node depends on constant to SROA in the collect use yet."); + let mut collect_def = collect; + for (idx, (new_id, new_data_def)) in + zip(new_ids.iter(), old_to_new_id_map[&data].iter()).enumerate() + { + let mut new_indices = indices.clone().into_vec(); + new_indices.push(Index::Field(idx)); + function.nodes[new_id.idx()] = Node::Write { + collect: collect_def, + data: *new_data_def, + indices: new_indices.into_boxed_slice(), + }; + collect_def = *new_id; + } + + // Replace uses of the old write with the new write. + for user in def_use.get_users(*old_id) { + get_uses_mut(&mut function.nodes[user.idx()]).map(*old_id, collect_def); + } + } + _ => todo!(), + } + + // Delete the old node. + function.nodes[old_id.idx()] = Node::Start; + } } }