//
// Copyright © 2017-2023 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
#pragma once

#include <armnn/Descriptors.hpp>
#include "armnn/INetwork.hpp"
#include "armnnTfLiteParser/ITfLiteParser.hpp"
#include "armnn/Types.hpp"

#include <schema_generated.h>
#include <functional>
#include <unordered_map>
#include <vector>

#include <tensorflow/lite/version.h>

#if TF_MAJOR_VERSION > 2 || (TF_MAJOR_VERSION == 2 && TF_MINOR_VERSION > 3)
#define ARMNN_POST_TFLITE_2_3
#endif

namespace armnnTfLiteParser
{

class TfLiteParserImpl
{
public:
    // Shorthands for TfLite types
    using ModelPtr = std::unique_ptr<tflite::ModelT>;
    using SubgraphPtr = std::unique_ptr<tflite::SubGraphT>;
    using OperatorPtr = std::unique_ptr<tflite::OperatorT>;
    using OperatorCodePtr = std::unique_ptr<tflite::OperatorCodeT>;
    using TensorPtr = std::unique_ptr<tflite::TensorT>;
    using TensorRawPtr = const tflite::TensorT *;
    using TensorRawPtrVector = std::vector<TensorRawPtr>;
    using TensorIdRawPtr = std::pair<size_t, TensorRawPtr>;
    using TensorIdRawPtrVector = std::vector<TensorIdRawPtr>;
    using BufferPtr = std::unique_ptr<tflite::BufferT>;
    using BufferRawPtr = const tflite::BufferT *;

public:
    /// Create the network from a flatbuffers binary file on disk
    armnn::INetworkPtr CreateNetworkFromBinaryFile(const char* graphFile);

    /// Create the network from a flatbuffers binary
    armnn::INetworkPtr CreateNetworkFromBinary(const std::vector<uint8_t> & binaryContent);


    /// Retrieve binding info (layer id and tensor info) for the network input identified by
    /// the given layer name and subgraph id
    BindingPointInfo GetNetworkInputBindingInfo(size_t subgraphId,
                                                const std::string& name) const;

    /// Retrieve binding info (layer id and tensor info) for the network output identified by
    /// the given layer name and subgraph id
    BindingPointInfo GetNetworkOutputBindingInfo(size_t subgraphId,
                                                         const std::string& name) const;

    /// Return the number of subgraphs in the parsed model
    size_t GetSubgraphCount() const;

    /// Return the input tensor names for a given subgraph
    std::vector<std::string> GetSubgraphInputTensorNames(size_t subgraphId) const;

    /// Return the output tensor names for a given subgraph
    std::vector<std::string> GetSubgraphOutputTensorNames(size_t subgraphId) const;

    TfLiteParserImpl(const armnn::Optional<ITfLiteParser::TfLiteParserOptions>& options = armnn::EmptyOptional());
    ~TfLiteParserImpl() = default;

public:
    // testable helpers
    armnn::INetworkPtr CreateNetworkFromBinaryAsDynamic(const std::vector<uint8_t>& binaryContent);

    armnn::INetworkPtr LoadModel(std::unique_ptr<tflite::ModelT> model);

    static ModelPtr LoadModelFromFile(const char* fileName);
    static ModelPtr LoadModelFromBinary(const uint8_t* binaryContent, size_t len);
    static TensorRawPtrVector GetInputs(const ModelPtr& model, size_t subgraphIndex, size_t operatorIndex);
    static TensorRawPtrVector GetOutputs(const ModelPtr& model, size_t subgraphIndex, size_t operatorIndex);
    static TensorIdRawPtrVector GetSubgraphInputs(const ModelPtr& model, size_t subgraphIndex);
    static TensorIdRawPtrVector GetSubgraphOutputs(const ModelPtr& model, size_t subgraphIndex);
    static std::vector<int32_t>& GetInputTensorIds(const ModelPtr& model, size_t subgraphIndex, size_t operatorIndex);
    static std::vector<int32_t>& GetOutputTensorIds(const ModelPtr& model, size_t subgraphIndex, size_t operatorIndex);

    static BufferRawPtr GetBuffer(const ModelPtr& model, size_t bufferIndex);
    static armnn::TensorInfo OutputShapeOfSqueeze(std::vector<uint32_t> squeezeDims,
                                                  const armnn::TensorInfo& inputTensorInfo);
    static armnn::TensorInfo OutputShapeOfReshape(const armnn::TensorInfo& inputTensorInfo,
                                                  const std::vector<int32_t>& targetDimsIn);

    /// Retrieve version in X.Y.Z form
    static const std::string GetVersion();

private:

    // No copying allowed until it is wanted and properly implemented
    TfLiteParserImpl(const TfLiteParserImpl &) = delete;
    TfLiteParserImpl & operator=(const TfLiteParserImpl &) = delete;

    /// Create the network from an already loaded flatbuffers model
    armnn::INetworkPtr CreateNetworkFromModel();

    // signature for the parser functions
    using OperatorParsingFunction = void(TfLiteParserImpl::*)(size_t subgraphIndex, size_t operatorIndex);

    void ParseCustomOperator(size_t subgraphIndex, size_t operatorIndex);
    void ParseUnsupportedOperator(size_t subgraphIndex, size_t operatorIndex);

    void ParseAbs(size_t subgraphIndex, size_t operatorIndex);
    void ParseActivation(size_t subgraphIndex, size_t operatorIndex, armnn::ActivationFunction activationType);
    void ParseAdd(size_t subgraphIndex, size_t operatorIndex);
    void ParseArgMinMax(size_t subgraphIndex, size_t operatorIndex, armnn::ArgMinMaxFunction argMinMaxFunction);
    void ParseArgMin(size_t subgraphIndex, size_t operatorIndex);
    void ParseArgMax(size_t subgraphIndex, size_t operatorIndex);
    void ParseAveragePool2D(size_t subgraphIndex, size_t operatorIndex);
    void ParseBatchMatMul(size_t subgraphIndex, size_t operatorIndex);
    void ParseBatchToSpaceND(size_t subgraphIndex, size_t operatorIndex);
    void ParseCast(size_t subgraphIndex, size_t operatorIndex);
    void ParseCeil(size_t subgraphIndex, size_t operatorIndex);
    void ParseComparison(size_t subgraphIndex, size_t operatorIndex, armnn::ComparisonOperation comparisonOperation);
    void ParseConcatenation(size_t subgraphIndex, size_t operatorIndex);
    void ParseConv2D(size_t subgraphIndex, size_t operatorIndex);
    // Conv3D support was added in TF 2.5, so for backwards compatibility a hash define is needed.
    #if defined(ARMNN_POST_TFLITE_2_4)
    void ParseConv3D(size_t subgraphIndex, size_t operatorIndex);
    #endif
    void ParseDepthToSpace(size_t subgraphIndex, size_t operatorIndex);
    void ParseDepthwiseConv2D(size_t subgraphIndex, size_t operatorIndex);
    void ParseDequantize(size_t subgraphIndex, size_t operatorIndex);
    void ParseDetectionPostProcess(size_t subgraphIndex, size_t operatorIndex);
    void ParseDiv(size_t subgraphIndex, size_t operatorIndex);
    void ParseElementwiseUnary(size_t subgraphIndex, size_t operatorIndex, armnn::UnaryOperation unaryOperation);
    void ParseElu(size_t subgraphIndex, size_t operatorIndex);
    void ParseEqual(size_t subgraphIndex, size_t operatorIndex);
    void ParseExp(size_t subgraphIndex, size_t operatorIndex);
    void ParseExpandDims(size_t subgraphIndex, size_t operatorIndex);
    void ParseFloorDiv(size_t subgraphIndex, size_t operatorIndex);
    void ParseFullyConnected(size_t subgraphIndex, size_t operatorIndex);
    void ParseGather(size_t subgraphIndex, size_t operatorIndex);
    void ParseGatherNd(size_t subgraphIndex, size_t operatorIndex);
    void ParseGreater(size_t subgraphIndex, size_t operatorIndex);
    void ParseGreaterOrEqual(size_t subgraphIndex, size_t operatorIndex);
    void ParseHardSwish(size_t subgraphIndex, size_t operatorIndex);
    void ParseLeakyRelu(size_t subgraphIndex, size_t operatorIndex);
    void ParseLess(size_t subgraphIndex, size_t operatorIndex);
    void ParseLessOrEqual(size_t subgraphIndex, size_t operatorIndex);
    void ParseLog(size_t subgraphIndex, size_t operatorIndex);
    void ParseLocalResponseNormalization(size_t subgraphIndex, size_t operatorIndex);
    void ParseLogicalNot(size_t subgraphIndex, size_t operatorIndex);
    void ParseLogistic(size_t subgraphIndex, size_t operatorIndex);
    void ParseLogSoftmax(size_t subgraphIndex, size_t operatorIndex);
    void ParseL2Normalization(size_t subgraphIndex, size_t operatorIndex);
    void ParseMaxPool2D(size_t subgraphIndex, size_t operatorIndex);
    void ParseMaximum(size_t subgraphIndex, size_t operatorIndex);
    void ParseMean(size_t subgraphIndex, size_t operatorIndex);
    void ParseMinimum(size_t subgraphIndex, size_t operatorIndex);
    void ParseMirrorPad(size_t subgraphIndex, size_t operatorIndex);
    void ParseMul(size_t subgraphIndex, size_t operatorIndex);
    void ParseNeg(size_t subgraphIndex, size_t operatorIndex);
    void ParseNotEqual(size_t subgraphIndex, size_t operatorIndex);
    void ParsePack(size_t subgraphIndex, size_t operatorIndex);
    void ParsePad(size_t subgraphIndex, size_t operatorIndex);
    void ParsePool(size_t subgraphIndex, size_t operatorIndex, armnn::PoolingAlgorithm algorithm);
    void ParsePrelu(size_t subgraphIndex, size_t operatorIndex);
    void ParseQuantize(size_t subgraphIndex, size_t operatorIndex);
    void ParseReduce(size_t subgraphIndex, size_t operatorIndex, armnn::ReduceOperation reduceOperation);
    void ParseReduceMax(size_t subgraphIndex, size_t operatorIndex);
    void ParseReduceMin(size_t subgraphIndex, size_t operatorIndex);
    void ParseReduceProd(size_t subgraphIndex, size_t operatorIndex);
    void ParseRelu(size_t subgraphIndex, size_t operatorIndex);
    void ParseRelu6(size_t subgraphIndex, size_t operatorIndex);
    void ParseReshape(size_t subgraphIndex, size_t operatorIndex);
    void ParseResize(size_t subgraphIndex, size_t operatorIndex, armnn::ResizeMethod resizeMethod);
    void ParseResizeBilinear(size_t subgraphIndex, size_t operatorIndex);
    void ParseResizeNearestNeighbor(size_t subgraphIndex, size_t operatorIndex);
    void ParseRsqrt(size_t subgraphIndex, size_t operatorIndex);
    void ParseShape(size_t subgraphIndex, size_t operatorIndex);
    void ParseSin(size_t subgraphIndex, size_t operatorIndex);
    void ParseSlice(size_t subgraphIndex, size_t operatorIndex);
    void ParseSoftmax(size_t subgraphIndex, size_t operatorIndex);
    void ParseSqrt(size_t subgraphIndex, size_t operatorIndex);
    void ParseSpaceToBatchND(size_t subgraphIndex, size_t operatorIndex);
    void ParseSpaceToDepth(size_t subgraphIndex, size_t operatorIndex);
    void ParseSplit(size_t subgraphIndex, size_t operatorIndex);
    void ParseSplitV(size_t subgraphIndex, size_t operatorIndex);
    void ParseSqueeze(size_t subgraphIndex, size_t operatorIndex);
    void ParseStridedSlice(size_t subgraphIndex, size_t operatorIndex);
    void ParseSub(size_t subgraphIndex, size_t operatorIndex);
    void ParseSum(size_t subgraphIndex, size_t operatorIndex);
    void ParseTanH(size_t subgraphIndex, size_t operatorIndex);
    void ParseTranspose(size_t subgraphIndex, size_t operatorIndex);
    void ParseTransposeConv(size_t subgraphIndex, size_t operatorIndex);
    void ParseUnidirectionalSequenceLSTM(size_t subgraphIndex, size_t operatorIndex);
    void ParseUnpack(size_t subgraphIndex, size_t operatorIndex);

    void RegisterProducerOfTensor(size_t subgraphIndex, size_t tensorIndex, armnn::IOutputSlot* slot);
    void RegisterConsumerOfTensor(size_t subgraphIndex, size_t tensorIndex, armnn::IInputSlot* slot);
    void RegisterInputSlots(size_t subgraphIndex,
                            size_t operatorIndex,
                            armnn::IConnectableLayer* layer,
                            const std::vector<unsigned int>& tensorIndexes,
                            unsigned int startingSlotIndex = 0);
    void RegisterOutputSlots(size_t subgraphIndex,
                             size_t operatorIndex,
                             armnn::IConnectableLayer* layer,
                             const std::vector<unsigned int>& tensorIndexes);

    void SetupInputLayerTensorInfos(size_t subgraphIndex);
    void SetupConstantLayerTensorInfos(size_t subgraphIndex);

    void SetupInputLayers(size_t subgraphIndex);
    void SetupOutputLayers(size_t subgraphIndex);
    void SetupConstantLayers(size_t subgraphIndex);

    void ResetParser();

    void AddBroadcastReshapeLayer(size_t subgraphIndex,
                                  size_t operatorIndex,
                                  armnn::IConnectableLayer* layer);

    /// Attach an reshape layer to the one passed as a parameter
    armnn::IConnectableLayer* AddReshapeLayer(armnn::IConnectableLayer* layer,
                                              unsigned int outputSlot,
                                              std::string reshapeLayerName,
                                              armnn::TensorInfo outputShape);

    /// Attach an activation layer to the one passed as a parameter
    armnn::IConnectableLayer* AddFusedActivationLayer(armnn::IConnectableLayer* layer,
                                                      unsigned int outputSlot,
                                                      tflite::ActivationFunctionType activationType);

    /// Attach a floor layer to the one passed as a parameter
    armnn::IConnectableLayer* AddFusedFloorLayer(armnn::IConnectableLayer* layer, unsigned int outputSlot);

    // SupportedDataStorage's purpose is to hold data till we pass over to the network.
    // We don't care about the content, and we want a single datatype to simplify the code.
    struct SupportedDataStorage
    {
    public:
        // Convenience constructors
        SupportedDataStorage(std::unique_ptr<float[]>&&   data);
        SupportedDataStorage(std::unique_ptr<uint8_t[]>&& data);
        SupportedDataStorage(std::unique_ptr<int8_t[]>&&  data);
        SupportedDataStorage(std::unique_ptr<int32_t[]>&& data);

    private:
        // Pointers to the data buffers
        std::unique_ptr<float[]>   m_FloatData;
        std::unique_ptr<uint8_t[]> m_Uint8Data;
        std::unique_ptr<int8_t[]>  m_Int8Data;
        std::unique_ptr<int32_t[]> m_Int32Data;
    };

    bool ShouldConstantTensorBeCreated(unsigned int tensorIndex);

    bool IsConstTensor(TensorRawPtr tensorPtr);

    bool ShouldConstantTensorBeConverted(TfLiteParserImpl::TensorRawPtr tensorPtr,
                                         armnn::DataType inputDataType,
                                         armnn::DataType filterDataType);

    armnn::ConstTensor CreateConstTensorNonPermuted(TensorRawPtr tensorPtr,
                                                    armnn::TensorInfo& tensorInfo);

    std::pair<armnn::ConstTensor, SupportedDataStorage>
    CreateConstTensorPermuted(TensorRawPtr tensorPtr,
                              armnn::TensorInfo& tensorInfo,
                              armnn::Optional<armnn::PermutationVector&> permutationVector);

    std::pair<armnn::ConstTensor, std::unique_ptr<float[]>>
    CreateConstTensorNonPermuted(TensorRawPtr tensorPtr,
                                 armnn::TensorInfo& tensorInfo,
                                 armnn::DataType inputDataType);

    template<typename T>
    std::pair<armnn::ConstTensor, TfLiteParserImpl::SupportedDataStorage>
    CreateConstTensorAndStoreData(TfLiteParserImpl::BufferRawPtr bufferPtr,
                                  TfLiteParserImpl::TensorRawPtr tensorPtr,
                                  armnn::TensorInfo& tensorInfo,
                                  armnn::Optional<armnn::PermutationVector&> permutationVector);

    std::pair<armnn::ConstTensor*, std::unique_ptr<float[]>>
    CreateConstTensorPtr(TensorRawPtr tensorPtr,
                         armnn::TensorInfo& inputTensorInfo);

    armnn::TensorInfo InputTensorInfo(size_t subgraphIndex,
                                      size_t operatorIndex,
                                      int input);

    armnn::TensorInfo OutputTensorInfoFromInputs(size_t subgraphIndex,
                                       size_t operatorIndex,
                                       armnn::IConnectableLayer* layer,
                                       int output,
                                       std::vector<int> inputs);

    armnn::TensorInfo OutputTensorInfoFromShapes(size_t subgraphIndex,
                                       size_t operatorIndex,
                                       armnn::IConnectableLayer* layer,
                                       int output = 0,
                                       std::vector<armnn::TensorShape> inputShapes = {});

    /// Settings for configuring the TfLiteParser
    armnn::Optional<ITfLiteParser::TfLiteParserOptions> m_Options;

    /// The network we're building. Gets cleared after it is passed to the user
    armnn::INetworkPtr                    m_Network;
    ModelPtr                              m_Model;

    std::vector<OperatorParsingFunction>                     m_ParserFunctions;
    std::unordered_map<std::string, OperatorParsingFunction> m_CustomParserFunctions;

    /// A mapping of an output slot to each of the input slots it should be connected to
    /// The outputSlot is from the layer that creates this tensor as one of its ouputs
    /// The inputSlots are from the layers that use this tensor as one of their inputs
    struct TensorSlots
    {
        armnn::IOutputSlot* outputSlot;
        std::vector<armnn::IInputSlot*> inputSlots;

        TensorSlots() : outputSlot(nullptr) { }
    };
    typedef std::vector<TensorSlots> TensorConnections;
    /// Connections for tensors in each subgraph
    /// The first index is the subgraph ID, the second index is the tensor ID
    std::vector<TensorConnections> m_SubgraphConnections;

    /// This is used in case that the model does not specify the output.
    /// The shape can be calculated from the options.
    std::vector<std::vector<unsigned int>> m_OverriddenOutputShapes;

    std::vector<unsigned int> m_ConstantsToDequantize;
    std::vector<unsigned int> m_ConstantsToBeCreated;
    std::map<size_t, armnn::TensorInfo> m_TensorInfos;
};

}
