editor.rs 38.15 KiB
use std::borrow::Borrow;
use std::cell::{Ref, RefCell};
use std::collections::{BTreeMap, HashSet};
use std::mem::take;
use std::ops::Deref;
use bitvec::prelude::*;
use either::Either;
use hercules_ir::def_use::*;
use hercules_ir::ir::*;
use hercules_ir::DynamicConstantView;
/*
* Helper object for editing Hercules functions in a trackable manner. Edits
* must be made atomically, that is, only one `.edit` may be called at a time
* across all editors, and individual edits must leave the function in a valid
* state.
*/
#[derive(Debug)]
pub struct FunctionEditor<'a> {
// Wraps a mutable reference to a function. Doesn't provide access to this
// reference directly, so that we can monitor edits.
function: &'a mut Function,
function_id: FunctionID,
// Keep a RefCell to (dynamic) constants and types to allow function changes
// to update these
constants: &'a RefCell<Vec<Constant>>,
dynamic_constants: &'a RefCell<Vec<DynamicConstant>>,
types: &'a RefCell<Vec<Type>>,
// Keep a RefCell to the string table that tracks labels, so that new labels
// can be added as needed
labels: &'a RefCell<Vec<String>>,
// Most optimizations need def use info, so provide an iteratively updated
// mutable version that's automatically updated based on recorded edits.
mut_def_use: Vec<HashSet<NodeID>>,
// The pass manager may indicate that only a certain subset of nodes should
// be modified in a function - what this actually means is that some nodes
// are off limits for deletion (equivalently modification) or being replaced
// as a use.
mutable_nodes: BitVec<u8, Lsb0>,
// Tracks whether this editor has been used to make any edits to the IR of
// this function
modified: bool,
}
/*
* Helper objects to make a single edit.
*/
#[derive(Debug)]
pub struct FunctionEdit<'a: 'b, 'b> {
// Reference the active function editor.
editor: &'b mut FunctionEditor<'a>,
// Keep track of deleted node IDs.
deleted_nodeids: HashSet<NodeID>,
// Keep track of added node IDs.
added_nodeids: HashSet<NodeID>,
// Keep track of added and use updated nodes.
added_and_updated_nodes: BTreeMap<NodeID, Node>,
// Keep track of added and updated schedules.
added_and_updated_schedules: BTreeMap<NodeID, Vec<Schedule>>,
// Keep track of added and updated labels.
added_and_updated_labels: BTreeMap<NodeID, HashSet<LabelID>>,
// Keep track of added (dynamic) constants, types, and labels
added_constants: Vec<Constant>,
added_dynamic_constants: Vec<DynamicConstant>,
added_types: Vec<Type>,
added_labels: Vec<String>,
// Compute a def-use map entries iteratively.
updated_def_use: BTreeMap<NodeID, HashSet<NodeID>>,
updated_param_types: Option<Vec<TypeID>>,
updated_return_types: Option<Vec<TypeID>>,
// Keep track of which deleted and added node IDs directly correspond.
sub_edits: Vec<(NodeID, NodeID)>,
}
impl<'a: 'b, 'b> FunctionEditor<'a> {
pub fn new(
function: &'a mut Function,
function_id: FunctionID,
constants: &'a RefCell<Vec<Constant>>,
dynamic_constants: &'a RefCell<Vec<DynamicConstant>>,
types: &'a RefCell<Vec<Type>>,
labels: &'a RefCell<Vec<String>>,
def_use: &ImmutableDefUseMap,
) -> Self {
let mut_def_use = (0..function.nodes.len())
.map(|idx| {
def_use
.get_users(NodeID::new(idx))
.into_iter()
.map(|x| *x)
.collect()
})
.collect();
let mutable_nodes = bitvec![u8, Lsb0; 1; function.nodes.len()];
FunctionEditor {
function,
function_id,
constants,
dynamic_constants,
types,
labels,
mut_def_use,
mutable_nodes,
modified: false,
}
}
// Constructs an editor with a specified mask determining which nodes are mutable
pub fn new_mask(
function: &'a mut Function,
function_id: FunctionID,
constants: &'a RefCell<Vec<Constant>>,
dynamic_constants: &'a RefCell<Vec<DynamicConstant>>,
types: &'a RefCell<Vec<Type>>,
labels: &'a RefCell<Vec<String>>,
def_use: &ImmutableDefUseMap,
mask: BitVec<u8, Lsb0>,
) -> Self {
let mut_def_use = (0..function.nodes.len())
.map(|idx| {
def_use
.get_users(NodeID::new(idx))
.into_iter()
.map(|x| *x)
.collect()
})
.collect();
FunctionEditor {
function,
function_id,
constants,
dynamic_constants,
types,
labels,
mut_def_use,
mutable_nodes: mask,
modified: false,
}
}
// Constructs an editor but makes every node immutable.
pub fn new_immutable(
function: &'a mut Function,
function_id: FunctionID,
constants: &'a RefCell<Vec<Constant>>,
dynamic_constants: &'a RefCell<Vec<DynamicConstant>>,
types: &'a RefCell<Vec<Type>>,
labels: &'a RefCell<Vec<String>>,
def_use: &ImmutableDefUseMap,
) -> Self {
let mut_def_use = (0..function.nodes.len())
.map(|idx| {
def_use
.get_users(NodeID::new(idx))
.into_iter()
.map(|x| *x)
.collect()
})
.collect();
let mutable_nodes = bitvec![u8, Lsb0; 0; function.nodes.len()];
FunctionEditor {
function,
function_id,
constants,
dynamic_constants,
types,
labels,
mut_def_use,
mutable_nodes,
modified: false,
}
}
pub fn modified(&self) -> bool {
self.modified
}
pub fn edit<F>(&'b mut self, edit: F) -> bool
where
F: FnOnce(FunctionEdit<'a, 'b>) -> Result<FunctionEdit<'a, 'b>, FunctionEdit<'a, 'b>>,
{
// Create the edit helper struct and perform the edit using it.
let edit_obj = FunctionEdit {
editor: self,
deleted_nodeids: HashSet::new(),
added_nodeids: HashSet::new(),
added_and_updated_nodes: BTreeMap::new(),
added_and_updated_schedules: BTreeMap::new(),
added_and_updated_labels: BTreeMap::new(),
added_constants: Vec::new().into(),
added_dynamic_constants: Vec::new().into(),
added_types: Vec::new().into(),
added_labels: Vec::new().into(),
updated_def_use: BTreeMap::new(),
updated_param_types: None,
updated_return_types: None,
sub_edits: vec![],
};
if let Ok(populated_edit) = edit(edit_obj) {
// If the populated edit is returned, then the edit can be performed
// without modifying immutable nodes.
let FunctionEdit {
editor,
deleted_nodeids,
added_nodeids,
added_and_updated_nodes,
added_and_updated_schedules,
added_and_updated_labels,
added_constants,
added_dynamic_constants,
added_types,
added_labels,
updated_def_use,
updated_param_types,
updated_return_types,
sub_edits,
} = populated_edit;
// Step 0: determine whether the edit changed the IR by checking if
// any nodes were deleted, added, or updated in any way
editor.modified |= !deleted_nodeids.is_empty()
|| !added_nodeids.is_empty()
|| !added_and_updated_nodes.is_empty()
|| !added_and_updated_schedules.is_empty()
|| !added_and_updated_labels.is_empty();
// Step 1: update the mutable def use map.
for (u, new_users) in updated_def_use {
// Go through new def-use entries in order. These are either
// updates to existing nodes, in which case we just modify them
// in place, or they are user entries for new nodes, in which
// case we push them.
if u.idx() < editor.mut_def_use.len() {
editor.mut_def_use[u.idx()] = new_users;
} else {
// The new nodes must be traversed in order in a packed
// fashion - if a new node was created without an
// accompanying new users entry, something is wrong!
assert_eq!(editor.mut_def_use.len(), u.idx());
editor.mut_def_use.push(new_users);
}
}
// Step 2: add and update nodes.
for (id, node) in added_and_updated_nodes {
if id.idx() < editor.function.nodes.len() {
editor.function.nodes[id.idx()] = node;
} else {
// New nodes should've been assigned increasing IDs starting
// at the previous number of nodes, so check that.
assert_eq!(editor.function.nodes.len(), id.idx());
editor.function.nodes.push(node);
}
}
// Step 3.0: add and update schedules.
editor
.function
.schedules
.resize(editor.function.nodes.len(), vec![]);
for (id, schedule) in added_and_updated_schedules {
editor.function.schedules[id.idx()] = schedule;
}
// Step 3.1: add and update labels.
editor
.function
.labels
.resize(editor.function.nodes.len(), HashSet::new());
for (id, label) in added_and_updated_labels {
editor.function.labels[id.idx()] = label;
}
// Step 4: delete nodes. This is done using "gravestones", where a
// node other than node ID 0 being a start node is considered a
// gravestone.
for id in deleted_nodeids.iter() {
// Check that there are no users of deleted nodes.
assert!(
editor.mut_def_use[id.idx()].is_empty(),
"PANIC: Attempted to delete node {:?}, but there are still users of this node ({:?}).",
id, editor.mut_def_use[id.idx()]
);
editor.function.nodes[id.idx()] = Node::Start;
}
// Step 5.0: propagate schedules along sub-edit edges.
for (src, dst) in sub_edits.iter() {
let mut dst_schedules = take(&mut editor.function.schedules[dst.idx()]);
for src_schedule in editor.function.schedules[src.idx()].iter() {
if !dst_schedules.contains(src_schedule) {
dst_schedules.push(src_schedule.clone());
}
}
editor.function.schedules[dst.idx()] = dst_schedules;
}
// Step 5.1: update and propagate labels
editor.labels.borrow_mut().extend(added_labels);
// We propagate labels in two steps, first along sub-edits and then
// all the labels on any deleted node not used in any sub-edit to all
// added nodes not in any sub-edit
let mut sources = deleted_nodeids.clone();
let mut dests = added_nodeids.clone();
for (src, dst) in sub_edits {
let mut dst_labels = take(&mut editor.function.labels[dst.idx()]);
dst_labels.extend(editor.function.labels[src.idx()].iter());
editor.function.labels[dst.idx()] = dst_labels;
sources.remove(&src);
dests.remove(&dst);
}
let mut src_labels = HashSet::new();
for src in sources {
src_labels.extend(editor.function.labels[src.idx()].clone());
}
for dst in dests {
editor.function.labels[dst.idx()].extend(src_labels.clone());
}
// Step 6: update the length of mutable_nodes. All added nodes are
// mutable.
editor
.mutable_nodes
.resize(editor.function.nodes.len(), true);
// Step 7: update types and constants.
let mut editor_constants = editor.constants.borrow_mut();
let mut editor_dynamic_constants = editor.dynamic_constants.borrow_mut();
let mut editor_types = editor.types.borrow_mut();
editor_constants.extend(added_constants);
editor_dynamic_constants.extend(added_dynamic_constants);
editor_types.extend(added_types);
// Step 8: update parameter types if necessary.
if let Some(param_types) = updated_param_types {
editor.function.param_types = param_types;
}
// Step 9: update return type if necessary.
if let Some(return_types) = updated_return_types {
editor.function.return_types = return_types;
}
true
} else {
false
}
}
pub fn func(&self) -> &Function {
&self.function
}
pub fn func_id(&self) -> FunctionID {
self.function_id
}
pub fn node(&self, node: impl Borrow<NodeID>) -> &Node {
&self.function.nodes[node.borrow().idx()]
}
pub fn get_types(&self) -> Ref<'_, Vec<Type>> {
self.types.borrow()
}
pub fn get_constants(&self) -> Ref<'_, Vec<Constant>> {
self.constants.borrow()
}
pub fn get_dynamic_constants(&self) -> Ref<'_, Vec<DynamicConstant>> {
self.dynamic_constants.borrow()
}
pub fn get_users(&self, id: NodeID) -> impl ExactSizeIterator<Item = NodeID> + '_ {
self.mut_def_use[id.idx()].iter().map(|x| *x)
}
pub fn get_uses(&self, id: NodeID) -> impl ExactSizeIterator<Item = NodeID> + '_ {
get_uses(&self.function.nodes[id.idx()])
.as_ref()
.into_iter()
.map(|x| *x)
.collect::<Vec<_>>()
.into_iter()
}
pub fn get_type(&self, id: TypeID) -> Ref<'_, Type> {
Ref::map(self.types.borrow(), |types| &types[id.idx()])
}
pub fn get_constant(&self, id: ConstantID) -> Ref<'_, Constant> {
Ref::map(self.constants.borrow(), |constants| &constants[id.idx()])
}
pub fn get_dynamic_constant(&self, id: DynamicConstantID) -> Ref<'_, DynamicConstant> {
Ref::map(self.dynamic_constants.borrow(), |dynamic_constants| {
&dynamic_constants[id.idx()]
})
}
pub fn is_mutable(&self, id: NodeID) -> bool {
self.mutable_nodes[id.idx()]
}
pub fn node_ids(&self) -> impl ExactSizeIterator<Item = NodeID> {
let num = self.function.nodes.len();
(0..num).map(NodeID::new)
}
pub fn dynamic_constant_ids(&self) -> impl ExactSizeIterator<Item = DynamicConstantID> {
let num = self.dynamic_constants.borrow().len();
(0..num).map(DynamicConstantID::new)
}
}
impl<'a, 'b> FunctionEdit<'a, 'b> {
fn ensure_updated_def_use_entry(&mut self, id: NodeID) {
if !self.updated_def_use.contains_key(&id) {
let old_entry = self
.editor
.mut_def_use
.get(id.idx())
.map(|entry| entry.clone())
.unwrap_or_default();
self.updated_def_use.insert(id, old_entry);
}
}
fn is_mutable(&self, id: NodeID) -> bool {
id.idx() >= self.editor.mutable_nodes.len() || self.editor.mutable_nodes[id.idx()]
}
pub fn num_dynamic_constants(&self) -> usize {
self.editor.dynamic_constants.borrow().len() + self.added_dynamic_constants.len()
}
pub fn num_node_ids(&self) -> usize {
self.editor.function.nodes.len() + self.added_nodeids.len()
}
pub fn copy_node(&mut self, node: NodeID) -> NodeID {
self.add_node(self.editor.func().nodes[node.idx()].clone())
}
pub fn add_node(&mut self, node: Node) -> NodeID {
let id = NodeID::new(self.num_node_ids());
// Added nodes need to have an entry in the def-use map.
self.updated_def_use.entry(id).or_insert(HashSet::new());
// Added nodes use other nodes, and we need to update their def-use
// entries.
for u in get_uses(&node).as_ref() {
self.ensure_updated_def_use_entry(*u);
self.updated_def_use.get_mut(u).unwrap().insert(id);
}
// Add the node.
self.added_and_updated_nodes.insert(id, node);
self.added_nodeids.insert(id);
id
}
pub fn delete_node(mut self, id: NodeID) -> Result<Self, Self> {
// We can only delete mutable nodes. Return None if we try to modify an
// immutable node, as it means the whole edit should be aborted.
if self.is_mutable(id) {
assert!(
!self.added_nodeids.contains(&id),
"PANIC: Please don't delete a node that was added in the same edit."
);
// Deleted nodes use other nodes, and we need to update their def-
// use entries.
let uses: Box<[NodeID]> = get_uses(&self.editor.function.nodes[id.idx()])
.as_ref()
.into();
for u in uses {
self.ensure_updated_def_use_entry(u);
self.updated_def_use.get_mut(&u).unwrap().remove(&id);
}
self.deleted_nodeids.insert(id);
Ok(self)
} else {
Err(self)
}
}
pub fn replace_all_uses(self, old: NodeID, new: NodeID) -> Result<Self, Self> {
self.replace_all_uses_where(old, new, |_| true)
}
pub fn replace_all_uses_where<P>(
mut self,
old: NodeID,
new: NodeID,
pred: P,
) -> Result<Self, Self>
where
P: Fn(&NodeID) -> bool,
{
// We can only replace uses of mutable nodes. Return None if we try to
// replace uses of an immutable node, as it means the whole edit should
// be aborted.
if self.is_mutable(old) {
// Update all of the users of the old node.
self.ensure_updated_def_use_entry(old);
for user_id in self.updated_def_use[&old].iter() {
if pred(user_id) {
// Replace uses of old with new.
let mut updated_user = self.get_node(*user_id).clone();
for u in get_uses_mut(&mut updated_user).as_mut() {
if **u == old {
**u = new;
}
}
// Add the updated user to added_and_updated.
self.added_and_updated_nodes.insert(*user_id, updated_user);
}
}
// All of the users of the old node become users of the new node, so
// move all of the entries in the def-use from the old to the new.
let old_entries = take(self.updated_def_use.get_mut(&old).unwrap());
let (new_users, old_users): (Vec<_>, Vec<_>) = old_entries.into_iter().partition(pred);
self.ensure_updated_def_use_entry(new);
self.updated_def_use
.get_mut(&new)
.unwrap()
.extend(new_users);
self.updated_def_use
.get_mut(&old)
.unwrap()
.extend(old_users);
Ok(self)
} else {
Err(self)
}
}
pub fn sub_edit(&mut self, src: NodeID, dst: NodeID) {
assert!(!self.added_nodeids.contains(&src));
assert!(!self.deleted_nodeids.contains(&dst));
self.sub_edits.push((src, dst));
}
pub fn get_name(&self) -> &str {
&self.editor.function.name
}
pub fn get_num_dynamic_constant_params(&self) -> u32 {
self.editor.function.num_dynamic_constants
}
pub fn get_node(&self, id: NodeID) -> &Node {
assert!(!self.deleted_nodeids.contains(&id));
if let Some(node) = self.added_and_updated_nodes.get(&id) {
// Refer to added or updated node. This node is guaranteed to be
// updated with uses after replace_all_uses is called.
node
} else {
// Refer to the original node.
&self.editor.function.nodes[id.idx()]
}
}
pub fn get_schedule(&self, id: NodeID) -> &Vec<Schedule> {
// The user may get the schedule of a to-be deleted node.
if let Some(schedule) = self.added_and_updated_schedules.get(&id) {
// Refer to added or updated schedule.
schedule
} else {
// Refer to the original schedule of this node.
&self.editor.function.schedules[id.idx()]
}
}
pub fn add_schedule(mut self, id: NodeID, schedule: Schedule) -> Result<Self, Self> {
if self.is_mutable(id) {
if let Some(schedules) = self.added_and_updated_schedules.get_mut(&id) {
if !schedules.contains(&schedule) {
schedules.push(schedule);
}
} else {
let empty = vec![];
let schedules = self
.editor
.function
.schedules
.get(id.idx())
.unwrap_or(&empty);
if !schedules.contains(&schedule) {
let mut schedules = schedules.clone();
schedules.push(schedule);
self.added_and_updated_schedules.insert(id, schedules);
}
}
Ok(self)
} else {
Err(self)
}
}
pub fn clear_schedule(mut self, id: NodeID) -> Result<Self, Self> {
if self.is_mutable(id) {
self.added_and_updated_schedules.insert(id, vec![]);
Ok(self)
} else {
Err(self)
}
}
pub fn get_label(&self, id: NodeID) -> &HashSet<LabelID> {
// The user may get the labels of a to-be deleted node.
if let Some(label) = self.added_and_updated_labels.get(&id) {
// Refer to added or updated label.
label
} else {
// Refer to the origin label of this code.
&self.editor.function.labels[id.idx()]
}
}
pub fn add_label(mut self, id: NodeID, label: LabelID) -> Result<Self, Self> {
if self.is_mutable(id) {
if let Some(labels) = self.added_and_updated_labels.get_mut(&id) {
labels.insert(label);
} else {
let mut labels = self
.editor
.function
.labels
.get(id.idx())
.unwrap_or(&HashSet::new())
.clone();
labels.insert(label);
self.added_and_updated_labels.insert(id, labels);
}
Ok(self)
} else {
Err(self)
}
}
// Creates or returns the LabelID for a given label name
pub fn new_label(&mut self, name: String) -> LabelID {
let pos = self
.editor
.labels
.borrow()
.iter()
.chain(self.added_labels.iter())
.position(|l| *l == name);
if let Some(idx) = pos {
LabelID::new(idx)
} else {
let idx = self.editor.labels.borrow().len() + self.added_labels.len();
self.added_labels.push(name);
LabelID::new(idx)
}
}
// Creates an entirely fresh label and returns its LabelID
pub fn fresh_label(&mut self) -> LabelID {
let idx = self.editor.labels.borrow().len() + self.added_labels.len();
self.added_labels.push(format!("#fresh_{}", idx));
LabelID::new(idx)
}
pub fn get_users(&self, id: NodeID) -> impl Iterator<Item = NodeID> + '_ {
assert!(!self.deleted_nodeids.contains(&id));
if let Some(users) = self.updated_def_use.get(&id) {
// Refer to the updated users set.
Either::Left(users.iter().map(|x| *x))
} else {
// Refer to the original users set.
Either::Right(self.editor.mut_def_use[id.idx()].iter().map(|x| *x))
}
}
pub fn add_type(&mut self, ty: Type) -> TypeID {
let pos = self
.editor
.types
.borrow()
.iter()
.chain(self.added_types.iter())
.position(|t| *t == ty);
if let Some(idx) = pos {
TypeID::new(idx)
} else {
let id = TypeID::new(self.editor.types.borrow().len() + self.added_types.len());
self.added_types.push(ty);
id
}
}
pub fn get_type(&self, id: TypeID) -> impl Deref<Target = Type> + '_ {
if id.idx() < self.editor.types.borrow().len() {
Either::Left(Ref::map(self.editor.types.borrow(), |types| {
&types[id.idx()]
}))
} else {
Either::Right(
self.added_types
.get(id.idx() - self.editor.types.borrow().len())
.unwrap(),
)
}
}
pub fn add_constant(&mut self, constant: Constant) -> ConstantID {
let pos = self
.editor
.constants
.borrow()
.iter()
.chain(self.added_constants.iter())
.position(|c| *c == constant);
if let Some(idx) = pos {
ConstantID::new(idx)
} else {
let id =
ConstantID::new(self.editor.constants.borrow().len() + self.added_constants.len());
self.added_constants.push(constant);
id
}
}
pub fn add_zero_constant(&mut self, id: TypeID) -> ConstantID {
let ty = self.get_type(id).clone();
let constant_to_construct = match ty {
Type::Boolean => Constant::Boolean(false),
Type::Integer8 => Constant::Integer8(0),
Type::Integer16 => Constant::Integer16(0),
Type::Integer32 => Constant::Integer32(0),
Type::Integer64 => Constant::Integer64(0),
Type::UnsignedInteger8 => Constant::UnsignedInteger8(0),
Type::UnsignedInteger16 => Constant::UnsignedInteger16(0),
Type::UnsignedInteger32 => Constant::UnsignedInteger32(0),
Type::UnsignedInteger64 => Constant::UnsignedInteger64(0),
Type::Float8 | Type::BFloat16 => panic!(),
Type::Float32 => Constant::Float32(ordered_float::OrderedFloat(0.0)),
Type::Float64 => Constant::Float64(ordered_float::OrderedFloat(0.0)),
Type::Control => panic!("PANIC: Can't create zero constant for the control type."),
Type::Product(tys) => {
let dummy_elems: Vec<_> =
tys.iter().map(|ty| self.add_zero_constant(*ty)).collect();
Constant::Product(id, dummy_elems.into_boxed_slice())
}
Type::Summation(tys) => Constant::Summation(id, 0, self.add_zero_constant(tys[0])),
Type::Array(_, _) => Constant::Array(id),
Type::MultiReturn(_) => {
panic!("PANIC: Can't create zero constant for multi-return types.")
}
};
self.add_constant(constant_to_construct)
}
pub fn add_one_constant(&mut self, id: TypeID) -> ConstantID {
let ty = self.get_type(id).clone();
let constant_to_construct = match ty {
Type::Boolean => Constant::Boolean(true),
Type::Integer8 => Constant::Integer8(1),
Type::Integer16 => Constant::Integer16(1),
Type::Integer32 => Constant::Integer32(1),
Type::Integer64 => Constant::Integer64(1),
Type::UnsignedInteger8 => Constant::UnsignedInteger8(1),
Type::UnsignedInteger16 => Constant::UnsignedInteger16(1),
Type::UnsignedInteger32 => Constant::UnsignedInteger32(1),
Type::UnsignedInteger64 => Constant::UnsignedInteger64(1),
Type::Float8 | Type::BFloat16 => panic!(),
Type::Float32 => Constant::Float32(ordered_float::OrderedFloat(1.0)),
Type::Float64 => Constant::Float64(ordered_float::OrderedFloat(1.0)),
Type::Control => panic!("PANIC: Can't create one constant for the control type."),
Type::Product(_) | Type::Summation(_) | Type::Array(_, _) => {
panic!("PANIC: Can't create one constant of a collection type.")
}
Type::MultiReturn(_) => {
panic!("PANIC: Can't create one constant for multi-return types.")
}
};
self.add_constant(constant_to_construct)
}
pub fn add_largest_constant(&mut self, id: TypeID) -> ConstantID {
let ty = self.get_type(id).clone();
let constant_to_construct = match ty {
Type::Boolean => Constant::Boolean(true),
Type::Integer8 => Constant::Integer8(i8::MAX),
Type::Integer16 => Constant::Integer16(i16::MAX),
Type::Integer32 => Constant::Integer32(i32::MAX),
Type::Integer64 => Constant::Integer64(i64::MAX),
Type::UnsignedInteger8 => Constant::UnsignedInteger8(u8::MAX),
Type::UnsignedInteger16 => Constant::UnsignedInteger16(u16::MAX),
Type::UnsignedInteger32 => Constant::UnsignedInteger32(u32::MAX),
Type::UnsignedInteger64 => Constant::UnsignedInteger64(u64::MAX),
Type::Float8 | Type::BFloat16 => panic!(),
Type::Float32 => Constant::Float32(ordered_float::OrderedFloat(f32::INFINITY)),
Type::Float64 => Constant::Float64(ordered_float::OrderedFloat(f64::INFINITY)),
Type::Control => panic!("PANIC: Can't create largest constant for the control type."),
Type::Product(_) | Type::Summation(_) | Type::Array(_, _) => {
panic!("PANIC: Can't create largest constant of a collection type.")
}
Type::MultiReturn(_) => {
panic!("PANIC: Can't create largest constant for multi-return types.")
}
};
self.add_constant(constant_to_construct)
}
pub fn add_smallest_constant(&mut self, id: TypeID) -> ConstantID {
let ty = self.get_type(id).clone();
let constant_to_construct = match ty {
Type::Boolean => Constant::Boolean(true),
Type::Integer8 => Constant::Integer8(i8::MIN),
Type::Integer16 => Constant::Integer16(i16::MIN),
Type::Integer32 => Constant::Integer32(i32::MIN),
Type::Integer64 => Constant::Integer64(i64::MIN),
Type::UnsignedInteger8 => Constant::UnsignedInteger8(u8::MIN),
Type::UnsignedInteger16 => Constant::UnsignedInteger16(u16::MIN),
Type::UnsignedInteger32 => Constant::UnsignedInteger32(u32::MIN),
Type::UnsignedInteger64 => Constant::UnsignedInteger64(u64::MIN),
Type::Float8 | Type::BFloat16 => panic!(),
Type::Float32 => Constant::Float32(ordered_float::OrderedFloat(f32::NEG_INFINITY)),
Type::Float64 => Constant::Float64(ordered_float::OrderedFloat(f64::NEG_INFINITY)),
Type::Control => panic!("PANIC: Can't create smallest constant for the control type."),
Type::Product(_) | Type::Summation(_) | Type::Array(_, _) => {
panic!("PANIC: Can't create smallest constant of a collection type.")
}
Type::MultiReturn(_) => {
panic!("PANIC: Can't create smallest constant for multi-return types.")
}
};
self.add_constant(constant_to_construct)
}
pub fn get_constant(&self, id: ConstantID) -> impl Deref<Target = Constant> + '_ {
if id.idx() < self.editor.constants.borrow().len() {
Either::Left(Ref::map(self.editor.constants.borrow(), |constants| {
&constants[id.idx()]
}))
} else {
Either::Right(
self.added_constants
.get(id.idx() - self.editor.constants.borrow().len())
.unwrap(),
)
}
}
pub fn add_dynamic_constant(&mut self, dynamic_constant: DynamicConstant) -> DynamicConstantID {
self.dc_normalize(dynamic_constant)
}
pub fn get_dynamic_constant(
&self,
id: DynamicConstantID,
) -> impl Deref<Target = DynamicConstant> + '_ {
if id.idx() < self.editor.dynamic_constants.borrow().len() {
Either::Left(Ref::map(
self.editor.dynamic_constants.borrow(),
|dynamic_constants| &dynamic_constants[id.idx()],
))
} else {
Either::Right(
self.added_dynamic_constants
.get(id.idx() - self.editor.dynamic_constants.borrow().len())
.unwrap(),
)
}
}
pub fn get_param_types(&self) -> &Vec<TypeID> {
self.updated_param_types
.as_ref()
.unwrap_or(&self.editor.function.param_types)
}
pub fn get_return_types(&self) -> &Vec<TypeID> {
self.updated_return_types
.as_ref()
.unwrap_or(&self.editor.function.return_types)
}
pub fn set_param_types(&mut self, tys: Vec<TypeID>) {
self.updated_param_types = Some(tys);
}
pub fn set_return_types(&mut self, tys: Vec<TypeID>) {
self.updated_return_types = Some(tys);
}
}
impl<'a, 'b> DynamicConstantView for FunctionEdit<'a, 'b> {
fn get_dynconst(&self, id: DynamicConstantID) -> impl Deref<Target = DynamicConstant> + '_ {
self.get_dynamic_constant(id)
}
fn add_dynconst(&mut self, dc: DynamicConstant) -> DynamicConstantID {
let pos = self
.editor
.dynamic_constants
.borrow()
.iter()
.chain(self.added_dynamic_constants.iter())
.position(|c| *c == dc);
if let Some(idx) = pos {
DynamicConstantID::new(idx)
} else {
let id = DynamicConstantID::new(
self.editor.dynamic_constants.borrow().len() + self.added_dynamic_constants.len(),
);
self.added_dynamic_constants.push(dc);
id
}
}
}
#[cfg(test)]
mod editor_tests {
#[allow(unused_imports)]
use super::*;
use std::mem::replace;
use hercules_ir::dataflow::reverse_postorder;
use hercules_ir::parse::parse;
fn canonicalize(function: &mut Function) -> Vec<Option<NodeID>> {
// The reverse postorder traversal from the Start node is a map from new
// index to old ID.
let rev_po = reverse_postorder(&def_use(function));
let num_new_nodes = rev_po.len();
// Construct a map from old ID to new ID.
let mut old_to_new = vec![None; function.nodes.len()];
for (new_idx, old_id) in rev_po.into_iter().enumerate() {
old_to_new[old_id.idx()] = Some(NodeID::new(new_idx));
}
// Move the old nodes before permuting them.
let mut old_nodes = take(&mut function.nodes);
function.nodes = vec![Node::Start; num_new_nodes];
// Permute the old nodes back into the function and fix their uses.
for (old_idx, new_id) in old_to_new.iter().enumerate() {
// Check if this old node is in the canonicalized form.
if let Some(new_id) = new_id {
// Get the old node.
let mut node = replace(&mut old_nodes[old_idx], Node::Start);
// Fix its uses.
for u in get_uses_mut(&mut node).as_mut() {
// Map every use using the old-to-new map. If we try to use
// a node that doesn't have a mapping, then the original IR
// had a node reachable from the start using another node
// not reachable from the start, which is malformed.
**u = old_to_new[u.idx()].unwrap();
}
// Insert the fixed node into its new spot.
function.nodes[new_id.idx()] = node;
}
}
old_to_new
}
#[test]
fn example1() {
// Define the original function.
let mut src_module = parse(
"
fn func(x: i32) -> i32
c = constant(i32, 7)
y = add(x, c)
r = return(start, y)
",
)
.unwrap();
// Find the ID of the add node and its uses.
let func = &mut src_module.functions[0];
let (add, left, right) = func
.nodes
.iter()
.enumerate()
.filter_map(|(idx, node)| {
node.try_binary(BinaryOperator::Add)
.map(|(left, right)| (NodeID::new(idx), left, right))
})
.next()
.unwrap();
let constants_ref = RefCell::new(src_module.constants);
let dynamic_constants_ref = RefCell::new(src_module.dynamic_constants);
let types_ref = RefCell::new(src_module.types);
let labels_ref = RefCell::new(src_module.labels);
// Edit the function by replacing the add with a multiply.
let mut editor = FunctionEditor::new(
func,
FunctionID::new(0),
&constants_ref,
&dynamic_constants_ref,
&types_ref,
&labels_ref,
&def_use(func),
);
let success = editor.edit(|mut edit| {
let mul = edit.add_node(Node::Binary {
op: BinaryOperator::Mul,
left,
right,
});
let edit = edit.replace_all_uses(add, mul)?;
let edit = edit.delete_node(add)?;
Ok(edit)
});
assert!(success);
// Canonicalize the function.
canonicalize(func);
// Check that the function is correct.
let mut dst_module = parse(
"
fn func(x: i32) -> i32
c = constant(i32, 7)
y = mul(x, c)
r = return(start, y)
",
)
.unwrap();
canonicalize(&mut dst_module.functions[0]);
assert_eq!(src_module.functions[0].nodes, dst_module.functions[0].nodes);
}
}