float_collections.rs 3.61 KiB
extern crate hercules_ir;
use self::hercules_ir::*;
use crate::*;
/*
* Float collections constants out of device functions, where allocation isn't
* allowed.
*/
pub fn float_collections(
editors: &mut [FunctionEditor],
typing: &ModuleTyping,
callgraph: &CallGraph,
devices: &Vec<Device>,
) {
let topo = callgraph.topo();
for to_float_id in topo {
// Collection constants float until reaching an AsyncRust function.
if devices[to_float_id.idx()] == Device::AsyncRust {
continue;
}
// Find the target constant nodes in the function.
let cons: Vec<(NodeID, Node)> = editors[to_float_id.idx()]
.func()
.nodes
.iter()
.enumerate()
.filter(|(_, node)| {
node.try_constant()
.map(|cons_id| !editors[to_float_id.idx()].get_constant(cons_id).is_scalar())
.unwrap_or(false)
})
.map(|(idx, node)| (NodeID::new(idx), node.clone()))
.collect();
if cons.is_empty() {
continue;
}
// Each constant node becomes a new parameter.
let mut new_param_types = editors[to_float_id.idx()].func().param_types.clone();
let old_num_params = new_param_types.len();
for (id, _) in cons.iter() {
new_param_types.push(typing[to_float_id.idx()][id.idx()]);
}
let success = editors[to_float_id.idx()].edit(|mut edit| {
for (idx, (id, _)) in cons.iter().enumerate() {
let param = edit.add_node(Node::Parameter {
index: idx + old_num_params,
});
edit = edit.replace_all_uses(*id, param)?;
edit = edit.delete_node(*id)?;
}
edit.set_param_types(new_param_types);
Ok(edit)
});
if !success {
continue;
}
// Add constants in callers and pass them into calls.
for caller in callgraph.get_callers(to_float_id) {
let calls: Vec<(NodeID, Node)> = editors[caller.idx()]
.func()
.nodes
.iter()
.enumerate()
.filter(|(_, node)| {
node.try_call()
.map(|(_, callee, _, _)| callee == to_float_id)
.unwrap_or(false)
})
.map(|(idx, node)| (NodeID::new(idx), node.clone()))
.collect();
let success = editors[caller.idx()].edit(|mut edit| {
let cons_ids: Vec<_> = cons
.iter()
.map(|(_, node)| edit.add_node(node.clone()))
.collect();
for (id, node) in calls {
let Node::Call {
control,
function,
dynamic_constants,
args,
} = node
else {
panic!()
};
let mut args = Vec::from(args);
args.extend(cons_ids.iter());
let new_call = edit.add_node(Node::Call {
control,
function,
dynamic_constants,
args: args.into_boxed_slice(),
});
edit = edit.replace_all_uses(id, new_call)?;
edit = edit.delete_node(id)?;
}
Ok(edit)
});
assert!(success);
}
}
}