Skip to content
Snippets Groups Projects
float_collections.rs 4.74 KiB
use std::collections::{BTreeMap, HashMap};

use hercules_ir::*;

use crate::*;

/*
 * Float collections constants out of device functions, where allocation isn't
 * allowed.
 */
pub fn float_collections(
    editors: &mut BTreeMap<FunctionID, FunctionEditor>,
    typing: &ModuleTyping,
    callgraph: &CallGraph,
    devices: &Vec<Device>,
) {
    let topo: Vec<_> = callgraph
        .topo()
        .into_iter()
        .filter(|id| editors.contains_key(&id))
        .collect();
    for to_float_id in topo {
        // Collection constants float until reaching an AsyncRust function.
        if devices[to_float_id.idx()] == Device::AsyncRust {
            continue;
        }

        // Check that all callers are in the selection as well.
        for caller in callgraph.get_callers(to_float_id) {
            assert!(editors.contains_key(&caller), "PANIC: FloatCollections called where a function ({:?}, {:?}) is in the selection but one of its callers ({:?}) is not. This means no collections will be floated from the callee, since the caller can't be modified to hold floated collections.", to_float_id, editors[&to_float_id].func().name, caller);
        }

        // Find the target constant nodes in the function.
        let cons: Vec<(NodeID, Node)> = editors[&to_float_id]
            .func()
            .nodes
            .iter()
            .enumerate()
            .filter(|(_, node)| {
                node.try_constant()
                    .map(|cons_id| !editors[&to_float_id].get_constant(cons_id).is_scalar())
                    .unwrap_or(false)
            })
            .map(|(idx, node)| (NodeID::new(idx), node.clone()))
            .collect();
        if cons.is_empty() {
            continue;
        }

        // Each constant node becomes a new parameter.
        let mut new_param_types = editors[&to_float_id].func().param_types.clone();
        let old_num_params = new_param_types.len();
        for (id, _) in cons.iter() {
            new_param_types.push(typing[to_float_id.idx()][id.idx()]);
        }
        let success = editors.get_mut(&to_float_id).unwrap().edit(|mut edit| {
            for (idx, (id, _)) in cons.iter().enumerate() {
                let param = edit.add_node(Node::Parameter {
                    index: idx + old_num_params,
                });
                edit = edit.replace_all_uses(*id, param)?;
                edit = edit.delete_node(*id)?;
            }
            edit.set_param_types(new_param_types);
            Ok(edit)
        });
        if !success {
            continue;
        }
        // Add constants in callers and pass them into calls.
        for caller in callgraph.get_callers(to_float_id) {
            let calls: Vec<(NodeID, Node)> = editors[&caller]
                .func()
                .nodes
                .iter()
                .enumerate()
                .filter(|(_, node)| {
                    node.try_call()
                        .map(|(_, callee, _, _)| callee == to_float_id)
                        .unwrap_or(false)
                })
                .map(|(idx, node)| (NodeID::new(idx), node.clone()))
                .collect();
            let success = editors.get_mut(&caller).unwrap().edit(|mut edit| {
                for (id, node) in calls {
                    let Node::Call {
                        control,
                        function,
                        dynamic_constants,
                        args,
                    } = node
                    else {
                        panic!()
                    };
                    let dc_args = (0..dynamic_constants.len())
                        .map(|i| edit.add_dynamic_constant(DynamicConstant::Parameter(i)));
                    let substs = dc_args
                        .zip(dynamic_constants.iter().map(|id| *id))
                        .collect::<HashMap<_, _>>();
                    let cons_ids: Vec<_> = cons
                        .iter()
                        .map(|(_, node)| {
                            let mut node = node.clone();
                            substitute_dynamic_constants_in_node(&substs, &mut node, &mut edit);
                            edit.add_node(node)
                        })
                        .collect();
                    let mut args = Vec::from(args);
                    args.extend(cons_ids.iter());
                    let new_call = edit.add_node(Node::Call {
                        control,
                        function,
                        dynamic_constants,
                        args: args.into_boxed_slice(),
                    });
                    edit = edit.replace_all_uses(id, new_call)?;
                    edit = edit.delete_node(id)?;
                }
                Ok(edit)
            });
            assert!(success);
        }
    }
}