/*
 * Copyright (C) 2019 The Android Open Source Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#ifndef ANDROID_FRAMEWORK_ML_NN_RUNTIME_TEST_FUZZING_RANDOM_VARIABLE_H
#define ANDROID_FRAMEWORK_ML_NN_RUNTIME_TEST_FUZZING_RANDOM_VARIABLE_H

#include <algorithm>
#include <iostream>
#include <map>
#include <memory>
#include <numeric>
#include <set>
#include <string>
#include <unordered_map>
#include <vector>

namespace android {
namespace nn {
namespace fuzzing_test {

static const int kMaxValue = 10000;
static const int kInvalidValue = INT_MIN;

// Describe the search range for the value of a random variable.
class RandomVariableRange {
   public:
    RandomVariableRange() = default;
    explicit RandomVariableRange(int value) : mChoices({value}) {}
    RandomVariableRange(int lower, int upper) : mChoices(upper - lower + 1) {
        std::iota(mChoices.begin(), mChoices.end(), lower);
    }
    explicit RandomVariableRange(const std::vector<int>& vec) : mChoices(vec) {}
    explicit RandomVariableRange(const std::set<int>& st) : mChoices(st.begin(), st.end()) {}
    RandomVariableRange(const RandomVariableRange&) = default;
    RandomVariableRange& operator=(const RandomVariableRange&) = default;

    bool empty() const { return mChoices.empty(); }
    bool has(int value) const {
        return std::binary_search(mChoices.begin(), mChoices.end(), value);
    }
    size_t size() const { return mChoices.size(); }
    int min() const { return *mChoices.begin(); }
    int max() const { return *mChoices.rbegin(); }
    const std::vector<int>& getChoices() const { return mChoices; }

    // Narrow down the range to fit [lower, upper]. Use kInvalidValue to indicate unlimited bound.
    void setRange(int lower, int upper);
    // Narrow down the range to a random selected choice. Return the chosen value.
    int toConst();

    // Calculate the intersection of two ranges.
    friend RandomVariableRange operator&(const RandomVariableRange& lhs,
                                         const RandomVariableRange& rhs);

   private:
    // Always in ascending order.
    std::vector<int> mChoices;
};

// Defines the interface for an operation applying to RandomVariables.
class IRandomVariableOp {
   public:
    virtual ~IRandomVariableOp() {}
    // Forward evaluation of two values.
    virtual int eval(int lhs, int rhs) const = 0;
    // Gets the range of the operation outcomes. The returned range must include all possible
    // outcomes of this operation, but may contain invalid results.
    virtual RandomVariableRange getInitRange(const RandomVariableRange& lhs,
                                             const RandomVariableRange& rhs) const;
    // Provides faster range evaluation for evalSubnetSingleOpHelper if possible.
    virtual void eval(const std::set<int>* parent1In, const std::set<int>* parent2In,
                      const std::set<int>* childIn, std::set<int>* parent1Out,
                      std::set<int>* parent2Out, std::set<int>* childOut) const;
    // For debugging purpose.
    virtual const char* getName() const = 0;
};

enum class RandomVariableType { FREE = 0, CONST = 1, OP = 2 };

struct RandomVariableBase {
    // Each RandomVariableBase is assigned an unique index for debugging purpose.
    static unsigned int globalIndex;
    int index;

    RandomVariableType type;
    RandomVariableRange range;
    int value = 0;
    std::shared_ptr<const IRandomVariableOp> op = nullptr;

    // Network structural information.
    std::shared_ptr<RandomVariableBase> parent1 = nullptr;
    std::shared_ptr<RandomVariableBase> parent2 = nullptr;
    std::vector<std::weak_ptr<RandomVariableBase>> children;

    // The last time that this RandomVariableBase is modified.
    int timestamp;

    explicit RandomVariableBase(int value);
    RandomVariableBase(int lower, int upper);
    explicit RandomVariableBase(const std::vector<int>& choices);
    RandomVariableBase(const std::shared_ptr<RandomVariableBase>& lhs,
                       const std::shared_ptr<RandomVariableBase>& rhs,
                       const std::shared_ptr<const IRandomVariableOp>& op);
    RandomVariableBase(const RandomVariableBase&) = delete;
    RandomVariableBase& operator=(const RandomVariableBase&) = delete;

    // Freeze FREE RandomVariable to one valid choice.
    // Should only invoke on FREE RandomVariable.
    void freeze();

    // Get CONST value or calculate from parents.
    // Should not invoke on FREE RandomVariable.
    int getValue() const;

    // Update the timestamp to the latest global time.
    void updateTimestamp();
};

using RandomVariableNode = std::shared_ptr<RandomVariableBase>;

// A wrapper class of RandomVariableBase that manages RandomVariableBase with shared_ptr and
// provides useful methods and operator overloading to build the random variable network.
class RandomVariable {
   public:
    // Construct a placeholder RandomVariable with nullptr.
    RandomVariable() : mVar(nullptr) {}

    // Construct a CONST RandomVariable with specified value.
    /* implicit */ RandomVariable(int value);

    // Construct a FREE RandomVariable with range [lower, upper].
    RandomVariable(int lower, int upper);

    // Construct a FREE RandomVariable with specified value choices.
    explicit RandomVariable(const std::vector<int>& choices);

    // This is for RandomVariableType::FREE only.
    // Construct a FREE RandomVariable with default range [1, defaultValue].
    /* implicit */ RandomVariable(RandomVariableType type);

    // RandomVariables share the same RandomVariableBase if copied or copy-assigned.
    RandomVariable(const RandomVariable& other) = default;
    RandomVariable& operator=(const RandomVariable& other) = default;

    // Get the value of the RandomVariable, the value must be deterministic.
    int getValue() const { return mVar->getValue(); }

    // Get the underlying managed RandomVariableNode.
    RandomVariableNode get() const { return mVar; };

    bool operator==(std::nullptr_t) const { return mVar == nullptr; }
    bool operator!=(std::nullptr_t) const { return mVar != nullptr; }

    // Arithmetic operators and methods on RandomVariables.
    friend RandomVariable operator+(const RandomVariable& lhs, const RandomVariable& rhs);
    friend RandomVariable operator-(const RandomVariable& lhs, const RandomVariable& rhs);
    friend RandomVariable operator*(const RandomVariable& lhs, const RandomVariable& rhs);
    friend RandomVariable operator*(const RandomVariable& lhs, const float& rhs);
    friend RandomVariable operator/(const RandomVariable& lhs, const RandomVariable& rhs);
    friend RandomVariable operator%(const RandomVariable& lhs, const RandomVariable& rhs);
    friend RandomVariable max(const RandomVariable& lhs, const RandomVariable& rhs);
    friend RandomVariable min(const RandomVariable& lhs, const RandomVariable& rhs);
    RandomVariable exactDiv(const RandomVariable& other);

    // Set constraints on the RandomVariable. Use kInvalidValue to indicate unlimited bound.
    void setRange(int lower, int upper);
    RandomVariable setEqual(const RandomVariable& other) const;
    RandomVariable setGreaterThan(const RandomVariable& other) const;
    RandomVariable setGreaterEqual(const RandomVariable& other) const;

    // A FREE RandomVariable is constructed with default range [1, defaultValue].
    static int defaultValue;

   private:
    // Construct a RandomVariable as the result of an OP between two other RandomVariables.
    RandomVariable(const RandomVariable& lhs, const RandomVariable& rhs,
                   const std::shared_ptr<const IRandomVariableOp>& op);
    RandomVariableNode mVar;
};

using EvaluationOrder = std::vector<RandomVariableNode>;

// The base class of a network consisting of disjoint subnets.
class DisjointNetwork {
   public:
    // Add a node to the network, join the parent subnets if needed.
    void add(const RandomVariableNode& var);

    // Similar to join(int, int), but accept RandomVariableNodes.
    int join(const RandomVariableNode& var1, const RandomVariableNode& var2) {
        return DisjointNetwork::join(mIndexMap[var1], mIndexMap[var2]);
    }

   protected:
    DisjointNetwork() = default;
    DisjointNetwork(const DisjointNetwork&) = default;
    DisjointNetwork& operator=(const DisjointNetwork&) = default;

    // Join two subnets by appending every node in ind2 after ind1, return the resulting subnet
    // index. Use -1 for invalid subnet index.
    int join(int ind1, int ind2);

    // A map from the network node to the corresponding subnet index.
    std::unordered_map<RandomVariableNode, int> mIndexMap;

    // A map from the subnet index to the set of nodes within the subnet. The nodes are maintained
    // in a valid evaluation order, that is, a valid topological sort.
    std::map<int, EvaluationOrder> mEvalOrderMap;

    // The next index for a new disjoint subnet component.
    int mNextIndex = 0;
};

// Manages the active RandomVariable network. Only one instance of this class will exist.
class RandomVariableNetwork : public DisjointNetwork {
   public:
    // Returns the singleton network instance.
    static RandomVariableNetwork* get();

    // Re-initialization. Should be called every time a new random graph is being generated.
    void initialize(int defaultValue);

    // Set the elementwise equality of the two vectors of RandomVariables iff it results in a
    // soluble network.
    bool setEqualIfCompatible(const std::vector<RandomVariable>& lhs,
                              const std::vector<RandomVariable>& rhs);

    // Freeze all FREE RandomVariables in the network to a random valid combination.
    bool freeze();

    // Check if node2 is FREE and can be evaluated after node1.
    bool isSubordinate(const RandomVariableNode& node1, const RandomVariableNode& node2);

    // Get and then advance the current global timestamp.
    int getGlobalTime() { return mGlobalTime++; }

    // Add a special constraint on dimension product.
    void addDimensionProd(const std::vector<RandomVariable>& dims);

   private:
    RandomVariableNetwork() = default;
    RandomVariableNetwork(const RandomVariableNetwork&) = default;
    RandomVariableNetwork& operator=(const RandomVariableNetwork&) = default;

    // A class to revert all the changes made to RandomVariableNetwork since the Reverter object is
    // constructed. Only used when setEqualIfCompatible results in incompatible.
    class Reverter;

    // Find valid choices for all RandomVariables in the network. Update the RandomVariableRange
    // if the network is soluble, otherwise, return false and leave the ranges unchanged.
    bool evalRange();

    int mGlobalTime = 0;
    int mTimestamp = -1;

    std::vector<EvaluationOrder> mDimProd;
};

}  // namespace fuzzing_test
}  // namespace nn
}  // namespace android

#endif  // ANDROID_FRAMEWORK_ML_NN_RUNTIME_TEST_FUZZING_RANDOM_VARIABLE_H
