Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
H
Hercules
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Package registry
Model registry
Operate
Environments
Terraform modules
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
llvm
Hercules
Commits
32ad1e82
Commit
32ad1e82
authored
6 months ago
by
Praneet Rathi
Browse files
Options
Downloads
Patches
Plain Diff
untested
parent
6af2e9ec
No related branches found
No related tags found
1 merge request
!115
GPU backend
Pipeline
#201084
passed
6 months ago
Stage: test
Changes
1
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
hercules_cg/src/gpu.rs
+57
-92
57 additions, 92 deletions
hercules_cg/src/gpu.rs
with
57 additions
and
92 deletions
hercules_cg/src/gpu.rs
+
57
−
92
View file @
32ad1e82
...
@@ -80,17 +80,17 @@ pub fn gpu_codegen<W: Write>(
...
@@ -80,17 +80,17 @@ pub fn gpu_codegen<W: Write>(
.collect
();
.collect
();
let
fork_join_map
=
&
fork_join_map
(
function
,
control_subgraph
);
let
fork_join_map
=
fork_join_map
(
function
,
control_subgraph
);
let
join_fork_map
:
&
HashMap
<
NodeID
,
NodeID
>
=
&
fork_join_map
let
join_fork_map
:
HashMap
<
NodeID
,
NodeID
>
=
fork_join_map
.
into_
iter
()
.iter
()
.map
(|(
fork
,
join
)|
(
*
join
,
*
fork
))
.map
(|(
fork
,
join
)|
(
*
join
,
*
fork
))
.collect
();
.collect
();
// Fork Reduce map should have all reduces contained in some key
// Fork Reduce map should have all reduces contained in some key
let
fork_reduce_map
:
&
mut
HashMap
<
NodeID
,
Vec
<
NodeID
>>
=
&
mut
HashMap
::
new
();
let
mut
fork_reduce_map
:
HashMap
<
NodeID
,
Vec
<
NodeID
>>
=
HashMap
::
new
();
// Reduct Reduce map should have all non-parallel and non-associative reduces
// Reduct Reduce map should have all non-parallel and non-associative reduces
// contained in some key. Unlike Fork, Reduct is not involved in any assertions.
// contained in some key. Unlike Fork, Reduct is not involved in any assertions.
// It's placed here for convenience but can be moved.
// It's placed here for convenience but can be moved.
let
reduct_reduce_map
:
&
mut
HashMap
<
NodeID
,
Vec
<
NodeID
>>
=
&
mut
HashMap
::
new
();
let
mut
reduct_reduce_map
:
HashMap
<
NodeID
,
Vec
<
NodeID
>>
=
HashMap
::
new
();
for
reduce_node
in
&
reduce_nodes
{
for
reduce_node
in
&
reduce_nodes
{
if
let
Node
::
Reduce
{
if
let
Node
::
Reduce
{
control
,
control
,
...
@@ -124,11 +124,13 @@ pub fn gpu_codegen<W: Write>(
...
@@ -124,11 +124,13 @@ pub fn gpu_codegen<W: Write>(
}
}
}
}
for
idx
in
0
..
function
.nodes
.len
()
{
for
idx
in
0
..
function
.nodes
.len
()
{
if
function
.nodes
[
idx
]
.is_fork
()
if
function
.nodes
[
idx
]
.is_fork
()
{
&&
fork_reduce_map
assert!
(
fork_reduce_map
.get
(
&
NodeID
::
new
(
idx
))
.is_none_or
(|
reduces
|
reduces
.is_empty
())
.get
(
&
NodeID
::
new
(
idx
))
{
.is_none_or
(|
reduces
|
reduces
.is_empty
()),
panic!
(
"Fork node {} has no reduce nodes"
,
idx
);
"Fork node {} has no reduce nodes"
,
idx
);
}
}
}
}
...
@@ -155,7 +157,7 @@ pub fn gpu_codegen<W: Write>(
...
@@ -155,7 +157,7 @@ pub fn gpu_codegen<W: Write>(
(
NodeID
::
new
(
pos
),
*
data
)
(
NodeID
::
new
(
pos
),
*
data
)
};
};
let
return_type_id
=
&
typing
[
data_node_id
.idx
()];
let
return_type_id
=
typing
[
data_node_id
.idx
()];
let
return_type
=
&
types
[
return_type_id
.idx
()];
let
return_type
=
&
types
[
return_type_id
.idx
()];
let
return_param_idx
=
if
!
return_type
.is_primitive
()
{
let
return_param_idx
=
if
!
return_type
.is_primitive
()
{
let
objects
=
&
collection_objects
.objects
(
data_node_id
);
let
objects
=
&
collection_objects
.objects
(
data_node_id
);
...
@@ -186,7 +188,7 @@ pub fn gpu_codegen<W: Write>(
...
@@ -186,7 +188,7 @@ pub fn gpu_codegen<W: Write>(
// Map from control to pairs of data to update phi
// Map from control to pairs of data to update phi
// For each phi, we go to its region and get region's controls
// For each phi, we go to its region and get region's controls
let
control_data_phi_map
:
&
mut
HashMap
<
NodeID
,
Vec
<
(
NodeID
,
NodeID
)
>>
=
&
mut
HashMap
::
new
();
let
mut
control_data_phi_map
:
HashMap
<
NodeID
,
Vec
<
(
NodeID
,
NodeID
)
>>
=
HashMap
::
new
();
for
(
idx
,
node
)
in
function
.nodes
.iter
()
.enumerate
()
{
for
(
idx
,
node
)
in
function
.nodes
.iter
()
.enumerate
()
{
if
let
Node
::
Phi
{
control
,
data
}
=
node
{
if
let
Node
::
Phi
{
control
,
data
}
=
node
{
let
Node
::
Region
{
preds
}
=
&
function
.nodes
[
control
.idx
()]
else
{
let
Node
::
Region
{
preds
}
=
&
function
.nodes
[
control
.idx
()]
else
{
...
@@ -237,12 +239,12 @@ struct GPUContext<'a> {
...
@@ -237,12 +239,12 @@ struct GPUContext<'a> {
bbs
:
&
'a
BasicBlocks
,
bbs
:
&
'a
BasicBlocks
,
kernel_params
:
&
'a
GPUKernelParams
,
kernel_params
:
&
'a
GPUKernelParams
,
def_use_map
:
&
'a
ImmutableDefUseMap
,
def_use_map
:
&
'a
ImmutableDefUseMap
,
fork_join_map
:
&
'a
HashMap
<
NodeID
,
NodeID
>
,
fork_join_map
:
HashMap
<
NodeID
,
NodeID
>
,
join_fork_map
:
&
'a
HashMap
<
NodeID
,
NodeID
>
,
join_fork_map
:
HashMap
<
NodeID
,
NodeID
>
,
fork_reduce_map
:
&
'a
HashMap
<
NodeID
,
Vec
<
NodeID
>>
,
fork_reduce_map
:
HashMap
<
NodeID
,
Vec
<
NodeID
>>
,
reduct_reduce_map
:
&
'a
HashMap
<
NodeID
,
Vec
<
NodeID
>>
,
reduct_reduce_map
:
HashMap
<
NodeID
,
Vec
<
NodeID
>>
,
control_data_phi_map
:
&
'a
HashMap
<
NodeID
,
Vec
<
(
NodeID
,
NodeID
)
>>
,
control_data_phi_map
:
HashMap
<
NodeID
,
Vec
<
(
NodeID
,
NodeID
)
>>
,
return_type_id
:
&
'a
TypeID
,
return_type_id
:
TypeID
,
return_param_idx
:
Option
<
usize
>
,
return_param_idx
:
Option
<
usize
>
,
}
}
...
@@ -318,7 +320,7 @@ impl GPUContext<'_> {
...
@@ -318,7 +320,7 @@ impl GPUContext<'_> {
(
1
,
1
)
(
1
,
1
)
}
else
{
}
else
{
// Create structures and determine block and thread parallelization strategy
// Create structures and determine block and thread parallelization strategy
let
(
fork_tree
,
fork_control_map
)
=
self
.make_fork_structures
(
self
.fork_join_map
);
let
(
fork_tree
,
fork_control_map
)
=
self
.make_fork_structures
(
&
self
.fork_join_map
);
let
(
root_forks
,
num_blocks
)
=
let
(
root_forks
,
num_blocks
)
=
self
.get_root_forks_and_num_blocks
(
&
fork_tree
,
self
.kernel_params.max_num_blocks
);
self
.get_root_forks_and_num_blocks
(
&
fork_tree
,
self
.kernel_params.max_num_blocks
);
let
(
thread_root_root_fork
,
thread_root_forks
)
=
self
.get_thread_root_forks
(
&
root_forks
,
&
fork_tree
,
num_blocks
);
let
(
thread_root_root_fork
,
thread_root_forks
)
=
self
.get_thread_root_forks
(
&
root_forks
,
&
fork_tree
,
num_blocks
);
...
@@ -422,7 +424,7 @@ namespace cg = cooperative_groups;
...
@@ -422,7 +424,7 @@ namespace cg = cooperative_groups;
write!
(
write!
(
w
,
w
,
"{} __restrict__ ret"
,
"{} __restrict__ ret"
,
self
.get_type
(
*
self
.return_type_id
,
true
)
self
.get_type
(
self
.return_type_id
,
true
)
)
?
;
)
?
;
}
}
...
@@ -536,7 +538,7 @@ namespace cg = cooperative_groups;
...
@@ -536,7 +538,7 @@ namespace cg = cooperative_groups;
// need to pass arguments to kernel, so we keep track of the arguments here.
// need to pass arguments to kernel, so we keep track of the arguments here.
let
mut
pass_args
=
String
::
new
();
let
mut
pass_args
=
String
::
new
();
let
ret_primitive
=
self
.types
[
self
.return_type_id
.idx
()]
.is_primitive
();
let
ret_primitive
=
self
.types
[
self
.return_type_id
.idx
()]
.is_primitive
();
let
ret_type
=
self
.get_type
(
*
self
.return_type_id
,
false
);
let
ret_type
=
self
.get_type
(
self
.return_type_id
,
false
);
write!
(
w
,
"
write!
(
w
,
"
extern
\"
C
\"
{} {}("
,
ret_type
.clone
(),
self
.function.name
)
?
;
extern
\"
C
\"
{} {}("
,
ret_type
.clone
(),
self
.function.name
)
?
;
// The first set of parameters are dynamic constants.
// The first set of parameters are dynamic constants.
...
@@ -566,7 +568,7 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?;
...
@@ -566,7 +568,7 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?;
write!
(
w
,
") {{
\n
"
)
?
;
write!
(
w
,
") {{
\n
"
)
?
;
// Pull primitive return as pointer parameter for kernel
// Pull primitive return as pointer parameter for kernel
if
ret_primitive
{
if
ret_primitive
{
let
ret_type_pnt
=
self
.get_type
(
*
self
.return_type_id
,
true
);
let
ret_type_pnt
=
self
.get_type
(
self
.return_type_id
,
true
);
write!
(
w
,
"
\t
{} ret;
\n
"
,
ret_type_pnt
)
?
;
write!
(
w
,
"
\t
{} ret;
\n
"
,
ret_type_pnt
)
?
;
write!
(
w
,
"
\t
cudaMalloc((void**)&ret, sizeof({}));
\n
"
,
ret_type
)
?
;
write!
(
w
,
"
\t
cudaMalloc((void**)&ret, sizeof({}));
\n
"
,
ret_type
)
?
;
if
!
first_param
{
if
!
first_param
{
...
@@ -1267,16 +1269,11 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?;
...
@@ -1267,16 +1269,11 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?;
// If we read collection, distribute elements among threads with cg
// If we read collection, distribute elements among threads with cg
// sync after. If we read primitive, copy read on all threads.
// sync after. If we read primitive, copy read on all threads.
Node
::
Read
{
collect
,
indices
}
=>
{
Node
::
Read
{
collect
,
indices
}
=>
{
let
is_char
=
self
.is_char
(
self
.typing
[
collect
.idx
()]);
let
collect_with_indices
=
self
.codegen_collect
(
*
collect
,
indices
,
extra_dim_collects
.contains
(
&
self
.typing
[
collect
.idx
()]));
let
collect_with_indices
=
self
.codegen_collect
(
*
collect
,
indices
,
is_char
,
extra_dim_collects
.contains
(
&
self
.typing
[
collect
.idx
()]));
let
data_type_id
=
self
.typing
[
id
.idx
()];
let
data_type_id
=
self
.typing
[
id
.idx
()];
if
self
.types
[
data_type_id
.idx
()]
.is_primitive
()
{
if
self
.types
[
data_type_id
.idx
()]
.is_primitive
()
{
if
is_char
{
let
type_name
=
self
.get_type
(
data_type_id
,
true
);
let
type_name
=
self
.get_type
(
data_type_id
,
true
);
write!
(
w
,
"{}{} = *reinterpret_cast<{}>({});
\n
"
,
tabs
,
define_variable
,
type_name
,
collect_with_indices
)
?
;
write!
(
w
,
"{}{} = *reinterpret_cast<{}>({});
\n
"
,
tabs
,
define_variable
,
type_name
,
collect_with_indices
)
?
;
}
else
{
write!
(
w
,
"{}{} = *({});
\n
"
,
tabs
,
define_variable
,
collect_with_indices
)
?
;
}
}
else
{
}
else
{
if
KernelState
::
OutBlock
==
state
&&
num_blocks
.unwrap
()
>
1
{
if
KernelState
::
OutBlock
==
state
&&
num_blocks
.unwrap
()
>
1
{
panic!
(
"GPU can't guarantee correctness for multi-block collection reads"
);
panic!
(
"GPU can't guarantee correctness for multi-block collection reads"
);
...
@@ -1287,13 +1284,12 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?;
...
@@ -1287,13 +1284,12 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?;
};
};
// Divide up "elements", which are collection size divided
// Divide up "elements", which are collection size divided
// by element size, among threads.
// by element size, among threads.
let
data_size
=
self
.get_size
(
data_type_id
,
None
,
Some
(
extra_dim_collects
),
Some
(
true
));
let
data_size
=
self
.get_size
(
data_type_id
,
None
,
Some
(
extra_dim_collects
));
let
num_elements
=
format!
(
"({})"
,
data_size
);
write!
(
w
,
"{}for (int i = {}.thread_rank(); i < {}; i += {}.size()) {{
\n
"
,
tabs
,
cg_tile
,
data_size
,
cg_tile
)
?
;
write!
(
w
,
"{}for (int i = {}.thread_rank(); i < {}; i += {}.size()) {{
\n
"
,
tabs
,
cg_tile
,
num_elements
,
cg_tile
)
?
;
write!
(
w
,
"{}
\t
*({} + i) = *({} + i);
\n
"
,
tabs
,
define_variable
,
collect_with_indices
)
?
;
write!
(
w
,
"{}
\t
*({} + i) = *({} + i);
\n
"
,
tabs
,
define_variable
,
collect_with_indices
)
?
;
write!
(
w
,
"{}}}
\n
"
,
tabs
)
?
;
write!
(
w
,
"{}}}
\n
"
,
tabs
)
?
;
write!
(
w
,
"{}if ({}.thread_rank() < {} % {}.size()) {{
\n
"
,
tabs
,
cg_tile
,
num_elements
,
cg_tile
)
?
;
write!
(
w
,
"{}if ({}.thread_rank() < {} % {}.size()) {{
\n
"
,
tabs
,
cg_tile
,
data_size
,
cg_tile
)
?
;
write!
(
w
,
"{}
\t
*({} + {}.size() * ({} / {}.size()) + {}.thread_rank()) = *({} + {}.size() * ({} / {}.size()) + {}.thread_rank());
\n
"
,
tabs
,
define_variable
,
cg_tile
,
num_elements
,
cg_tile
,
cg_tile
,
collect_with_indices
,
cg_tile
,
num_elements
,
cg_tile
,
cg_tile
)
?
;
write!
(
w
,
"{}
\t
*({} + {}.size() * ({} / {}.size()) + {}.thread_rank()) = *({} + {}.size() * ({} / {}.size()) + {}.thread_rank());
\n
"
,
tabs
,
define_variable
,
cg_tile
,
data_size
,
cg_tile
,
cg_tile
,
collect_with_indices
,
cg_tile
,
data_size
,
cg_tile
,
cg_tile
)
?
;
write!
(
w
,
"{}}}
\n
"
,
tabs
)
?
;
write!
(
w
,
"{}}}
\n
"
,
tabs
)
?
;
write!
(
w
,
"{}{}.sync();
\n
"
,
tabs
,
cg_tile
)
?
;
write!
(
w
,
"{}{}.sync();
\n
"
,
tabs
,
cg_tile
)
?
;
}
}
...
@@ -1305,8 +1301,7 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?;
...
@@ -1305,8 +1301,7 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?;
data
,
data
,
indices
,
indices
,
}
=>
{
}
=>
{
let
is_char
=
self
.is_char
(
self
.typing
[
collect
.idx
()]);
let
collect_with_indices
=
self
.codegen_collect
(
*
collect
,
indices
,
extra_dim_collects
.contains
(
&
self
.typing
[
collect
.idx
()]));
let
collect_with_indices
=
self
.codegen_collect
(
*
collect
,
indices
,
is_char
,
extra_dim_collects
.contains
(
&
self
.typing
[
collect
.idx
()]));
let
data_variable
=
self
.get_value
(
*
data
,
false
,
false
);
let
data_variable
=
self
.get_value
(
*
data
,
false
,
false
);
let
data_type_id
=
self
.typing
[
data
.idx
()];
let
data_type_id
=
self
.typing
[
data
.idx
()];
if
KernelState
::
OutBlock
==
state
&&
num_blocks
.unwrap
()
>
1
{
if
KernelState
::
OutBlock
==
state
&&
num_blocks
.unwrap
()
>
1
{
...
@@ -1318,21 +1313,16 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?;
...
@@ -1318,21 +1313,16 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?;
};
};
if
self
.types
[
data_type_id
.idx
()]
.is_primitive
()
{
if
self
.types
[
data_type_id
.idx
()]
.is_primitive
()
{
write!
(
w
,
"{}if ({}.thread_rank() == 0) {{
\n
"
,
tabs
,
cg_tile
)
?
;
write!
(
w
,
"{}if ({}.thread_rank() == 0) {{
\n
"
,
tabs
,
cg_tile
)
?
;
if
is_char
{
let
type_name
=
self
.get_type
(
data_type_id
,
true
);
let
type_name
=
self
.get_type
(
data_type_id
,
true
);
write!
(
w
,
"{}
\t
*reinterpret_cast<{}>({}) = {};
\n
"
,
tabs
,
type_name
,
collect_with_indices
,
data_variable
)
?
;
write!
(
w
,
"{}
\t
*reinterpret_cast<{}>({}) = {};
\n
"
,
tabs
,
type_name
,
collect_with_indices
,
data_variable
)
?
;
}
else
{
write!
(
w
,
"{}
\t
*({}) = {};
\n
"
,
tabs
,
collect_with_indices
,
data_variable
)
?
;
}
write!
(
w
,
"{}}}
\n
"
,
tabs
)
?
;
write!
(
w
,
"{}}}
\n
"
,
tabs
)
?
;
}
else
{
}
else
{
let
data_size
=
self
.get_size
(
data_type_id
,
None
,
Some
(
extra_dim_collects
),
Some
(
true
));
let
data_size
=
self
.get_size
(
data_type_id
,
None
,
Some
(
extra_dim_collects
));
let
num_elements
=
format!
(
"({})"
,
data_size
);
write!
(
w
,
"{}for (int i = {}.thread_rank(); i < {}; i += {}.size()) {{
\n
"
,
tabs
,
cg_tile
,
data_size
,
cg_tile
)
?
;
write!
(
w
,
"{}for (int i = {}.thread_rank(); i < {}; i += {}.size()) {{
\n
"
,
tabs
,
cg_tile
,
num_elements
,
cg_tile
)
?
;
write!
(
w
,
"{}
\t
*({} + i) = *({} + i);
\n
"
,
tabs
,
collect_with_indices
,
data_variable
)
?
;
write!
(
w
,
"{}
\t
*({} + i) = *({} + i);
\n
"
,
tabs
,
collect_with_indices
,
data_variable
)
?
;
write!
(
w
,
"{}}}
\n
"
,
tabs
)
?
;
write!
(
w
,
"{}}}
\n
"
,
tabs
)
?
;
write!
(
w
,
"{}if ({}.thread_rank() < {} % {}.size()) {{
\n
"
,
tabs
,
cg_tile
,
num_elements
,
cg_tile
)
?
;
write!
(
w
,
"{}if ({}.thread_rank() < {} % {}.size()) {{
\n
"
,
tabs
,
cg_tile
,
data_size
,
cg_tile
)
?
;
write!
(
w
,
"{}
\t
*({} + {}.size() * ({} / {}.size()) + {}.thread_rank()) = *({} + {}.size() * ({} / {}.size()) + {}.thread_rank());
\n
"
,
tabs
,
collect_with_indices
,
cg_tile
,
num_elements
,
cg_tile
,
cg_tile
,
data_variable
,
cg_tile
,
num_elements
,
cg_tile
,
cg_tile
)
?
;
write!
(
w
,
"{}
\t
*({} + {}.size() * ({} / {}.size()) + {}.thread_rank()) = *({} + {}.size() * ({} / {}.size()) + {}.thread_rank());
\n
"
,
tabs
,
collect_with_indices
,
cg_tile
,
data_size
,
cg_tile
,
cg_tile
,
data_variable
,
cg_tile
,
data_size
,
cg_tile
,
cg_tile
)
?
;
write!
(
w
,
"{}}}
\n
"
,
tabs
)
?
;
write!
(
w
,
"{}}}
\n
"
,
tabs
)
?
;
}
}
write!
(
w
,
"{}{}.sync();
\n
"
,
tabs
,
cg_tile
)
?
;
write!
(
w
,
"{}{}.sync();
\n
"
,
tabs
,
cg_tile
)
?
;
...
@@ -1508,18 +1498,15 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?;
...
@@ -1508,18 +1498,15 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?;
/*
/*
* This function emits collection name + pointer math for the provided indices.
* This function emits collection name + pointer math for the provided indices.
* One nuance is whether the collection is represented as char pointer or
* All collection types use char pointers.
* the original primitive pointer. For Field, it's always char, for Variant,
* it doesn't matter here, and for Array, it depends- so we may need to tack
* on the element size to the index math.
*/
*/
fn
codegen_collect
(
&
self
,
collect
:
NodeID
,
indices
:
&
[
Index
],
is_char
:
bool
,
has_extra_dim
:
bool
)
->
String
{
fn
codegen_collect
(
&
self
,
collect
:
NodeID
,
indices
:
&
[
Index
],
has_extra_dim
:
bool
)
->
String
{
let
mut
index_ptr
=
"0"
.to_string
();
let
mut
index_ptr
=
"0"
.to_string
();
let
type_id
=
self
.typing
[
collect
.idx
()];
let
type_id
=
self
.typing
[
collect
.idx
()];
for
index
in
indices
{
for
index
in
indices
{
match
index
{
match
index
{
Index
::
Field
(
field
)
=>
{
Index
::
Field
(
field
)
=>
{
self
.get_size
(
type_id
,
Some
(
*
field
),
None
,
None
);
self
.get_size
(
type_id
,
Some
(
*
field
),
None
);
}
}
// Variants of summations have zero offset
// Variants of summations have zero offset
Index
::
Variant
(
_
)
=>
{}
Index
::
Variant
(
_
)
=>
{}
...
@@ -1550,10 +1537,8 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?;
...
@@ -1550,10 +1537,8 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?;
cumulative_offset
,
cumulative_offset
,
")"
.repeat
(
array_indices
.len
()
-
if
has_extra_dim
{
1
}
else
{
0
})
")"
.repeat
(
array_indices
.len
()
-
if
has_extra_dim
{
1
}
else
{
0
})
));
));
if
is_char
{
let
element_size
=
self
.get_size
(
*
element_type
,
None
,
None
);
let
element_size
=
self
.get_size
(
*
element_type
,
None
,
None
,
None
);
index_ptr
.push_str
(
&
format!
(
" * {}"
,
element_size
));
index_ptr
.push_str
(
&
format!
(
" * {}"
,
element_size
));
}
}
}
}
}
}
}
...
@@ -1600,7 +1585,7 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?;
...
@@ -1600,7 +1585,7 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?;
Constant
::
Product
(
type_id
,
constant_fields
)
=>
{
Constant
::
Product
(
type_id
,
constant_fields
)
=>
{
if
allow_allocate
{
if
allow_allocate
{
let
alignment
=
self
.get_alignment
(
*
type_id
);
let
alignment
=
self
.get_alignment
(
*
type_id
);
let
size
=
self
.get_size
(
*
type_id
,
None
,
extra_dim_collects
,
None
);
let
size
=
self
.get_size
(
*
type_id
,
None
,
extra_dim_collects
);
*
dynamic_shared_offset
=
format!
(
"(({} + {} - 1) / {}) * {}"
,
dynamic_shared_offset
,
alignment
,
alignment
,
alignment
);
*
dynamic_shared_offset
=
format!
(
"(({} + {} - 1) / {}) * {}"
,
dynamic_shared_offset
,
alignment
,
alignment
,
alignment
);
write!
(
w
,
"{}dynamic_shared_offset = {};
\n
"
,
tabs
,
dynamic_shared_offset
)
?
;
write!
(
w
,
"{}dynamic_shared_offset = {};
\n
"
,
tabs
,
dynamic_shared_offset
)
?
;
write!
(
w
,
"{}{} = dynamic_shared + dynamic_shared_offset;
\n
"
,
tabs
,
name
)
?
;
write!
(
w
,
"{}{} = dynamic_shared + dynamic_shared_offset;
\n
"
,
tabs
,
name
)
?
;
...
@@ -1612,7 +1597,7 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?;
...
@@ -1612,7 +1597,7 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?;
for
i
in
0
..
constant_fields
.len
()
{
for
i
in
0
..
constant_fields
.len
()
{
// For each field update offset and issue recursive call
// For each field update offset and issue recursive call
let
field_type
=
self
.get_type
(
type_fields
[
i
],
true
);
let
field_type
=
self
.get_type
(
type_fields
[
i
],
true
);
let
offset
=
self
.get_size
(
type_fields
[
i
],
Some
(
i
),
extra_dim_collects
,
None
);
let
offset
=
self
.get_size
(
type_fields
[
i
],
Some
(
i
),
extra_dim_collects
);
let
field_constant
=
&
self
.constants
[
constant_fields
[
i
]
.idx
()];
let
field_constant
=
&
self
.constants
[
constant_fields
[
i
]
.idx
()];
if
field_constant
.is_scalar
()
{
if
field_constant
.is_scalar
()
{
self
.codegen_constant
(
self
.codegen_constant
(
...
@@ -1632,7 +1617,7 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?;
...
@@ -1632,7 +1617,7 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?;
Constant
::
Summation
(
type_id
,
variant
,
field
)
=>
{
Constant
::
Summation
(
type_id
,
variant
,
field
)
=>
{
if
allow_allocate
{
if
allow_allocate
{
let
alignment
=
self
.get_alignment
(
*
type_id
);
let
alignment
=
self
.get_alignment
(
*
type_id
);
let
size
=
self
.get_size
(
*
type_id
,
None
,
extra_dim_collects
,
None
);
let
size
=
self
.get_size
(
*
type_id
,
None
,
extra_dim_collects
);
*
dynamic_shared_offset
=
format!
(
"(({} + {} - 1) / {}) * {}"
,
dynamic_shared_offset
,
alignment
,
alignment
,
alignment
);
*
dynamic_shared_offset
=
format!
(
"(({} + {} - 1) / {}) * {}"
,
dynamic_shared_offset
,
alignment
,
alignment
,
alignment
);
write!
(
w
,
"{}dynamic_shared_offset = {};
\n
"
,
tabs
,
dynamic_shared_offset
)
?
;
write!
(
w
,
"{}dynamic_shared_offset = {};
\n
"
,
tabs
,
dynamic_shared_offset
)
?
;
write!
(
w
,
"{}{} = dynamic_shared + dynamic_shared_offset;
\n
"
,
tabs
,
name
)
?
;
write!
(
w
,
"{}{} = dynamic_shared + dynamic_shared_offset;
\n
"
,
tabs
,
name
)
?
;
...
@@ -1660,18 +1645,14 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?;
...
@@ -1660,18 +1645,14 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?;
};
};
}
}
Constant
::
Array
(
type_id
)
=>
{
Constant
::
Array
(
type_id
)
=>
{
let
Type
::
Array
(
element_type
,
_
)
=
&
self
.types
[
type_id
.idx
()]
else
{
panic!
(
"Expected array type"
)
};
if
!
allow_allocate
{
if
!
allow_allocate
{
panic!
(
"Nested array constant should not be re-allocated"
);
panic!
(
"Nested array constant should not be re-allocated"
);
}
}
let
alignment
=
self
.get_alignment
(
*
type_id
);
let
alignment
=
self
.get_alignment
(
*
type_id
);
let
size
=
self
.get_size
(
*
type_id
,
None
,
extra_dim_collects
,
None
);
let
size
=
self
.get_size
(
*
type_id
,
None
,
extra_dim_collects
);
let
element_type
=
self
.get_type
(
*
element_type
,
true
);
*
dynamic_shared_offset
=
format!
(
"(({} + {} - 1) / {}) * {}"
,
dynamic_shared_offset
,
alignment
,
alignment
,
alignment
);
*
dynamic_shared_offset
=
format!
(
"(({} + {} - 1) / {}) * {}"
,
dynamic_shared_offset
,
alignment
,
alignment
,
alignment
);
write!
(
w
,
"{}dynamic_shared_offset = {};
\n
"
,
tabs
,
dynamic_shared_offset
)
?
;
write!
(
w
,
"{}dynamic_shared_offset = {};
\n
"
,
tabs
,
dynamic_shared_offset
)
?
;
write!
(
w
,
"{}{} =
reinterpret_cast<{}>(
dynamic_shared + dynamic_shared_offset
)
;
\n
"
,
tabs
,
name
,
element_type
)
?
;
write!
(
w
,
"{}{} = dynamic_shared + dynamic_shared_offset;
\n
"
,
tabs
,
name
)
?
;
*
dynamic_shared_offset
=
format!
(
"{} + {}"
,
dynamic_shared_offset
,
size
);
*
dynamic_shared_offset
=
format!
(
"{} + {}"
,
dynamic_shared_offset
,
size
);
}
}
}
}
...
@@ -1684,15 +1665,11 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?;
...
@@ -1684,15 +1665,11 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?;
* and offset to 2nd field. This is useful for constant initialization and read/write
* and offset to 2nd field. This is useful for constant initialization and read/write
* index math.
* index math.
*/
*/
fn
get_size
(
&
self
,
type_id
:
TypeID
,
num_fields
:
Option
<
usize
>
,
extra_dim_collects
:
Option
<&
HashSet
<
TypeID
>>
,
exclude_element_size
:
Option
<
bool
>
)
->
String
{
fn
get_size
(
&
self
,
type_id
:
TypeID
,
num_fields
:
Option
<
usize
>
,
extra_dim_collects
:
Option
<&
HashSet
<
TypeID
>>
)
->
String
{
match
&
self
.types
[
type_id
.idx
()]
{
match
&
self
.types
[
type_id
.idx
()]
{
Type
::
Array
(
element_type
,
extents
)
=>
{
Type
::
Array
(
element_type
,
extents
)
=>
{
let
array_size
=
multiply_dcs
(
if
extra_dim_collects
.is_some
()
&&
extra_dim_collects
.unwrap
()
.contains
(
&
type_id
)
{
&
extents
[
1
..
]
}
else
{
extents
});
let
array_size
=
multiply_dcs
(
if
extra_dim_collects
.is_some
()
&&
extra_dim_collects
.unwrap
()
.contains
(
&
type_id
)
{
&
extents
[
1
..
]
}
else
{
extents
});
if
exclude_element_size
.unwrap_or
(
false
)
{
format!
(
"{} * {}"
,
self
.get_alignment
(
*
element_type
),
array_size
)
array_size
}
else
{
format!
(
"{} * {}"
,
self
.get_alignment
(
*
element_type
),
array_size
)
}
}
}
Type
::
Product
(
fields
)
=>
{
Type
::
Product
(
fields
)
=>
{
let
num_fields
=
&
num_fields
.unwrap_or
(
fields
.len
());
let
num_fields
=
&
num_fields
.unwrap_or
(
fields
.len
());
...
@@ -1700,7 +1677,7 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?;
...
@@ -1700,7 +1677,7 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?;
.iter
()
.iter
()
.enumerate
()
.enumerate
()
.filter
(|(
i
,
_
)|
i
<
num_fields
)
.filter
(|(
i
,
_
)|
i
<
num_fields
)
.map
(|(
_
,
id
)|
(
self
.get_size
(
*
id
,
None
,
extra_dim_collects
,
None
),
self
.get_alignment
(
*
id
)))
.map
(|(
_
,
id
)|
(
self
.get_size
(
*
id
,
None
,
extra_dim_collects
),
self
.get_alignment
(
*
id
)))
.fold
(
String
::
from
(
"0"
),
|
acc
,
(
size
,
align
)|
{
.fold
(
String
::
from
(
"0"
),
|
acc
,
(
size
,
align
)|
{
if
acc
==
"0"
{
if
acc
==
"0"
{
size
size
...
@@ -1715,7 +1692,7 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?;
...
@@ -1715,7 +1692,7 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?;
format!
(
format!
(
"{} - {}"
,
"{} - {}"
,
with_field
,
with_field
,
self
.get_size
(
fields
[
*
num_fields
],
None
,
extra_dim_collects
,
None
)
self
.get_size
(
fields
[
*
num_fields
],
None
,
extra_dim_collects
)
)
)
}
else
{
}
else
{
with_field
with_field
...
@@ -1725,7 +1702,7 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?;
...
@@ -1725,7 +1702,7 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?;
// The argmax variant by size is not guaranteed to be same as
// The argmax variant by size is not guaranteed to be same as
// argmax variant by alignment, eg product of 3 4-byte primitives
// argmax variant by alignment, eg product of 3 4-byte primitives
// vs 1 8-byte primitive, so we need to calculate both.
// vs 1 8-byte primitive, so we need to calculate both.
let
max_size
=
variants
.iter
()
.map
(|
id
|
self
.get_size
(
*
id
,
None
,
extra_dim_collects
,
None
))
.fold
(
let
max_size
=
variants
.iter
()
.map
(|
id
|
self
.get_size
(
*
id
,
None
,
extra_dim_collects
))
.fold
(
String
::
from
(
"0"
),
String
::
from
(
"0"
),
|
acc
,
x
|
{
|
acc
,
x
|
{
if
acc
==
"0"
{
if
acc
==
"0"
{
...
@@ -1880,16 +1857,6 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?;
...
@@ -1880,16 +1857,6 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?;
func_name
.to_string
()
func_name
.to_string
()
}
}
// Check if a type should be represented as char*. Must be a product,
// summation, or array of product/summation types.
fn
is_char
(
&
self
,
type_id
:
TypeID
)
->
bool
{
match
&
self
.types
[
type_id
.idx
()]
{
Type
::
Product
(
_
)
|
Type
::
Summation
(
_
)
=>
true
,
Type
::
Array
(
element_type
,
_
)
=>
self
.is_char
(
*
element_type
),
_
=>
false
,
}
}
fn
get_cg_tile
(
&
self
,
fork
:
NodeID
,
cg_type
:
CGType
)
->
String
{
fn
get_cg_tile
(
&
self
,
fork
:
NodeID
,
cg_type
:
CGType
)
->
String
{
format!
(
"cg_{}{}"
,
self
.get_value
(
fork
,
false
,
false
),
if
cg_type
==
CGType
::
Use
{
"_use"
}
else
if
cg_type
==
CGType
::
Available
{
"_available"
}
else
{
""
})
format!
(
"cg_{}{}"
,
self
.get_value
(
fork
,
false
,
false
),
if
cg_type
==
CGType
::
Use
{
"_use"
}
else
if
cg_type
==
CGType
::
Available
{
"_available"
}
else
{
""
})
}
}
...
@@ -1938,12 +1905,10 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?;
...
@@ -1938,12 +1905,10 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?;
}
}
fn
get_type
(
&
self
,
id
:
TypeID
,
make_pointer
:
bool
)
->
String
{
fn
get_type
(
&
self
,
id
:
TypeID
,
make_pointer
:
bool
)
->
String
{
match
&
self
.types
[
id
.idx
()]
{
if
self
.types
[
id
.idx
()]
.is_primitive
()
{
// Product and summation collections are char* for 1 byte-addressability
convert_type
(
&
self
.types
[
id
.idx
()],
make_pointer
)
// since we can have variable type fields
}
else
{
Type
::
Product
(
_
)
|
Type
::
Summation
(
_
)
=>
"char*"
.to_string
(),
"char*"
.to_string
()
Type
::
Array
(
element_type
,
_
)
=>
self
.get_type
(
*
element_type
,
true
),
_
=>
convert_type
(
&
self
.types
[
id
.idx
()],
make_pointer
),
}
}
}
}
...
...
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment