Skip to content
Snippets Groups Projects
float_collections.rs 3.61 KiB
extern crate hercules_ir;

use self::hercules_ir::*;

use crate::*;

/*
 * Float collections constants out of device functions, where allocation isn't
 * allowed.
 */
pub fn float_collections(
    editors: &mut [FunctionEditor],
    typing: &ModuleTyping,
    callgraph: &CallGraph,
    devices: &Vec<Device>,
) {
    let topo = callgraph.topo();
    for to_float_id in topo {
        // Collection constants float until reaching an AsyncRust function.
        if devices[to_float_id.idx()] == Device::AsyncRust {
            continue;
        }

        // Find the target constant nodes in the function.
        let cons: Vec<(NodeID, Node)> = editors[to_float_id.idx()]
            .func()
            .nodes
            .iter()
            .enumerate()
            .filter(|(_, node)| {
                node.try_constant()
                    .map(|cons_id| !editors[to_float_id.idx()].get_constant(cons_id).is_scalar())
                    .unwrap_or(false)
            })
            .map(|(idx, node)| (NodeID::new(idx), node.clone()))
            .collect();
        if cons.is_empty() {
            continue;
        }

        // Each constant node becomes a new parameter.
        let mut new_param_types = editors[to_float_id.idx()].func().param_types.clone();
        let old_num_params = new_param_types.len();
        for (id, _) in cons.iter() {
            new_param_types.push(typing[to_float_id.idx()][id.idx()]);
        }
        let success = editors[to_float_id.idx()].edit(|mut edit| {
            for (idx, (id, _)) in cons.iter().enumerate() {
                let param = edit.add_node(Node::Parameter {
                    index: idx + old_num_params,
                });
                edit = edit.replace_all_uses(*id, param)?;
                edit = edit.delete_node(*id)?;
            }
            edit.set_param_types(new_param_types);
            Ok(edit)
        });
        if !success {
            continue;
        }

        // Add constants in callers and pass them into calls.
        for caller in callgraph.get_callers(to_float_id) {
            let calls: Vec<(NodeID, Node)> = editors[caller.idx()]
                .func()
                .nodes
                .iter()
                .enumerate()
                .filter(|(_, node)| {
                    node.try_call()
                        .map(|(_, callee, _, _)| callee == to_float_id)
                        .unwrap_or(false)
                })
                .map(|(idx, node)| (NodeID::new(idx), node.clone()))
                .collect();
            let success = editors[caller.idx()].edit(|mut edit| {
                let cons_ids: Vec<_> = cons
                    .iter()
                    .map(|(_, node)| edit.add_node(node.clone()))
                    .collect();
                for (id, node) in calls {
                    let Node::Call {
                        control,
                        function,
                        dynamic_constants,
                        args,
                    } = node
                    else {
                        panic!()
                    };
                    let mut args = Vec::from(args);
                    args.extend(cons_ids.iter());
                    let new_call = edit.add_node(Node::Call {
                        control,
                        function,
                        dynamic_constants,
                        args: args.into_boxed_slice(),
                    });
                    edit = edit.replace_all_uses(id, new_call)?;
                    edit = edit.delete_node(id)?;
                }
                Ok(edit)
            });
            assert!(success);
        }
    }
}