diff --git a/hercules_cg/src/rt.rs b/hercules_cg/src/rt.rs index 8a41b08d42355fd6ea4d33b769f227e801c57069..cffed48ad91ed1d935cc5a7af1b29a596927eb7a 100644 --- a/hercules_cg/src/rt.rs +++ b/hercules_cg/src/rt.rs @@ -669,10 +669,22 @@ impl<'a> RTContext<'a> { } Node::DataProjection { data, selection } => { let block = &mut blocks.get_mut(&bb).unwrap().data; - write!(block, "{} = {}.{};", - self.get_value(id, bb, true), - self.get_value(data, bb, false), - selection)?; + let Node::Call { function: callee_id, .. } = func.nodes[data.idx()] else { + panic!() + }; + if self.module.functions[callee_id.idx()].return_types.len() == 1 { + assert!(selection == 0); + write!(block, "{} = {};", + self.get_value(id, bb, true), + self.get_value(data, bb, false), + )?; + } else { + write!(block, "{} = {}.{};", + self.get_value(id, bb, true), + self.get_value(data, bb, false), + selection, + )?; + } } Node::LibraryCall { library_function, diff --git a/hercules_ir/src/parse.rs b/hercules_ir/src/parse.rs index d61ff6e76716dd4ff4243a41deab0930e6a3dfdb..a5a05d0f0da6169acdbcfc7b6baca538c9205196 100644 --- a/hercules_ir/src/parse.rs +++ b/hercules_ir/src/parse.rs @@ -513,6 +513,8 @@ fn parse_return<'a>( ), parse_identifier, ).parse(ir_text)?; + let ir_text = nom::character::complete::multispace0(ir_text)?.0; + let ir_text = nom::character::complete::char(')')(ir_text)?.0; let control = context.borrow_mut().get_node_id(control); let data = data .into_iter() diff --git a/hercules_samples/call/src/call.hir b/hercules_samples/call/src/call.hir index cecee343288370e9ce51564e8cc1fc40149bd94c..77f5db2de2b4fa0508048d8155d8745e34ca3168 100644 --- a/hercules_samples/call/src/call.hir +++ b/hercules_samples/call/src/call.hir @@ -2,8 +2,10 @@ fn myfunc(x: u64) -> u64 cr1 = region(start) cr2 = region(cr1) c = constant(u64, 24) - y = call<16>(add, cr1, x, x) - z = call<10>(add, cr2, x, c) + cy = call<16>(add, cr1, x, x) + y = data_projection(cy, 0) + cz = call<10>(add, cr2, x, c) + z = data_projection(cz, 0) w = add(y, z) r = return(cr2, w) diff --git a/hercules_samples/ccp/src/ccp.hir b/hercules_samples/ccp/src/ccp.hir index b8e939942be05f75d85411f05016ed3484ed1bf0..e07df1d37a97054a19605578fe2511525a0c229b 100644 --- a/hercules_samples/ccp/src/ccp.hir +++ b/hercules_samples/ccp/src/ccp.hir @@ -7,14 +7,14 @@ fn tricky(x: i32) -> i32 val = phi(loop, one, later_val) b = ne(one, val) if1 = if(loop, b) - if1_false = projection(if1, 0) - if1_true = projection(if1, 1) + if1_false = control_projection(if1, 0) + if1_true = control_projection(if1, 1) middle = region(if1_false, if1_true) inter_val = sub(two, val) later_val = phi(middle, inter_val, two) idx_dec = sub(idx, one) cond = gte(idx_dec, one) if2 = if(middle, cond) - if2_false = projection(if2, 0) - if2_true = projection(if2, 1) + if2_false = control_projection(if2, 0) + if2_true = control_projection(if2, 1) r = return(if2_false, later_val) diff --git a/hercules_samples/fac/src/fac.hir b/hercules_samples/fac/src/fac.hir index e43dd8cae1a605bca7c3ceac4eb7c029665e86e6..aaf55c1de38cca2c6b024be061eea22c50ad5e6d 100644 --- a/hercules_samples/fac/src/fac.hir +++ b/hercules_samples/fac/src/fac.hir @@ -8,6 +8,6 @@ fn fac(x: i32) -> i32 fac_acc = mul(fac, idx_inc) in_bounds = lt(idx_inc, x) if = if(loop, in_bounds) - if_false = projection(if, 0) - if_true = projection(if, 1) + if_false = control_projection(if, 0) + if_true = control_projection(if, 1) r = return(if_false, fac_acc) diff --git a/juno_samples/rodinia/backprop/src/backprop.jn b/juno_samples/rodinia/backprop/src/backprop.jn index 2927dbb59ef55681160fae56dcea64a3b8329a27..c7f4345bc5dc89d2fb55ee96231e6b5f6604ef4f 100644 --- a/juno_samples/rodinia/backprop/src/backprop.jn +++ b/juno_samples/rodinia/backprop/src/backprop.jn @@ -18,7 +18,7 @@ fn layer_forward<n, m: usize>(vals: f32[n + 1], weights: f32[n + 1, m + 1]) -> f return result; } -fn output_error<n: usize>(target: f32[n + 1], actual: f32[n + 1]) -> (f32, f32[n + 1]) { +fn output_error<n: usize>(target: f32[n + 1], actual: f32[n + 1]) -> f32, f32[n + 1] { let errsum = 0.0; let delta : f32[n + 1]; @@ -29,14 +29,14 @@ fn output_error<n: usize>(target: f32[n + 1], actual: f32[n + 1]) -> (f32, f32[n errsum += abs!(delta[j]); } - return (errsum, delta); + return errsum, delta; } fn hidden_error<hidden_n, output_n: usize>( out_delta: f32[output_n + 1], hidden_weights: f32[hidden_n + 1, output_n + 1], hidden_vals: f32[hidden_n + 1], -) -> (f32, f32[hidden_n + 1]) { +) -> f32, f32[hidden_n + 1] { let errsum = 0.0; let delta : f32[hidden_n + 1]; @@ -52,7 +52,7 @@ fn hidden_error<hidden_n, output_n: usize>( errsum += abs!(delta[j]); } - return (errsum, delta); + return errsum, delta; } const ETA : f32 = 0.3; @@ -63,7 +63,7 @@ fn adjust_weights<n, m: usize>( vals: f32[n + 1], weights: f32[n + 1, m + 1], prev_weights: f32[n + 1, m + 1] -) -> (f32[n + 1, m + 1], f32[n + 1, m + 1]) { +) -> f32[n + 1, m + 1], f32[n + 1, m + 1] { for j in 1..=m { for k in 0..=n { let new_dw = ETA * delta[j] * vals[k] + MOMENTUM * prev_weights[k, j]; @@ -72,7 +72,7 @@ fn adjust_weights<n, m: usize>( } } - return (weights, prev_weights); + return weights, prev_weights; } #[entry] @@ -83,21 +83,19 @@ fn backprop<input_n, hidden_n, output_n: usize>( target: f32[output_n + 1], input_prev_weights: f32[input_n + 1, hidden_n + 1], hidden_prev_weights: f32[hidden_n + 1, output_n + 1], -//) -> (f32, f32, -// f32[input_n + 1, hidden_n + 1], f32[input_n + 1, hidden_n + 1], -// f32[hidden_n + 1, output_n + 1], f32[hidden_n + 1, output_n + 1]) { -) -> (f32, f32, f32) { +) -> f32, f32, + f32[input_n + 1, hidden_n + 1], f32[input_n + 1, hidden_n + 1], + f32[hidden_n + 1, output_n + 1], f32[hidden_n + 1, output_n + 1] { let hidden_vals = layer_forward::<input_n, hidden_n>(input_vals, input_weights); let output_vals = layer_forward::<hidden_n, output_n>(hidden_vals, hidden_weights); - let (out_err, out_delta) = output_error::<output_n>(target, output_vals); - let (hid_err, hid_delta) = hidden_error::<hidden_n, output_n>(out_delta, hidden_weights, hidden_vals); + let out_err, out_delta = output_error::<output_n>(target, output_vals); + let hid_err, hid_delta = hidden_error::<hidden_n, output_n>(out_delta, hidden_weights, hidden_vals); - let (hidden_weights, hidden_prev_weights) + let hidden_weights, hidden_prev_weights = adjust_weights::<hidden_n, output_n>(out_delta, hidden_vals, hidden_weights, hidden_prev_weights); - let (input_weights, input_prev_weights) + let input_weights, input_prev_weights = adjust_weights::<input_n, hidden_n>(hid_delta, input_vals, input_weights, input_prev_weights); - return (out_err, hid_err, input_weights[0, 0] + input_prev_weights[0, 0] + hidden_weights[0, 0] + hidden_prev_weights[0, 0]); - //return (input_weights, input_prev_weights, hidden_weights, hidden_prev_weights); + return out_err, hid_err, input_weights, input_prev_weights, hidden_weights, hidden_prev_weights; } diff --git a/juno_samples/rodinia/backprop/src/main.rs b/juno_samples/rodinia/backprop/src/main.rs index 848b0abb02cdca6d26d9aa2961e3e0206844fa19..23f78fe4f7783717bb3c3c89ecd8ee0f17270169 100644 --- a/juno_samples/rodinia/backprop/src/main.rs +++ b/juno_samples/rodinia/backprop/src/main.rs @@ -37,7 +37,14 @@ fn run_backprop( let mut hidden_prev_weights = HerculesMutBox::from(hidden_prev_weights.to_vec()); let mut runner = runner!(backprop); - let res = HerculesMutBox::from(async_std::task::block_on(async { + let ( + out_err, + hid_err, + input_weights, + input_prev_weights, + hidden_weights, + hidden_prev_weights + ) = async_std::task::block_on(async { runner .run( input_n, @@ -51,11 +58,11 @@ fn run_backprop( hidden_prev_weights.to(), ) .await - })) - .as_slice() - .to_vec(); - let out_err = res[0]; - let hid_err = res[1]; + }); + let mut input_weights = HerculesMutBox::from(input_weights); + let mut hidden_weights = HerculesMutBox::from(hidden_weights); + let mut input_prev_weights = HerculesMutBox::from(input_prev_weights); + let mut hidden_prev_weights = HerculesMutBox::from(hidden_prev_weights); ( out_err,