"""Convert a saved model to tflite model.

Usage: python3 saved-model-to-tflite.py <mlgo saved_model_dir> <tflite dest_dir>

The <tflite dest_dir> will contain:
  model.tflite: this is the converted saved model
  output_spec.json: the output spec, copied from the saved_model dir.
"""

import tensorflow as tf
import os
import sys
from tf_agents.policies import greedy_policy


def main(argv):
  assert len(argv) == 3
  sm_dir = argv[1]
  tfl_dir = argv[2]
  tf.io.gfile.makedirs(tfl_dir)
  tfl_path = os.path.join(tfl_dir, 'model.tflite')
  converter = tf.lite.TFLiteConverter.from_saved_model(sm_dir)
  converter.target_spec.supported_ops = [
    tf.lite.OpsSet.TFLITE_BUILTINS,
  ]
  tfl_model = converter.convert()
  with tf.io.gfile.GFile(tfl_path, 'wb') as f:
    f.write(tfl_model)
  
  json_file = 'output_spec.json'
  src_json = os.path.join(sm_dir, json_file)
  if tf.io.gfile.exists(src_json):
    tf.io.gfile.copy(src_json,
                     os.path.join(tfl_dir, json_file))

if __name__ == '__main__':
  main(sys.argv)
