LLVM  14.0.0git
generate_mock_model.py
Go to the documentation of this file.
1 """Generate a mock model for LLVM tests.
2 
3 The generated model is not a neural net - it is just a tf.function with the
4 correct input and output parameters. By construction, the mock model will always
5 output 1.
6 """
7 
8 import os
9 import importlib.util
10 import sys
11 
12 import tensorflow as tf
13 
14 
16  return os.path.join(path, 'output_spec.json')
17 
18 
19 def build_mock_model(path, signature):
20  """Build and save the mock model with the given signature"""
21  module = tf.Module()
22 
23  # We have to set this useless variable in order for the TF C API to correctly
24  # intake it
25  module.var = tf.Variable(0.)
26 
27  def action(*inputs):
28  s = tf.reduce_sum([tf.cast(x, tf.float32) for x in tf.nest.flatten(inputs)])
29  return {signature['output']: float('inf') + s + module.var}
30 
31  module.action = tf.function()(action)
32  action = {'action': module.action.get_concrete_function(signature['inputs'])}
33  tf.saved_model.save(module, path, signatures=action)
34 
35  output_spec_path = get_output_spec_path(path)
36  with open(output_spec_path, 'w') as f:
37  print(f'Writing output spec to {output_spec_path}.')
38  f.write(signature['output_spec'])
39 
40 
41 def get_external_signature(config_path):
42  """Get the signature for the desired model.
43 
44  We manually import the python file at config_path to avoid adding a gin
45  dependency to the LLVM build.
46  """
47  spec = importlib.util.spec_from_file_location('config', config_path)
48  config = importlib.util.module_from_spec(spec)
49  spec.loader.exec_module(config)
50 
51  return {
52  'inputs': config.get_input_signature(),
53  'output': config.get_output_signature(),
54  'output_spec': config.get_output_spec()
55  }
56 
57 
58 def main(argv):
59  assert len(argv) == 3
60  config_path = argv[1]
61  model_path = argv[2]
62 
63  print(f'Using config file at [{argv[1]}]')
64  signature = get_external_signature(config_path)
65  build_mock_model(model_path, signature)
66 
67 
68 if __name__ == '__main__':
69  main(sys.argv)
print
static void print(raw_ostream &Out, object::Archive::Kind Kind, T Val)
Definition: ArchiveWriter.cpp:147
config.get_output_spec
def get_output_spec()
Definition: config.py:86
generate_mock_model.get_output_spec_path
def get_output_spec_path(path)
Definition: generate_mock_model.py:15
config.get_input_signature
def get_input_signature()
Definition: config.py:24
generate_mock_model.build_mock_model
def build_mock_model(path, signature)
Definition: generate_mock_model.py:19
config.get_output_signature
def get_output_signature()
Definition: config.py:82
generate_mock_model.get_external_signature
def get_external_signature(config_path)
Definition: generate_mock_model.py:41
generate_mock_model.main
def main(argv)
Definition: generate_mock_model.py:58