Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • llvm/hercules
1 result
Show changes
Commits on Source (4)
Showing
with 669 additions and 179 deletions
# Highlight juno source files like they're rust source files
*.jn gitlab-language=rust
......@@ -23,6 +23,12 @@ version = "0.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4aa90d7ce82d4be67b64039a3d588d38dbcc6736577de4a847025ce5b0c468d1"
[[package]]
name = "allocator-api2"
version = "0.2.21"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "683d7910e743518b0e34f1186f92494becacb047c7b6bf616c96772180fef923"
[[package]]
name = "anstream"
version = "0.6.18"
......@@ -621,6 +627,27 @@ version = "1.0.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "92773504d58c093f6de2459af4af33faa518c13451eb8f2b5698ed3d36e7c813"
[[package]]
name = "egg"
version = "0.10.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "abb749745461743bb477fba3ef87c663d5965876155c676c9489cfe0963de5ab"
dependencies = [
"env_logger",
"hashbrown",
"indexmap",
"log",
"num-bigint",
"num-traits",
"quanta",
"rustc-hash",
"saturating",
"smallvec",
"symbol_table",
"symbolic_expressions",
"thiserror",
]
[[package]]
name = "either"
version = "1.13.0"
......@@ -639,6 +666,15 @@ version = "0.6.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "edd0f118536f44f5ccd48bcb8b111bdc3de888b58c74639dfb034a357d0f206d"
[[package]]
name = "env_logger"
version = "0.9.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a12e6657c4c97ebab115a42dcee77225f7f482cdd841cf7088c657a42e9e00e7"
dependencies = [
"log",
]
[[package]]
name = "equivalent"
version = "1.0.1"
......@@ -752,6 +788,12 @@ version = "1.0.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1"
[[package]]
name = "foldhash"
version = "0.1.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a0d2fde1f7b3d48b8395d5f2de76c18a528bd6a9cdde438df747bfcba3e05d6f"
[[package]]
name = "funty"
version = "1.1.0"
......@@ -882,6 +924,11 @@ name = "hashbrown"
version = "0.15.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bf151400ff0baff5465007dd2f3e717f3fe502074ca563069ce3a6629d07b289"
dependencies = [
"allocator-api2",
"equivalent",
"foldhash",
]
[[package]]
name = "heapless"
......@@ -955,6 +1002,7 @@ version = "0.1.0"
dependencies = [
"bimap",
"bitvec 1.0.1",
"egg",
"either",
"hercules_cg",
"hercules_ir",
......@@ -1119,6 +1167,31 @@ dependencies = [
"with_builtin_macros",
]
[[package]]
name = "juno_backprop"
version = "0.1.0"
dependencies = [
"async-std",
"clap",
"hercules_rt",
"juno_build",
"nom 6.2.2",
"rand 0.9.0",
"with_builtin_macros",
]
[[package]]
name = "juno_bfs"
version = "0.1.0"
dependencies = [
"async-std",
"clap",
"hercules_rt",
"juno_build",
"nom 6.2.2",
"with_builtin_macros",
]
[[package]]
name = "juno_build"
version = "0.1.0"
......@@ -1150,6 +1223,18 @@ dependencies = [
"with_builtin_macros",
]
[[package]]
name = "juno_cfd"
version = "0.1.0"
dependencies = [
"async-std",
"clap",
"hercules_rt",
"juno_build",
"nom 6.2.2",
"with_builtin_macros",
]
[[package]]
name = "juno_concat"
version = "0.1.0"
......@@ -1177,7 +1262,7 @@ dependencies = [
"async-std",
"hercules_rt",
"juno_build",
"rand 0.8.5",
"rand 0.9.0",
"with_builtin_macros",
]
......@@ -1321,6 +1406,18 @@ dependencies = [
"with_builtin_macros",
]
[[package]]
name = "juno_srad"
version = "0.1.0"
dependencies = [
"async-std",
"clap",
"hercules_rt",
"juno_build",
"nom 6.2.2",
"with_builtin_macros",
]
[[package]]
name = "juno_utils"
version = "0.1.0"
......@@ -1921,6 +2018,21 @@ dependencies = [
"bytemuck",
]
[[package]]
name = "quanta"
version = "0.12.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3bd1fe6824cea6538803de3ff1bc0cf3949024db3d43c9643024bfb33a807c0e"
dependencies = [
"crossbeam-utils",
"libc",
"once_cell",
"raw-cpuid",
"wasi 0.11.0+wasi-snapshot-preview1",
"web-sys",
"winapi",
]
[[package]]
name = "quick-error"
version = "2.0.1"
......@@ -2061,6 +2173,15 @@ dependencies = [
"rgb",
]
[[package]]
name = "raw-cpuid"
version = "11.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "529468c1335c1c03919960dfefdb1b3648858c20d7ec2d0663e728e4a717efbc"
dependencies = [
"bitflags 2.8.0",
]
[[package]]
name = "rayon"
version = "1.10.0"
......@@ -2119,6 +2240,12 @@ version = "0.8.50"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "57397d16646700483b67d2dd6511d79318f9d057fdbd21a4066aeac8b41d310a"
[[package]]
name = "rustc-hash"
version = "2.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "357703d41365b4b27c590e3ed91eabb1b663f07c4c084095e60cbed4362dff0d"
[[package]]
name = "rustc_version"
version = "0.4.1"
......@@ -2153,6 +2280,12 @@ version = "1.0.19"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6ea1a2d0a644769cc99faa24c3ad26b379b786fe7c36fd3c546254801650e6dd"
[[package]]
name = "saturating"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ece8e78b2f38ec51c51f5d475df0a7187ba5111b2a28bdc761ee05b075d40a71"
[[package]]
name = "scopeguard"
version = "1.2.0"
......@@ -2284,6 +2417,23 @@ version = "0.11.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f"
[[package]]
name = "symbol_table"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f19bffd69fb182e684d14e3c71d04c0ef33d1641ac0b9e81c712c734e83703bc"
dependencies = [
"crossbeam-utils",
"foldhash",
"hashbrown",
]
[[package]]
name = "symbolic_expressions"
version = "5.0.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7c68d531d83ec6c531150584c42a4290911964d5f0d79132b193b67252a23b71"
[[package]]
name = "syn"
version = "1.0.109"
......@@ -2648,6 +2798,28 @@ version = "0.1.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "53a85b86a771b1c87058196170769dd264f66c0782acf1ae6cc51bfd64b39082"
[[package]]
name = "winapi"
version = "0.3.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5c839a674fcd7a98952e593242ea400abe93992746761e38641405d28b00f419"
dependencies = [
"winapi-i686-pc-windows-gnu",
"winapi-x86_64-pc-windows-gnu",
]
[[package]]
name = "winapi-i686-pc-windows-gnu"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6"
[[package]]
name = "winapi-x86_64-pc-windows-gnu"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f"
[[package]]
name = "windows"
version = "0.59.0"
......
......@@ -28,6 +28,10 @@ members = [
"juno_samples/multi_device",
"juno_samples/patterns",
"juno_samples/products",
"juno_samples/rodinia/backprop",
"juno_samples/rodinia/bfs",
"juno_samples/rodinia/cfd",
"juno_samples/rodinia/srad",
"juno_samples/schedule_test",
"juno_samples/simple3",
"juno_scheduler",
......
......@@ -964,7 +964,17 @@ fn convert_type(ty: &Type) -> &'static str {
fn convert_intrinsic(intrinsic: &Intrinsic, ty: &Type) -> String {
let intrinsic = match intrinsic {
Intrinsic::Abs => "abs",
Intrinsic::Abs => {
if ty.is_float() {
"fabs"
} else if ty.is_signed() {
"abs"
} else if ty.is_unsigned() {
panic!("llvm doesn't define abs for unsigned integers")
} else {
panic!()
}
},
Intrinsic::ACos => "acos",
Intrinsic::ASin => "asin",
Intrinsic::ATan => "atan",
......
......@@ -2004,7 +2004,7 @@ extern \"C\" {} {}(",
fn codegen_intrinsic(&self, intrinsic: &Intrinsic, ty: &Type) -> String {
let func_name = match intrinsic {
Intrinsic::Abs => match ty {
Type::Float32 => "__fabsf",
Type::Float32 => "fabsf",
Type::Float64 => "__fabs",
ty if ty.is_signed() => "abs",
ty if ty.is_unsigned() => "uabs",
......
......@@ -597,6 +597,59 @@ impl<'a> RTContext<'a> {
}
write!(block, "){};", postfix)?;
}
Node::LibraryCall {
library_function,
ref args,
ty,
device,
} => match library_function {
LibraryFunction::GEMM => {
assert_eq!(args.len(), 3);
assert_eq!(self.typing[args[0].idx()], ty);
let c_ty = &self.module.types[self.typing[args[0].idx()].idx()];
let a_ty = &self.module.types[self.typing[args[1].idx()].idx()];
let b_ty = &self.module.types[self.typing[args[2].idx()].idx()];
let (
Type::Array(c_elem, c_dims),
Type::Array(a_elem, a_dims),
Type::Array(b_elem, b_dims),
) = (c_ty, a_ty, b_ty)
else {
panic!();
};
assert_eq!(a_elem, b_elem);
assert_eq!(a_elem, c_elem);
assert_eq!(c_dims.len(), 2);
assert_eq!(a_dims.len(), 2);
assert_eq!(b_dims.len(), 2);
assert_eq!(a_dims[1], b_dims[0]);
assert_eq!(a_dims[0], c_dims[0]);
assert_eq!(b_dims[1], c_dims[1]);
let block = &mut blocks.get_mut(&bb).unwrap().data;
let prim_ty = self.library_prim_ty(*a_elem);
write!(block, "::hercules_rt::__library_{}_gemm(", device.name())?;
self.codegen_dynamic_constant(a_dims[0], block)?;
write!(block, ", ")?;
self.codegen_dynamic_constant(a_dims[1], block)?;
write!(block, ", ")?;
self.codegen_dynamic_constant(b_dims[1], block)?;
write!(
block,
", {}.0, {}.0, {}.0, {});",
self.get_value(args[0], bb, false),
self.get_value(args[1], bb, false),
self.get_value(args[2], bb, false),
prim_ty
)?;
write!(
block,
"{} = {};",
self.get_value(id, bb, true),
self.get_value(args[0], bb, false)
)?;
}
},
Node::Unary { op, input } => {
let block = &mut blocks.get_mut(&bb).unwrap().data;
match op {
......@@ -1316,6 +1369,25 @@ impl<'a> RTContext<'a> {
fn get_type(&self, id: TypeID) -> &'static str {
convert_type(&self.module.types[id.idx()])
}
fn library_prim_ty(&self, id: TypeID) -> &'static str {
match self.module.types[id.idx()] {
Type::Boolean => "::hercules_rt::PrimTy::Bool",
Type::Integer8 => "::hercules_rt::PrimTy::I8",
Type::Integer16 => "::hercules_rt::PrimTy::I16",
Type::Integer32 => "::hercules_rt::PrimTy::I32",
Type::Integer64 => "::hercules_rt::PrimTy::I64",
Type::UnsignedInteger8 => "::hercules_rt::PrimTy::U8",
Type::UnsignedInteger16 => "::hercules_rt::PrimTy::U16",
Type::UnsignedInteger32 => "::hercules_rt::PrimTy::U32",
Type::UnsignedInteger64 => "::hercules_rt::PrimTy::U64",
Type::Float8 => "::hercules_rt::PrimTy::F8",
Type::BFloat16 => "::hercules_rt::PrimTy::BF16",
Type::Float32 => "::hercules_rt::PrimTy::F32",
Type::Float64 => "::hercules_rt::PrimTy::F64",
_ => panic!(),
}
}
}
fn convert_type(ty: &Type) -> &'static str {
......
......@@ -218,6 +218,8 @@ pub fn collection_objects(
// - Constant: may originate an object.
// - Call: may originate an object and may return an object passed in as
// a parameter.
// - LibraryCall: may return an object passed in as a parameter, but may
// not originate an object.
// - Read: may extract a smaller object from the input - this is
// considered to be the same object as the input, as no copy takes
// place.
......@@ -288,6 +290,14 @@ pub fn collection_objects(
}
CollectionObjectLattice { objs }
}
Node::LibraryCall {
library_function,
args: _,
ty: _,
device: _,
} => match library_function {
LibraryFunction::GEMM => inputs[0].clone(),
},
Node::Undef { ty: _ } => {
let obj = origins
.iter()
......@@ -332,7 +342,13 @@ pub fn collection_objects(
for object in objects_per_node[idx].iter() {
mutated[object.idx()].push(NodeID::new(idx));
}
} else if let Some((_, callee, _, args)) = node.try_call() {
} else if let Node::Call {
control: _,
function: callee,
dynamic_constants: _,
args,
} = node
{
let fco = &collection_objects[&callee];
for (param_idx, arg) in args.into_iter().enumerate() {
// If this parameter corresponds to an object and it's
......@@ -347,6 +363,20 @@ pub fn collection_objects(
}
}
}
} else if let Node::LibraryCall {
library_function,
args,
ty: _,
device: _,
} = node
{
match library_function {
LibraryFunction::GEMM => {
for object in objects_per_node[args[0].idx()].iter() {
mutated[object.idx()].push(NodeID::new(idx));
}
}
}
}
}
......
......@@ -178,7 +178,13 @@ pub fn get_uses(node: &Node) -> NodeUses {
uses.extend(args);
NodeUses::Variable(uses.into_boxed_slice())
}
Node::IntrinsicCall { intrinsic: _, args } => NodeUses::Variable(args.clone()),
Node::IntrinsicCall { intrinsic: _, args }
| Node::LibraryCall {
library_function: _,
args,
ty: _,
device: _,
} => NodeUses::Variable(args.clone()),
Node::Read { collect, indices } => {
let mut uses = vec![];
for index in indices.iter() {
......@@ -276,9 +282,13 @@ pub fn get_uses_mut<'a>(node: &'a mut Node) -> NodeUsesMut<'a> {
uses.extend(args);
NodeUsesMut::Variable(uses.into_boxed_slice())
}
Node::IntrinsicCall { intrinsic: _, args } => {
NodeUsesMut::Variable(args.iter_mut().collect())
}
Node::IntrinsicCall { intrinsic: _, args }
| Node::LibraryCall {
library_function: _,
args,
ty: _,
device: _,
} => NodeUsesMut::Variable(args.iter_mut().collect()),
Node::Read { collect, indices } => {
let mut uses = vec![];
for index in indices.iter_mut() {
......
......@@ -23,118 +23,3 @@ pub fn device_placement(functions: &Vec<Function>, callgraph: &CallGraph) -> Vec
devices
}
pub type FunctionObjectDeviceDemands = Vec<BTreeSet<Device>>;
pub type ObjectDeviceDemands = Vec<FunctionObjectDeviceDemands>;
/*
* This analysis figures out which device each collection object may be on. At
* first, an object may need to be on different devices at different times. This
* is fine during optimization.
*/
pub fn object_device_demands(
functions: &Vec<Function>,
types: &Vec<Type>,
typing: &ModuleTyping,
callgraph: &CallGraph,
objects: &CollectionObjects,
devices: &Vec<Device>,
) -> ObjectDeviceDemands {
// An object is "demanded" on a device when:
// 1. The object is used by a primitive read node or write node in a device
// function. This includes objects on the `data` input to write nodes.
// Non-primitive reads don't demand an object on a device since they are
// lowered to pointer math and no actual memory transfers.
// 2. The object is a constant / undef defined in a device function.
// 3. The object is passed as input to a call node where the corresponding
// object in the callee is demanded on a device.
// 4. The object is returned from a call node where the corresponding object
// in the callee is demanded on a device.
// Note that reads and writes in a RT function don't induce a device demand.
// This is because RT functions can call device functions as necessary to
// arbitrarily move data onto / off of devices (though this may be slow).
// Traverse the functions in a module in reverse topological order, since
// the analysis of a function depends on all functions it calls.
let mut demands: ObjectDeviceDemands = vec![vec![]; functions.len()];
let topo = callgraph.topo();
for func_id in topo {
let function = &functions[func_id.idx()];
let typing = &typing[func_id.idx()];
let device = devices[func_id.idx()];
demands[func_id.idx()].resize(objects[&func_id].num_objects(), BTreeSet::new());
match device {
Device::LLVM | Device::CUDA => {
for (idx, node) in function.nodes.iter().enumerate() {
match node {
// Condition #1.
Node::Read {
collect,
indices: _,
} if types[typing[idx].idx()].is_primitive() => {
for object in objects[&func_id].objects(*collect) {
demands[func_id.idx()][object.idx()].insert(device);
}
}
Node::Write {
collect,
data,
indices: _,
} => {
for object in objects[&func_id]
.objects(*collect)
.into_iter()
.chain(objects[&func_id].objects(*data).into_iter())
{
demands[func_id.idx()][object.idx()].insert(device);
}
}
// Condition #2.
Node::Constant { id: _ } | Node::Undef { ty: _ } => {
for object in objects[&func_id].objects(NodeID::new(idx)) {
demands[func_id.idx()][object.idx()].insert(device);
}
}
_ => {}
}
}
}
Device::AsyncRust => {
for (idx, node) in function.nodes.iter().enumerate() {
if let Node::Call {
control: _,
function: callee,
dynamic_constants: _,
args,
} = node
{
// Condition #3.
for (param_idx, arg) in args.into_iter().enumerate() {
if let Some(callee_obj) = objects[callee].param_to_object(param_idx) {
let callee_demands =
take(&mut demands[callee.idx()][callee_obj.idx()]);
for object in objects[&func_id].objects(*arg) {
demands[func_id.idx()][object.idx()]
.extend(callee_demands.iter());
}
demands[callee.idx()][callee_obj.idx()] = callee_demands;
}
}
// Condition #4.
for callee_obj in objects[callee].returned_objects() {
let callee_demands = take(&mut demands[callee.idx()][callee_obj.idx()]);
for object in objects[&func_id].objects(NodeID::new(idx)) {
demands[func_id.idx()][object.idx()].extend(callee_demands.iter());
}
demands[callee.idx()][callee_obj.idx()] = callee_demands;
}
}
}
}
}
}
demands
}
......@@ -318,6 +318,12 @@ fn write_node<W: Write>(
Node::IntrinsicCall { intrinsic, args: _ } => {
write!(&mut suffix, "{}", intrinsic.lower_case_name())?
}
Node::LibraryCall {
library_function,
args: _,
ty: _,
device,
} => write!(&mut suffix, "{:?} on {:?}", library_function, device)?,
Node::Read {
collect: _,
indices,
......
......@@ -222,6 +222,12 @@ pub enum Node {
intrinsic: Intrinsic,
args: Box<[NodeID]>,
},
LibraryCall {
library_function: LibraryFunction,
args: Box<[NodeID]>,
ty: TypeID,
device: Device,
},
Read {
collect: NodeID,
indices: Box<[Index]>,
......@@ -336,7 +342,7 @@ pub enum Schedule {
* The authoritative enumeration of supported backends. Multiple backends may
* correspond to the same kind of hardware.
*/
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
pub enum Device {
LLVM,
CUDA,
......@@ -345,6 +351,14 @@ pub enum Device {
AsyncRust,
}
/*
* The authoritative enumeration of supported library calls.
*/
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
pub enum LibraryFunction {
GEMM,
}
/*
* A single node may have multiple schedules.
*/
......@@ -1531,6 +1545,12 @@ impl Node {
intrinsic: _,
args: _,
} => "Intrinsic",
Node::LibraryCall {
library_function: _,
args: _,
ty: _,
device: _,
} => "Library",
Node::Read {
collect: _,
indices: _,
......@@ -1604,6 +1624,12 @@ impl Node {
intrinsic: _,
args: _,
} => "intrinsic",
Node::LibraryCall {
library_function: _,
args: _,
ty: _,
device: _,
} => "library",
Node::Read {
collect: _,
indices: _,
......
......@@ -935,6 +935,12 @@ fn typeflow(
}
}
}
Node::LibraryCall {
library_function: _,
args: _,
ty,
device: _,
} => Concrete(*ty),
Node::Read {
collect: _,
indices,
......
......@@ -21,4 +21,5 @@ serde = { version = "*", features = ["derive"] }
hercules_cg = { path = "../hercules_cg" }
hercules_ir = { path = "../hercules_ir" }
nestify = "*"
bimap = "*"
\ No newline at end of file
bimap = "*"
egg = "*"
......@@ -933,6 +933,17 @@ fn ccp_flow_function(
constant: new_constant,
}
}
Node::LibraryCall {
library_function: _,
args,
ty: _,
device: _,
} => CCPLattice {
reachability: args.iter().fold(ReachabilityLattice::bottom(), |val, id| {
ReachabilityLattice::join(&val, &inputs[id.idx()].reachability)
}),
constant: ConstantLattice::bottom(),
},
Node::Read { collect, indices } => {
let mut reachability = inputs[collect.idx()].reachability.clone();
for index in indices.iter() {
......
......@@ -655,6 +655,14 @@ fn terminating_reads<'a>(
None
}
})),
Node::LibraryCall {
library_function,
ref args,
ty: _,
device: _,
} => match library_function {
LibraryFunction::GEMM => Box::new(once(args[1]).chain(once(args[2]))),
},
_ => Box::new(empty()),
}
}
......@@ -728,6 +736,16 @@ fn mutating_objects<'a>(
})
.flatten(),
),
Node::LibraryCall {
library_function,
ref args,
ty: _,
device: _,
} => match library_function {
LibraryFunction::GEMM => {
Box::new(objects[&func_id].objects(args[0]).into_iter().map(|id| *id))
}
},
_ => Box::new(empty()),
}
}
......@@ -757,6 +775,14 @@ fn mutating_writes<'a>(
None
}
})),
Node::LibraryCall {
library_function,
ref args,
ty: _,
device: _,
} => match library_function {
LibraryFunction::GEMM => Box::new(once(args[0])),
},
_ => Box::new(empty()),
}
}
......@@ -1311,6 +1337,17 @@ fn color_nodes(
}
}
}
Node::LibraryCall {
library_function: _,
ref args,
ty: _,
device,
} => {
for arg in args {
equations.push((UTerm::Node(*arg), UTerm::Device(device)));
}
equations.push((UTerm::Node(id), UTerm::Device(device)));
}
_ => {}
}
}
......
......@@ -21,6 +21,7 @@ pub mod outline;
pub mod phi_elim;
pub mod pred;
pub mod reuse_products;
pub mod rewrite_math_expressions;
pub mod schedule;
pub mod simplify_cfg;
pub mod slf;
......@@ -49,6 +50,7 @@ pub use crate::outline::*;
pub use crate::phi_elim::*;
pub use crate::pred::*;
pub use crate::reuse_products::*;
pub use crate::rewrite_math_expressions::*;
pub use crate::schedule::*;
pub use crate::simplify_cfg::*;
pub use crate::slf::*;
......
......@@ -180,6 +180,13 @@ pub fn outline(
editor.edit(|mut edit| {
// Step 2: assemble the outlined function.
let u32_ty = edit.add_type(Type::UnsignedInteger32);
let return_types: Box<[_]> = return_idx_to_inside_id
.iter()
.map(|id| typing[id.idx()])
.chain(callee_succ_return_idx.map(|_| u32_ty))
.collect();
let single_return = return_types.len() == 1;
let mut outlined = Function {
name: format!(
"{}_{}",
......@@ -191,13 +198,11 @@ pub fn outline(
.map(|id| typing[id.idx()])
.chain(callee_pred_param_idx.map(|_| u32_ty))
.collect(),
return_type: edit.add_type(Type::Product(
return_idx_to_inside_id
.iter()
.map(|id| typing[id.idx()])
.chain(callee_succ_return_idx.map(|_| u32_ty))
.collect(),
)),
return_type: if single_return {
return_types[0]
} else {
edit.add_type(Type::Product(return_types))
},
num_dynamic_constants: edit.get_num_dynamic_constant_params(),
entry: false,
nodes: vec![],
......@@ -393,18 +398,24 @@ pub fn outline(
data_ids.push(cons_node_id);
}
// Build the return product.
let mut construct_id = NodeID::new(outlined.nodes.len());
outlined.nodes.push(Node::Constant { id: cons_id });
for (idx, data) in data_ids.into_iter().enumerate() {
let write = Node::Write {
collect: construct_id,
data: data,
indices: Box::new([Index::Field(idx)]),
};
construct_id = NodeID::new(outlined.nodes.len());
outlined.nodes.push(write);
}
// Build the return value
let construct_id = if single_return {
assert!(data_ids.len() == 1);
data_ids.pop().unwrap()
} else {
let mut construct_id = NodeID::new(outlined.nodes.len());
outlined.nodes.push(Node::Constant { id: cons_id });
for (idx, data) in data_ids.into_iter().enumerate() {
let write = Node::Write {
collect: construct_id,
data: data,
indices: Box::new([Index::Field(idx)]),
};
construct_id = NodeID::new(outlined.nodes.len());
outlined.nodes.push(write);
}
construct_id
};
// Return the return product.
outlined.nodes.push(Node::Return {
......@@ -505,16 +516,20 @@ pub fn outline(
};
// Create the read nodes from the call node to get the outputs of the
// outlined function.
let output_reads: Vec<_> = (0..return_idx_to_inside_id.len())
.map(|idx| {
let read = Node::Read {
collect: call_id,
indices: Box::new([Index::Field(idx)]),
};
edit.add_node(read)
})
.collect();
// outlined function (if there are multiple returned values)
let output_reads: Vec<_> = if single_return {
vec![call_id]
} else {
(0..return_idx_to_inside_id.len())
.map(|idx| {
let read = Node::Read {
collect: call_id,
indices: Box::new([Index::Field(idx)]),
};
edit.add_node(read)
})
.collect()
};
let indicator_read = callee_succ_return_idx.map(|idx| {
let read = Node::Read {
collect: call_id,
......
use std::collections::{HashMap, HashSet};
use std::fmt::{Error, Write};
use hercules_ir::*;
use egg::*;
use crate::*;
define_language! {
enum MathLanguage {
"zero" = Zero,
"one" = One,
ForkDim(i64),
"tid" = ThreadID(Id),
"sum" = SumReduction(Box<[Id]>),
"array" = Comprehension(Box<[Id]>),
"read" = Read(Box<[Id]>),
"+" = Add([Id; 2]),
"*" = Mul([Id; 2]),
"library_gemm" = LibraryGemm([Id; 2]),
Opaque(Symbol),
}
}
fn make_rules() -> Vec<Rewrite<MathLanguage, ()>> {
vec![
rewrite!("add-zero"; "(+ zero ?a)" => "?a"),
rewrite!("mul-zero"; "(* zero ?a)" => "zero"),
rewrite!("mul-one"; "(* one ?a)" => "?a"),
rewrite!("library-gemm"; "(array ?i ?k (sum ?j (* (read ?i ?j ?A) (read ?j ?k ?B))))" => "(library_gemm ?A ?B)"),
]
}
pub fn rewrite_math_expressions(
editor: &mut FunctionEditor,
device: Device,
typing: &Vec<TypeID>,
fork_join_map: &HashMap<NodeID, NodeID>,
nodes_in_fork_joins: &HashMap<NodeID, HashSet<NodeID>>,
reduce_einsums: &(MathEnv, HashMap<NodeID, MathID>),
) {
let join_fork_map: HashMap<_, _> = fork_join_map
.into_iter()
.map(|(fork, join)| (*join, *fork))
.collect();
// Step 1: figure out how many fork-joins each reduce is in. We want to
// rewrite outer reductions before inner ones.
let mut depth: HashMap<NodeID, u32> = HashMap::new();
for (_, inside) in nodes_in_fork_joins {
for id in inside {
if editor.func().nodes[id.idx()].is_reduce() {
*depth.entry(*id).or_default() += 1;
}
}
}
let mut reduces: Vec<(u32, NodeID)> =
depth.into_iter().map(|(id, depth)| (depth, id)).collect();
reduces.sort();
for (_, reduce) in reduces {
let (join, init, _) = editor.func().nodes[reduce.idx()].try_reduce().unwrap();
let fork = join_fork_map[&join];
// Step 2: convert the reduce to an expression in egg and rewrite it.
let mut s = String::new();
egg_print_math_expr(reduce_einsums.1[&reduce], &reduce_einsums.0, &mut s).unwrap();
let expr: RecExpr<MathLanguage> = s.parse().unwrap();
let egraph = Runner::default().with_expr(&expr).run(&make_rules()).egraph;
// Step 3: match the smallest expression against patterns for known
// library functions.
let gemm_pattern: Pattern<MathLanguage> = "(library_gemm ?A ?B)".parse().unwrap();
let mut matches = gemm_pattern.search(&egraph);
if matches.len() > 0
&& let m = matches.remove(0)
&& m.substs.len() > 0
{
let left_id = get_node_ids_from_subst("?A", &m.substs[0], &egraph);
let right_id = get_node_ids_from_subst("?B", &m.substs[0], &egraph);
let call = Node::LibraryCall {
library_function: LibraryFunction::GEMM,
args: vec![init, left_id, right_id].into_boxed_slice(),
ty: typing[reduce.idx()],
device,
};
let success = editor.edit(|mut edit| {
let call = edit.add_node(call);
edit.replace_all_uses_where(reduce, call, |id| {
!nodes_in_fork_joins[&fork].contains(id)
})
});
if success {
return;
}
}
}
}
fn get_node_ids_from_subst(var: &str, subst: &Subst, egraph: &EGraph<MathLanguage, ()>) -> NodeID {
let id = *subst.get(var.parse::<Var>().unwrap()).unwrap();
let expr = egraph.id_to_expr(id);
let MathLanguage::Opaque(sym) = expr.last().unwrap() else {
todo!();
};
let sym = sym.as_str();
assert!(sym.chars().nth(0).unwrap() == 'n');
NodeID::new(sym[1..].parse().unwrap())
}
fn egg_print_math_expr<W: Write>(id: MathID, env: &MathEnv, w: &mut W) -> Result<(), Error> {
match env[id.idx()] {
MathExpr::Zero(_) => write!(w, "zero"),
MathExpr::One(_) => write!(w, "one"),
MathExpr::OpaqueNode(id) => write!(w, "n{}", id.idx()),
MathExpr::ThreadID(dim) => write!(w, "{}", dim.0),
MathExpr::SumReduction(id, ref dims) => {
write!(w, "(sum")?;
for dim in dims {
write!(w, " {}", dim.0)?;
}
write!(w, " ")?;
egg_print_math_expr(id, env, w)?;
write!(w, ")")
}
MathExpr::Comprehension(id, ref dims) => {
write!(w, "(array")?;
for dim in dims {
write!(w, " {}", dim.0)?;
}
write!(w, " ")?;
egg_print_math_expr(id, env, w)?;
write!(w, ")")
}
MathExpr::Read(id, ref pos) => {
write!(w, "(read")?;
for pos in pos {
write!(w, " ")?;
egg_print_math_expr(*pos, env, w)?;
}
write!(w, " ")?;
egg_print_math_expr(id, env, w)?;
write!(w, ")")
}
MathExpr::Binary(op, left, right) => {
write!(w, "(")?;
match op {
BinaryOperator::Add => write!(w, "+ "),
BinaryOperator::Mul => write!(w, "* "),
_ => Err(Error::default()),
}?;
egg_print_math_expr(left, env, w)?;
write!(w, " ")?;
egg_print_math_expr(right, env, w)?;
write!(w, ")")
}
_ => Err(Error::default()),
}
}
......@@ -126,9 +126,7 @@ fn remove_useless_fork_joins(
// Third, get rid of fork-joins.
for (fork, join) in fork_join_map {
if editor.get_users(*join).len() == 1 {
assert_eq!(editor.get_users(*fork).len(), 1);
if editor.get_users(*fork).len() == 1 && editor.get_users(*join).len() == 1 {
let fork_use = get_uses(&editor.func().nodes[fork.idx()]).as_ref()[0];
let join_use = get_uses(&editor.func().nodes[join.idx()]).as_ref()[0];
......
......@@ -38,14 +38,33 @@ use crate::*;
*
* - Write: the write node writes primitive fields in product values - these get
* replaced by a direct def of the field value
*
* The allow_sroa_arrays variable controls whether products that contain arrays
* will be broken into pieces. This option is useful to have since breaking
* these products up can be expensive if it requires destructing and
* reconstructing the product at any point.
*
* TODO: Handle partial selections (i.e. immutable nodes). This will involve
* actually tracking each source and use of a product and verifying that all of
* the nodes involved are mutable.
*/
pub fn sroa(editor: &mut FunctionEditor, reverse_postorder: &Vec<NodeID>, types: &Vec<TypeID>) {
pub fn sroa(
editor: &mut FunctionEditor,
reverse_postorder: &Vec<NodeID>,
types: &Vec<TypeID>,
allow_sroa_arrays: bool,
) {
let mut types: HashMap<NodeID, TypeID> = types
.iter()
.enumerate()
.map(|(i, t)| (NodeID::new(i), *t))
.collect();
let can_sroa_type = |editor: &FunctionEditor, typ: TypeID| {
editor.get_type(typ).is_product()
&& (allow_sroa_arrays || !type_contains_array(editor, typ))
};
// This map stores a map from NodeID to an index tree which can be used to lookup the NodeID
// that contains the corresponding fields of the original value
let mut field_map: HashMap<NodeID, IndexTree<NodeID>> = HashMap::new();
......@@ -67,7 +86,7 @@ pub fn sroa(editor: &mut FunctionEditor, reverse_postorder: &Vec<NodeID>, types:
second: _,
third: _,
op: TernaryOperator::Select,
} if editor.get_type(types[&node]).is_product() => product_nodes.push(*node),
} if can_sroa_type(editor, types[&node]) => product_nodes.push(*node),
Node::Write {
collect,
......@@ -83,19 +102,23 @@ pub fn sroa(editor: &mut FunctionEditor, reverse_postorder: &Vec<NodeID>, types:
let mut fields = vec![];
let mut remainder = vec![];
let mut indices = indices.iter();
while let Some(idx) = indices.next() {
if idx.is_field() {
fields.push(idx.clone());
} else {
remainder.push(idx.clone());
remainder.extend(indices.cloned());
break;
if can_sroa_type(editor, types[&node]) {
let mut indices = indices.iter();
while let Some(idx) = indices.next() {
if idx.is_field() {
fields.push(idx.clone());
} else {
remainder.push(idx.clone());
remainder.extend(indices.cloned());
break;
}
}
} else {
remainder.extend_from_slice(indices);
}
if fields.is_empty() {
if editor.get_type(types[&data]).is_product() {
if can_sroa_type(editor, types[&data]) {
(None, Some((*node, collect, remainder)))
} else {
(None, None)
......@@ -205,9 +228,13 @@ pub fn sroa(editor: &mut FunctionEditor, reverse_postorder: &Vec<NodeID>, types:
// that information to the node map for the rest of SROA (this produces some reads
// that mix types of indices, since we only read leaves but that's okay since those
// reads are not handled by SROA)
let indices = indices
.chunk_by(|i, j| i.is_field() && j.is_field())
.collect::<Vec<_>>();
let indices = if can_sroa_type(editor, types[collect]) {
indices
.chunk_by(|i, j| i.is_field() == j.is_field())
.collect::<Vec<_>>()
} else {
vec![indices.as_ref()]
};
let (field_reads, non_fields_produce_prod) = {
if indices.len() == 0 {
......@@ -217,9 +244,9 @@ pub fn sroa(editor: &mut FunctionEditor, reverse_postorder: &Vec<NodeID>, types:
} else if indices.len() == 1 {
// If once we perform chunking there's only one set of indices, we can just
// use the original node
if indices[0][0].is_field() {
if can_sroa_type(editor, types[collect]) {
(vec![*node], vec![])
} else if editor.get_type(types[node]).is_product() {
} else if can_sroa_type(editor, types[node]) {
(vec![], vec![*node])
} else {
(vec![], vec![])
......@@ -278,7 +305,7 @@ pub fn sroa(editor: &mut FunctionEditor, reverse_postorder: &Vec<NodeID>, types:
// We add all calls to the call/return list and check their arguments later
Node::Call { .. } => call_return_nodes.push(*node),
Node::Return { control: _, data } if editor.get_type(types[&data]).is_product() => {
Node::Return { control: _, data } if can_sroa_type(editor, types[&data]) => {
call_return_nodes.push(*node)
}
......@@ -296,7 +323,7 @@ pub fn sroa(editor: &mut FunctionEditor, reverse_postorder: &Vec<NodeID>, types:
for node in call_return_nodes {
match &editor.func().nodes[node.idx()] {
Node::Return { control, data } => {
assert!(editor.get_type(types[&data]).is_product());
assert!(can_sroa_type(editor, types[&data]));
let control = *control;
let new_data = reconstruct_product(editor, types[&data], *data, &mut product_nodes);
editor.edit(|mut edit| {
......@@ -319,8 +346,8 @@ pub fn sroa(editor: &mut FunctionEditor, reverse_postorder: &Vec<NodeID>, types:
let dynamic_constants = dynamic_constants.clone();
let args = args.clone();
// If the call returns a product, we generate reads for each field
let fields = if editor.get_type(types[&node]).is_product() {
// If the call returns a product that we can sroa, we generate reads for each field
let fields = if can_sroa_type(editor, types[&node]) {
Some(generate_reads(editor, types[&node], node))
} else {
None
......@@ -328,7 +355,7 @@ pub fn sroa(editor: &mut FunctionEditor, reverse_postorder: &Vec<NodeID>, types:
let mut new_args = vec![];
for arg in args {
if editor.get_type(types[&arg]).is_product() {
if can_sroa_type(editor, types[&arg]) {
new_args.push(reconstruct_product(
editor,
types[&arg],
......@@ -489,7 +516,7 @@ pub fn sroa(editor: &mut FunctionEditor, reverse_postorder: &Vec<NodeID>, types:
indices,
} => {
if let Some(index_map) = field_map.get(collect) {
if editor.get_type(types[&data]).is_product() {
if can_sroa_type(editor, types[&data]) {
if let Some(data_idx) = field_map.get(data) {
field_map.insert(
node,
......@@ -698,6 +725,16 @@ pub fn sroa(editor: &mut FunctionEditor, reverse_postorder: &Vec<NodeID>, types:
});
}
fn type_contains_array(editor: &FunctionEditor, typ: TypeID) -> bool {
match &*editor.get_type(typ) {
Type::Array(_, _) => true,
Type::Product(ts) | Type::Summation(ts) => {
ts.iter().any(|t| type_contains_array(editor, *t))
}
_ => false,
}
}
// An index tree is used to store results at many index lists
#[derive(Clone, Debug)]
pub enum IndexTree<T> {
......