1"""Generate a mock model for LLVM tests.
3The generated model is not a neural net - it is just a tf.function with the
4correct input and output parameters. By construction, the mock model will always
12import tensorflow
as tf
14POLICY_DECISION_LABEL =
"inlining_decision"
15POLICY_OUTPUT_SPEC =
"""
18 "logging_name":
"inlining_decision",
20 "name":
"StatefulPartitionedCall",
32# pylint: disable=g-complex-comprehension
33def get_input_signature():
34 """Returns the list of features for LLVM inlining."""
37 tf.TensorSpec(dtype=tf.int64, shape=(), name=key)
39 "caller_basic_block_count",
40 "caller_conditionally_executed_blocks",
42 "callee_basic_block_count",
43 "callee_conditionally_executed_blocks",
54 "call_argument_setup",
55 "load_relative_intrinsic",
56 "lowered_call_arg_setup",
57 "indirect_call_penalty",
59 "case_cluster_penalty",
61 "unsimplified_common_instructions",
64 "simplified_instructions",
66 "constant_offset_ptr_args",
69 "last_call_to_static_bonus",
72 "nested_inline_cost_estimate",
74 "is_callee_avail_external",
75 "is_caller_avail_external",
82 tf.TensorSpec(dtype=tf.float32, shape=(), name=key)
83 for key
in [
"discount",
"reward"]
89 [tf.TensorSpec(dtype=tf.int32, shape=(), name=key)
for key
in [
"step_type"]]
95 return POLICY_DECISION_LABEL
99 return POLICY_OUTPUT_SPEC
103 return os.path.join(path,
"output_spec.json")
107 """Build and save the mock model with the given signature"""
111 return {signature[
"output"]: tf.constant(value=advice, dtype=tf.int64)}
113 module.action = tf.function()(action)
114 action = {
"action": module.action.get_concrete_function(signature[
"inputs"])}
115 tf.saved_model.save(module, path, signatures=action)
118 with open(output_spec_path,
"w")
as f:
119 print(f
"Writing output spec to {output_spec_path}.")
120 f.write(signature[
"output_spec"])
132 assert len(argv) == 2
or (len(argv) == 3
and argv[2] ==
"never")
135 print(f
"Output model to: [{argv[1]}]")
140 print(f
"The model will always return: {constant_advice}")
146if __name__ ==
"__main__":
static void print(raw_ostream &Out, object::Archive::Kind Kind, T Val)
def get_output_spec_path(path)
def get_input_signature()
def get_output_signature()
def build_mock_model(path, signature, advice)