Skip to content
Snippets Groups Projects
Commit 5837cfe5 authored by Elizabeth's avatar Elizabeth
Browse files

Added code to handle local paths for fp32 row

parent 26eb0e32
No related branches found
No related tags found
No related merge requests found
...@@ -113,19 +113,19 @@ def get_new_function_calls(complete_line, knob_config): ...@@ -113,19 +113,19 @@ def get_new_function_calls(complete_line, knob_config):
return ''.join(new_line) return ''.join(new_line)
def generate_fp32_source(new_file, source_file): def convert_local_paths(file_contents, orig_source_dir):
# Copy the source code over '''
new_file.write(source_file.read()) Converts all local paths wrt the original source file's directory to paths compatible
with the current source code directory
def generate_fp16_source(knob_config, new_file, source_file, orig_source_dir):
file_contents = source_file.read()
# Fix all local paths Args:
file_contents: String containing source code read from file
orig_source_dir: Path of original source code dir wrt the current directory
'''
last_include_ind = file_contents.rfind("#include") last_include_ind = file_contents.rfind("#include")
last_include_newline_ind = file_contents.find("\n", last_include_ind) last_include_newline_ind = file_contents.find("\n", last_include_ind)
include_lines = file_contents[ : last_include_newline_ind].split("\n") include_lines = file_contents[ : last_include_newline_ind].split("\n")
new_file_contents = [] new_file_contents = []
for line in include_lines: for line in include_lines:
if line.startswith("#"): if line.startswith("#"):
...@@ -136,7 +136,19 @@ def generate_fp16_source(knob_config, new_file, source_file, orig_source_dir): ...@@ -136,7 +136,19 @@ def generate_fp16_source(knob_config, new_file, source_file, orig_source_dir):
else: else:
new_file_contents.append(line) new_file_contents.append(line)
new_file_contents.append(file_contents[last_include_newline_ind : ]) new_file_contents.append(file_contents[last_include_newline_ind : ])
new_file_contents = '\n'.join(new_file_contents) return '\n'.join(new_file_contents)
def generate_fp32_source(new_file, source_file, orig_source_dir):
# Copy the source code over
new_file_contents = convert_local_paths(source_file.read(), orig_source_dir)
new_file.write(new_file_contents)
def generate_fp16_source(knob_config, new_file, source_file, orig_source_dir):
file_contents = source_file.read()
new_file_contents = convert_local_paths(file_contents, orig_source_dir)
# Replace all tensorOperation calls with tensorHalfOperation calls # Replace all tensorOperation calls with tensorHalfOperation calls
# Derived from ../bin/replace_half_calls.py # Derived from ../bin/replace_half_calls.py
...@@ -207,7 +219,7 @@ def generate_source_code(table, dir_name, filename, source_name): ...@@ -207,7 +219,7 @@ def generate_source_code(table, dir_name, filename, source_name):
if knob_config.approx == Approx.FP16: if knob_config.approx == Approx.FP16:
generate_fp16_source(knob_config, new_file, source_file, orig_source_dir) generate_fp16_source(knob_config, new_file, source_file, orig_source_dir)
elif knob_config.approx == Approx.FP32: elif knob_config.approx == Approx.FP32:
generate_fp32_source(new_file, source_file) generate_fp32_source(new_file, source_file, orig_source_dir)
else: else:
generate_approx_source(knob_config, new_file, source_file, orig_source_dir) generate_approx_source(knob_config, new_file, source_file, orig_source_dir)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment