diff --git a/hpvm/projects/torch2hpvm/torch2hpvm/template_hpvm_inspect.cpp.in b/hpvm/projects/torch2hpvm/torch2hpvm/template_hpvm_inspect.cpp.in index 015b0aea3c67ff038b45ee0021388408162ab41a..4fb25a557401a9a238bf22f0a097b8b1d0228e9a 100644 --- a/hpvm/projects/torch2hpvm/torch2hpvm/template_hpvm_inspect.cpp.in +++ b/hpvm/projects/torch2hpvm/torch2hpvm/template_hpvm_inspect.cpp.in @@ -118,10 +118,11 @@ int main(){ {% endfor %} RootIn* args = static_cast<RootIn*>(malloc(sizeof(RootIn))); - void* {{input_name}} = create4DTensor(0, nchw, batch_size, {{input_shape|join(', ')}}); {% for n in root_inputs %} +{% if n != input_name %} args->{{n}} = {{n}}; args->{{n}}_bytes = 0; +{% endif %} {% endfor %} int ret = 0; @@ -135,7 +136,9 @@ int main(){ auto* fp = open_fifo("{{fifo_path_w}}", "wb"); for (int i = 0; i < batch_count; i++){ int start = i * batch_size, end = start + batch_size; - copyInputBatch(input_pth, start, end, {{input_shape|join(', ')}}, {{input_name}}); + void *{{input_name}} = readInputBatch(input_pth, 0, start, end, {{input_shape|join(', ')}}); + args->input = {{input_name}}; + args->input_bytes = 0; void* dfg = __hpvm__launch(0, root, (void*) args); __hpvm__wait(dfg);