Skip to content
Snippets Groups Projects

Add (dynamic) constants and types to editor

Merged Ryan Ziegler requested to merge ryanjz2/hercules-onnx:editor-types-constants into main
All threads resolved!
Files
2
+ 179
24
@@ -3,6 +3,7 @@ extern crate either;
@@ -3,6 +3,7 @@ extern crate either;
extern crate hercules_ir;
extern crate hercules_ir;
extern crate itertools;
extern crate itertools;
 
use std::cell::{Ref, RefCell};
use std::collections::{BTreeMap, HashMap, HashSet, VecDeque};
use std::collections::{BTreeMap, HashMap, HashSet, VecDeque};
use std::iter::FromIterator;
use std::iter::FromIterator;
use std::mem::take;
use std::mem::take;
@@ -26,12 +27,19 @@ pub type Edit = (HashSet<NodeID>, HashSet<NodeID>);
@@ -26,12 +27,19 @@ pub type Edit = (HashSet<NodeID>, HashSet<NodeID>);
/*
/*
* Helper object for editing Hercules functions in a trackable manner. Edits are
* Helper object for editing Hercules functions in a trackable manner. Edits are
* recorded in order to repair partitions and debug info.
* recorded in order to repair partitions and debug info.
 
* Edits must be made atomically, that is, only one `.edit` may be called at a time
 
* across all editors.
*/
*/
#[derive(Debug)]
#[derive(Debug)]
pub struct FunctionEditor<'a> {
pub struct FunctionEditor<'a> {
// Wraps a mutable reference to a function. Doesn't provide access to this
// Wraps a mutable reference to a function. Doesn't provide access to this
// reference directly, so that we can monitor edits.
// reference directly, so that we can monitor edits.
function: &'a mut Function,
function: &'a mut Function,
 
// Keep a RefCell to (dynamic) constants and types to allow function changes
 
// to update these
 
constants: &'a RefCell<Vec<Constant>>,
 
dynamic_constants: &'a RefCell<Vec<DynamicConstant>>,
 
types: &'a RefCell<Vec<Type>>,
// Most optimizations need def use info, so provide an iteratively updated
// Most optimizations need def use info, so provide an iteratively updated
// mutable version that's automatically updated based on recorded edits.
// mutable version that's automatically updated based on recorded edits.
mut_def_use: Vec<HashSet<NodeID>>,
mut_def_use: Vec<HashSet<NodeID>>,
@@ -63,17 +71,28 @@ pub struct FunctionEdit<'a: 'b, 'b> {
@@ -63,17 +71,28 @@ pub struct FunctionEdit<'a: 'b, 'b> {
// Reference the active function editor.
// Reference the active function editor.
editor: &'b mut FunctionEditor<'a>,
editor: &'b mut FunctionEditor<'a>,
// Keep track of deleted node IDs.
// Keep track of deleted node IDs.
deleted: HashSet<NodeID>,
deleted_nodeids: HashSet<NodeID>,
// Keep track of added node IDs.
// Keep track of added node IDs.
added: HashSet<NodeID>,
added_nodeids: HashSet<NodeID>,
// Keep track of added and use updated nodes.
// Keep track of added and use updated nodes.
added_and_updated: BTreeMap<NodeID, Node>,
added_and_updated_nodes: BTreeMap<NodeID, Node>,
 
// Keep track of added (dynamic) constants and types
 
added_constants: RefCell<Vec<Constant>>,
 
added_dynamic_constants: RefCell<Vec<DynamicConstant>>,
 
added_types: RefCell<Vec<Type>>,
// Compute a def-use map entries iteratively.
// Compute a def-use map entries iteratively.
updated_def_use: BTreeMap<NodeID, HashSet<NodeID>>,
updated_def_use: BTreeMap<NodeID, HashSet<NodeID>>,
 
updated_return_type: Option<TypeID>,
}
}
impl<'a: 'b, 'b> FunctionEditor<'a> {
impl<'a: 'b, 'b> FunctionEditor<'a> {
pub fn new(function: &'a mut Function, def_use: &ImmutableDefUseMap) -> Self {
pub fn new(
 
function: &'a mut Function,
 
constants: &'a RefCell<Vec<Constant>>,
 
dynamic_constants: &'a RefCell<Vec<DynamicConstant>>,
 
types: &'a RefCell<Vec<Type>>,
 
def_use: &ImmutableDefUseMap,
 
) -> Self {
let mut_def_use = (0..function.nodes.len())
let mut_def_use = (0..function.nodes.len())
.map(|idx| {
.map(|idx| {
def_use
def_use
@@ -87,6 +106,9 @@ impl<'a: 'b, 'b> FunctionEditor<'a> {
@@ -87,6 +106,9 @@ impl<'a: 'b, 'b> FunctionEditor<'a> {
FunctionEditor {
FunctionEditor {
function,
function,
 
constants,
 
dynamic_constants,
 
types,
mut_def_use,
mut_def_use,
edits: vec![],
edits: vec![],
mutable_nodes,
mutable_nodes,
@@ -100,10 +122,14 @@ impl<'a: 'b, 'b> FunctionEditor<'a> {
@@ -100,10 +122,14 @@ impl<'a: 'b, 'b> FunctionEditor<'a> {
// Create the edit helper struct and perform the edit using it.
// Create the edit helper struct and perform the edit using it.
let edit_obj = FunctionEdit {
let edit_obj = FunctionEdit {
editor: self,
editor: self,
deleted: HashSet::new(),
deleted_nodeids: HashSet::new(),
added: HashSet::new(),
added_nodeids: HashSet::new(),
added_and_updated: BTreeMap::new(),
added_constants: Vec::new().into(),
 
added_dynamic_constants: Vec::new().into(),
 
added_types: Vec::new().into(),
 
added_and_updated_nodes: BTreeMap::new(),
updated_def_use: BTreeMap::new(),
updated_def_use: BTreeMap::new(),
 
updated_return_type: None,
};
};
if let Ok(populated_edit) = edit(edit_obj) {
if let Ok(populated_edit) = edit(edit_obj) {
@@ -111,10 +137,14 @@ impl<'a: 'b, 'b> FunctionEditor<'a> {
@@ -111,10 +137,14 @@ impl<'a: 'b, 'b> FunctionEditor<'a> {
// without modifying immutable nodes.
// without modifying immutable nodes.
let FunctionEdit {
let FunctionEdit {
editor,
editor,
deleted,
deleted_nodeids,
added,
added_nodeids,
added_and_updated,
added_constants,
 
added_dynamic_constants,
 
added_types,
 
added_and_updated_nodes: added_and_updated,
updated_def_use,
updated_def_use,
 
updated_return_type,
} = populated_edit;
} = populated_edit;
// Step 1: update the mutable def use map.
// Step 1: update the mutable def use map.
for (u, new_users) in updated_def_use {
for (u, new_users) in updated_def_use {
@@ -148,14 +178,14 @@ impl<'a: 'b, 'b> FunctionEditor<'a> {
@@ -148,14 +178,14 @@ impl<'a: 'b, 'b> FunctionEditor<'a> {
// Step 3: delete nodes. This is done using "gravestones", where a
// Step 3: delete nodes. This is done using "gravestones", where a
// node other than node ID 0 being a start node is considered a
// node other than node ID 0 being a start node is considered a
// gravestone.
// gravestone.
for id in deleted.iter() {
for id in deleted_nodeids.iter() {
// Check that there are no users of deleted nodes.
// Check that there are no users of deleted nodes.
assert!(editor.mut_def_use[id.idx()].is_empty());
assert!(editor.mut_def_use[id.idx()].is_empty());
editor.function.nodes[id.idx()] = Node::Start;
editor.function.nodes[id.idx()] = Node::Start;
}
}
// Step 4: add a single edit to the edit list.
// Step 4: add a single edit to the edit list.
editor.edits.push((deleted, added));
editor.edits.push((deleted_nodeids, added_nodeids));
// Step 5: update the length of mutable_nodes. All added nodes are
// Step 5: update the length of mutable_nodes. All added nodes are
// mutable.
// mutable.
@@ -163,6 +193,20 @@ impl<'a: 'b, 'b> FunctionEditor<'a> {
@@ -163,6 +193,20 @@ impl<'a: 'b, 'b> FunctionEditor<'a> {
.mutable_nodes
.mutable_nodes
.resize(editor.function.nodes.len(), true);
.resize(editor.function.nodes.len(), true);
 
// Step 6: update types and constants
 
let mut editor_constants = editor.constants.borrow_mut();
 
let mut editor_dynamic_constants = editor.dynamic_constants.borrow_mut();
 
let mut editor_types = editor.types.borrow_mut();
 
 
editor_constants.extend(added_constants.take());
 
editor_dynamic_constants.extend(added_dynamic_constants.take());
 
editor_types.extend(added_types.take());
 
 
// Step 7: update return type if necessary
 
if let Some(return_type) = updated_return_type {
 
editor.function.return_type = return_type;
 
}
 
true
true
} else {
} else {
false
false
@@ -200,7 +244,7 @@ impl<'a, 'b> FunctionEdit<'a, 'b> {
@@ -200,7 +244,7 @@ impl<'a, 'b> FunctionEdit<'a, 'b> {
}
}
pub fn add_node(&mut self, node: Node) -> NodeID {
pub fn add_node(&mut self, node: Node) -> NodeID {
let id = NodeID::new(self.editor.function.nodes.len() + self.added.len());
let id = NodeID::new(self.editor.function.nodes.len() + self.added_nodeids.len());
// Added nodes need to have an entry in the def-use map.
// Added nodes need to have an entry in the def-use map.
self.updated_def_use.insert(id, HashSet::new());
self.updated_def_use.insert(id, HashSet::new());
// Added nodes use other nodes, and we need to update their def-use
// Added nodes use other nodes, and we need to update their def-use
@@ -210,8 +254,8 @@ impl<'a, 'b> FunctionEdit<'a, 'b> {
@@ -210,8 +254,8 @@ impl<'a, 'b> FunctionEdit<'a, 'b> {
self.updated_def_use.get_mut(u).unwrap().insert(id);
self.updated_def_use.get_mut(u).unwrap().insert(id);
}
}
// Add the node.
// Add the node.
self.added_and_updated.insert(id, node);
self.added_and_updated_nodes.insert(id, node);
self.added.insert(id);
self.added_nodeids.insert(id);
id
id
}
}
@@ -220,7 +264,7 @@ impl<'a, 'b> FunctionEdit<'a, 'b> {
@@ -220,7 +264,7 @@ impl<'a, 'b> FunctionEdit<'a, 'b> {
// immutable node, as it means the whole edit should be aborted.
// immutable node, as it means the whole edit should be aborted.
if self.editor.mutable_nodes[id.idx()] {
if self.editor.mutable_nodes[id.idx()] {
assert!(
assert!(
!self.added.contains(&id),
!self.added_nodeids.contains(&id),
"PANIC: Please don't delete a node that was added in the same edit."
"PANIC: Please don't delete a node that was added in the same edit."
);
);
// Deleted nodes use other nodes, and we need to update their def-
// Deleted nodes use other nodes, and we need to update their def-
@@ -232,7 +276,7 @@ impl<'a, 'b> FunctionEdit<'a, 'b> {
@@ -232,7 +276,7 @@ impl<'a, 'b> FunctionEdit<'a, 'b> {
self.ensure_updated_def_use_entry(u);
self.ensure_updated_def_use_entry(u);
self.updated_def_use.get_mut(&u).unwrap().remove(&id);
self.updated_def_use.get_mut(&u).unwrap().remove(&id);
}
}
self.deleted.insert(id);
self.deleted_nodeids.insert(id);
Ok(self)
Ok(self)
} else {
} else {
Err(self)
Err(self)
@@ -248,14 +292,14 @@ impl<'a, 'b> FunctionEdit<'a, 'b> {
@@ -248,14 +292,14 @@ impl<'a, 'b> FunctionEdit<'a, 'b> {
self.ensure_updated_def_use_entry(old);
self.ensure_updated_def_use_entry(old);
for user_id in self.updated_def_use[&old].iter() {
for user_id in self.updated_def_use[&old].iter() {
// Replace uses of old with new.
// Replace uses of old with new.
let mut updated_user = self.node(*user_id).clone();
let mut updated_user = self.get_node(*user_id).clone();
for u in get_uses_mut(&mut updated_user).as_mut() {
for u in get_uses_mut(&mut updated_user).as_mut() {
if **u == old {
if **u == old {
**u = new;
**u = new;
}
}
}
}
// Add the updated user to added_and_updated.
// Add the updated user to added_and_updated.
self.added_and_updated.insert(*user_id, updated_user);
self.added_and_updated_nodes.insert(*user_id, updated_user);
}
}
// All of the users of the old node become users of the new node, so
// All of the users of the old node become users of the new node, so
@@ -273,9 +317,9 @@ impl<'a, 'b> FunctionEdit<'a, 'b> {
@@ -273,9 +317,9 @@ impl<'a, 'b> FunctionEdit<'a, 'b> {
}
}
}
}
pub fn node(&self, id: NodeID) -> &Node {
pub fn get_node(&self, id: NodeID) -> &Node {
assert!(!self.deleted.contains(&id));
assert!(!self.deleted_nodeids.contains(&id));
if let Some(node) = self.added_and_updated.get(&id) {
if let Some(node) = self.added_and_updated_nodes.get(&id) {
// Refer to added or updated node. This node is guaranteed to be
// Refer to added or updated node. This node is guaranteed to be
// updated with uses after replace_all_uses is called.
// updated with uses after replace_all_uses is called.
node
node
@@ -286,7 +330,7 @@ impl<'a, 'b> FunctionEdit<'a, 'b> {
@@ -286,7 +330,7 @@ impl<'a, 'b> FunctionEdit<'a, 'b> {
}
}
pub fn users(&self, id: NodeID) -> impl Iterator<Item = NodeID> + '_ {
pub fn users(&self, id: NodeID) -> impl Iterator<Item = NodeID> + '_ {
assert!(!self.deleted.contains(&id));
assert!(!self.deleted_nodeids.contains(&id));
if let Some(users) = self.updated_def_use.get(&id) {
if let Some(users) = self.updated_def_use.get(&id) {
// Refer to the updated users set.
// Refer to the updated users set.
Either::Left(users.iter().map(|x| *x))
Either::Left(users.iter().map(|x| *x))
@@ -295,6 +339,108 @@ impl<'a, 'b> FunctionEdit<'a, 'b> {
@@ -295,6 +339,108 @@ impl<'a, 'b> FunctionEdit<'a, 'b> {
Either::Right(self.editor.mut_def_use[id.idx()].iter().map(|x| *x))
Either::Right(self.editor.mut_def_use[id.idx()].iter().map(|x| *x))
}
}
}
}
 
 
pub fn add_type(&mut self, ty: Type) -> TypeID {
 
let pos = self
 
.editor
 
.types
 
.borrow()
 
.iter()
 
.chain(self.added_types.borrow().iter())
 
.position(|t| *t == ty);
 
if let Some(idx) = pos {
 
TypeID::new(idx)
 
} else {
 
let id =
 
TypeID::new(self.editor.types.borrow().len() + self.added_types.borrow().len());
 
self.added_types.borrow_mut().push(ty);
 
id
 
}
 
}
 
 
pub fn get_type(&self, id: TypeID) -> Ref<'_, Type> {
 
if id.idx() < self.editor.types.borrow().len() {
 
Ref::map(self.editor.types.borrow(), |types| &types[id.idx()])
 
} else {
 
Ref::map(self.added_types.borrow(), |added_types| {
 
&added_types[id.idx() - self.editor.types.borrow().len()]
 
})
 
}
 
}
 
 
pub fn add_constant(&mut self, constant: Constant) -> ConstantID {
 
let pos = self
 
.editor
 
.constants
 
.borrow()
 
.iter()
 
.chain(self.added_constants.borrow().iter())
 
.position(|c| *c == constant);
 
if let Some(idx) = pos {
 
ConstantID::new(idx)
 
} else {
 
let id = ConstantID::new(
 
self.editor.constants.borrow().len() + self.added_constants.borrow().len(),
 
);
 
self.added_constants.borrow_mut().push(constant);
 
id
 
}
 
}
 
 
pub fn get_constant(&self, id: ConstantID) -> Ref<'_, Constant> {
 
if id.idx() < self.editor.constants.borrow().len() {
 
Ref::map(self.editor.constants.borrow(), |constants| {
 
&constants[id.idx()]
 
})
 
} else {
 
Ref::map(self.added_constants.borrow(), |added_constants| {
 
&added_constants[id.idx() - self.editor.constants.borrow().len()]
 
})
 
}
 
}
 
 
pub fn add_dynamic_constant(&mut self, dynamic_constant: DynamicConstant) -> DynamicConstantID {
 
let pos = self
 
.editor
 
.dynamic_constants
 
.borrow()
 
.iter()
 
.chain(self.added_dynamic_constants.borrow().iter())
 
.position(|c| *c == dynamic_constant);
 
if let Some(idx) = pos {
 
DynamicConstantID::new(idx)
 
} else {
 
let id = DynamicConstantID::new(
 
self.editor.dynamic_constants.borrow().len()
 
+ self.added_dynamic_constants.borrow().len(),
 
);
 
self.added_dynamic_constants
 
.borrow_mut()
 
.push(dynamic_constant);
 
id
 
}
 
}
 
 
pub fn get_dynamic_constant(&self, id: DynamicConstantID) -> Ref<'_, DynamicConstant> {
 
if id.idx() < self.editor.dynamic_constants.borrow().len() {
 
Ref::map(
 
self.editor.dynamic_constants.borrow(),
 
|dynamic_constants| &dynamic_constants[id.idx()],
 
)
 
} else {
 
Ref::map(
 
self.added_dynamic_constants.borrow(),
 
|added_dynamic_constants| {
 
&added_dynamic_constants
 
[id.idx() - self.editor.dynamic_constants.borrow().len()]
 
},
 
)
 
}
 
}
 
 
pub fn set_return_type(&mut self, ty: TypeID) {
 
self.updated_return_type = Some(ty);
 
}
}
}
/*
/*
@@ -606,8 +752,17 @@ fn func(x: i32) -> i32
@@ -606,8 +752,17 @@ fn func(x: i32) -> i32
.next()
.next()
.unwrap();
.unwrap();
 
let constants_ref = RefCell::new(src_module.constants);
 
let dynamic_constants_ref = RefCell::new(src_module.dynamic_constants);
 
let types_ref = RefCell::new(src_module.types);
// Edit the function by replacing the add with a multiply.
// Edit the function by replacing the add with a multiply.
let mut editor = FunctionEditor::new(func, &def_use(func));
let mut editor = FunctionEditor::new(
 
func,
 
&constants_ref,
 
&dynamic_constants_ref,
 
&types_ref,
 
&def_use(func),
 
);
let success = editor.edit(|mut edit| {
let success = editor.edit(|mut edit| {
let mul = edit.add_node(Node::Binary {
let mul = edit.add_node(Node::Binary {
op: BinaryOperator::Mul,
op: BinaryOperator::Mul,
Loading