-
Russel Arbore authoredRussel Arbore authored
float_collections.rs 4.74 KiB
use std::collections::{BTreeMap, HashMap};
use hercules_ir::*;
use crate::*;
/*
* Float collections constants out of device functions, where allocation isn't
* allowed.
*/
pub fn float_collections(
editors: &mut BTreeMap<FunctionID, FunctionEditor>,
typing: &ModuleTyping,
callgraph: &CallGraph,
devices: &Vec<Device>,
) {
let topo: Vec<_> = callgraph
.topo()
.into_iter()
.filter(|id| editors.contains_key(&id))
.collect();
for to_float_id in topo {
// Collection constants float until reaching an AsyncRust function.
if devices[to_float_id.idx()] == Device::AsyncRust {
continue;
}
// Check that all callers are in the selection as well.
for caller in callgraph.get_callers(to_float_id) {
assert!(editors.contains_key(&caller), "PANIC: FloatCollections called where a function ({:?}, {:?}) is in the selection but one of its callers ({:?}) is not. This means no collections will be floated from the callee, since the caller can't be modified to hold floated collections.", to_float_id, editors[&to_float_id].func().name, caller);
}
// Find the target constant nodes in the function.
let cons: Vec<(NodeID, Node)> = editors[&to_float_id]
.func()
.nodes
.iter()
.enumerate()
.filter(|(_, node)| {
node.try_constant()
.map(|cons_id| !editors[&to_float_id].get_constant(cons_id).is_scalar())
.unwrap_or(false)
})
.map(|(idx, node)| (NodeID::new(idx), node.clone()))
.collect();
if cons.is_empty() {
continue;
}
// Each constant node becomes a new parameter.
let mut new_param_types = editors[&to_float_id].func().param_types.clone();
let old_num_params = new_param_types.len();
for (id, _) in cons.iter() {
new_param_types.push(typing[to_float_id.idx()][id.idx()]);
}
let success = editors.get_mut(&to_float_id).unwrap().edit(|mut edit| {
for (idx, (id, _)) in cons.iter().enumerate() {
let param = edit.add_node(Node::Parameter {
index: idx + old_num_params,
});
edit = edit.replace_all_uses(*id, param)?;
edit = edit.delete_node(*id)?;
}
edit.set_param_types(new_param_types);
Ok(edit)
});
if !success {
continue;
}
// Add constants in callers and pass them into calls.
for caller in callgraph.get_callers(to_float_id) {
let calls: Vec<(NodeID, Node)> = editors[&caller]
.func()
.nodes
.iter()
.enumerate()
.filter(|(_, node)| {
node.try_call()
.map(|(_, callee, _, _)| callee == to_float_id)
.unwrap_or(false)
})
.map(|(idx, node)| (NodeID::new(idx), node.clone()))
.collect();
let success = editors.get_mut(&caller).unwrap().edit(|mut edit| {
for (id, node) in calls {
let Node::Call {
control,
function,
dynamic_constants,
args,
} = node
else {
panic!()
};
let dc_args = (0..dynamic_constants.len())
.map(|i| edit.add_dynamic_constant(DynamicConstant::Parameter(i)));
let substs = dc_args
.zip(dynamic_constants.iter().map(|id| *id))
.collect::<HashMap<_, _>>();
let cons_ids: Vec<_> = cons
.iter()
.map(|(_, node)| {
let mut node = node.clone();
substitute_dynamic_constants_in_node(&substs, &mut node, &mut edit);
edit.add_node(node)
})
.collect();
let mut args = Vec::from(args);
args.extend(cons_ids.iter());
let new_call = edit.add_node(Node::Call {
control,
function,
dynamic_constants,
args: args.into_boxed_slice(),
});
edit = edit.replace_all_uses(id, new_call)?;
edit = edit.delete_node(id)?;
}
Ok(edit)
});
assert!(success);
}
}
}