/*-------------------------------------------------------------------------
 * drawElements Quality Program OpenGL ES 3.1 Module
 * -------------------------------------------------
 *
 * Copyright 2014 The Android Open Source Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 *
 *//*!
 * \file
 * \brief Compute Shader Built-in variable tests.
 *//*--------------------------------------------------------------------*/

#include "es31fComputeShaderBuiltinVarTests.hpp"
#include "gluShaderProgram.hpp"
#include "gluShaderUtil.hpp"
#include "gluRenderContext.hpp"
#include "gluObjectWrapper.hpp"
#include "gluProgramInterfaceQuery.hpp"
#include "tcuVector.hpp"
#include "tcuTestLog.hpp"
#include "tcuVectorUtil.hpp"
#include "deSharedPtr.hpp"
#include "deStringUtil.hpp"
#include "glwFunctions.hpp"
#include "glwEnums.hpp"

#include <map>

namespace deqp
{
namespace gles31
{
namespace Functional
{

using std::map;
using std::string;
using std::vector;
using tcu::IVec3;
using tcu::TestLog;
using tcu::UVec3;

using namespace glu;

template <typename T, int Size>
struct LexicalCompareVec
{
    inline bool operator()(const tcu::Vector<T, Size> &a, const tcu::Vector<T, Size> &b) const
    {
        for (int ndx = 0; ndx < Size; ndx++)
        {
            if (a[ndx] < b[ndx])
                return true;
            else if (a[ndx] > b[ndx])
                return false;
        }
        return false;
    }
};

typedef de::SharedPtr<glu::ShaderProgram> ShaderProgramSp;
typedef std::map<tcu::UVec3, ShaderProgramSp, LexicalCompareVec<uint32_t, 3>> LocalSizeProgramMap;

class ComputeBuiltinVarCase : public TestCase
{
public:
    ComputeBuiltinVarCase(Context &context, const char *name, const char *varName, DataType varType);
    ~ComputeBuiltinVarCase(void);

    void init(void);
    void deinit(void);
    IterateResult iterate(void);

    virtual UVec3 computeReference(const UVec3 &numWorkGroups, const UVec3 &workGroupSize, const UVec3 &workGroupID,
                                   const UVec3 &localInvocationID) const = 0;

protected:
    struct SubCase
    {
        UVec3 localSize;
        UVec3 numWorkGroups;

        SubCase(void)
        {
        }
        SubCase(const UVec3 &localSize_, const UVec3 &numWorkGroups_)
            : localSize(localSize_)
            , numWorkGroups(numWorkGroups_)
        {
        }
    };

    vector<SubCase> m_subCases;

private:
    ComputeBuiltinVarCase(const ComputeBuiltinVarCase &other);
    ComputeBuiltinVarCase &operator=(const ComputeBuiltinVarCase &other);

    uint32_t getProgram(const UVec3 &localSize);

    const string m_varName;
    const DataType m_varType;

    LocalSizeProgramMap m_progMap;
    int m_subCaseNdx;
};

ComputeBuiltinVarCase::ComputeBuiltinVarCase(Context &context, const char *name, const char *varName, DataType varType)
    : TestCase(context, name, varName)
    , m_varName(varName)
    , m_varType(varType)
    , m_subCaseNdx(0)
{
}

ComputeBuiltinVarCase::~ComputeBuiltinVarCase(void)
{
    ComputeBuiltinVarCase::deinit();
}

void ComputeBuiltinVarCase::init(void)
{
    m_testCtx.setTestResult(QP_TEST_RESULT_PASS, "Pass");
    m_subCaseNdx = 0;
}

void ComputeBuiltinVarCase::deinit(void)
{
    m_progMap.clear();
}

static string genBuiltinVarSource(const string &varName, DataType varType, const UVec3 &localSize)
{
    std::ostringstream src;

    src << "#version 310 es\n"
        << "layout (local_size_x = " << localSize.x() << ", local_size_y = " << localSize.y()
        << ", local_size_z = " << localSize.z() << ") in;\n"
        << "uniform highp uvec2 u_stride;\n"
        << "layout(binding = 0) buffer Output\n"
        << "{\n"
        << "    " << glu::getDataTypeName(varType) << " result[];\n"
        << "} sb_out;\n"
        << "\n"
        << "void main (void)\n"
        << "{\n"
        << "    highp uint offset = u_stride.x*gl_GlobalInvocationID.z + u_stride.y*gl_GlobalInvocationID.y + "
           "gl_GlobalInvocationID.x;\n"
        << "    sb_out.result[offset] = " << varName << ";\n"
        << "}\n";

    return src.str();
}

uint32_t ComputeBuiltinVarCase::getProgram(const UVec3 &localSize)
{
    LocalSizeProgramMap::const_iterator cachePos = m_progMap.find(localSize);
    if (cachePos != m_progMap.end())
        return cachePos->second->getProgram();
    else
    {
        ShaderProgramSp program(
            new ShaderProgram(m_context.getRenderContext(),
                              ProgramSources() << ComputeSource(genBuiltinVarSource(m_varName, m_varType, localSize))));

        // Log all compiled programs.
        m_testCtx.getLog() << *program;
        if (!program->isOk())
            throw tcu::TestError("Compile failed");

        m_progMap[localSize] = program;
        return program->getProgram();
    }
}

static inline UVec3 readResultVec(const uint32_t *ptr, int numComps)
{
    UVec3 res;
    for (int ndx = 0; ndx < numComps; ndx++)
        res[ndx] = ptr[ndx];
    return res;
}

static inline bool compareComps(const UVec3 &a, const UVec3 &b, int numComps)
{
    DE_ASSERT(numComps == 1 || numComps == 3);
    return numComps == 3 ? tcu::allEqual(a, b) : a.x() == b.x();
}

struct LogComps
{
    const UVec3 &v;
    int numComps;

    LogComps(const UVec3 &v_, int numComps_) : v(v_), numComps(numComps_)
    {
    }
};

static inline std::ostream &operator<<(std::ostream &str, const LogComps &c)
{
    DE_ASSERT(c.numComps == 1 || c.numComps == 3);
    return c.numComps == 3 ? str << c.v : str << c.v.x();
}

ComputeBuiltinVarCase::IterateResult ComputeBuiltinVarCase::iterate(void)
{
    const tcu::ScopedLogSection section(m_testCtx.getLog(), string("Iteration") + de::toString(m_subCaseNdx),
                                        string("Iteration ") + de::toString(m_subCaseNdx));
    const glw::Functions &gl = m_context.getRenderContext().getFunctions();
    const SubCase &subCase   = m_subCases[m_subCaseNdx];
    const uint32_t program   = getProgram(subCase.localSize);

    const tcu::UVec3 globalSize = subCase.localSize * subCase.numWorkGroups;
    const tcu::UVec2 stride(globalSize[0] * globalSize[1], globalSize[0]);
    const uint32_t numInvocations = subCase.localSize[0] * subCase.localSize[1] * subCase.localSize[2] *
                                    subCase.numWorkGroups[0] * subCase.numWorkGroups[1] * subCase.numWorkGroups[2];

    const uint32_t outVarIndex = gl.getProgramResourceIndex(program, GL_BUFFER_VARIABLE, "Output.result");
    const InterfaceVariableInfo outVarInfo =
        getProgramInterfaceVariableInfo(gl, program, GL_BUFFER_VARIABLE, outVarIndex);
    const uint32_t bufferSize = numInvocations * outVarInfo.arrayStride;
    Buffer outputBuffer(m_context.getRenderContext());

    TCU_CHECK(outVarInfo.arraySize == 0); // Unsized variable.

    m_testCtx.getLog() << TestLog::Message << "Number of work groups = " << subCase.numWorkGroups << TestLog::EndMessage
                       << TestLog::Message << "Work group size = " << subCase.localSize << TestLog::EndMessage;

    gl.bindBuffer(GL_SHADER_STORAGE_BUFFER, *outputBuffer);
    gl.bufferData(GL_SHADER_STORAGE_BUFFER, (glw::GLsizeiptr)bufferSize, DE_NULL, GL_STREAM_READ);
    gl.bindBufferBase(GL_SHADER_STORAGE_BUFFER, 0, *outputBuffer);
    GLU_EXPECT_NO_ERROR(gl.getError(), "Buffer setup failed");

    gl.useProgram(program);
    gl.uniform2uiv(gl.getUniformLocation(program, "u_stride"), 1, stride.getPtr());
    GLU_EXPECT_NO_ERROR(gl.getError(), "Program setup failed");

    gl.dispatchCompute(subCase.numWorkGroups[0], subCase.numWorkGroups[1], subCase.numWorkGroups[2]);
    GLU_EXPECT_NO_ERROR(gl.getError(), "glDispatchCompute() failed");

    {
        const void *ptr        = gl.mapBufferRange(GL_SHADER_STORAGE_BUFFER, 0, bufferSize, GL_MAP_READ_BIT);
        int numFailed          = 0;
        const int numScalars   = getDataTypeScalarSize(m_varType);
        const int maxLogPrints = 10;

        GLU_EXPECT_NO_ERROR(gl.getError(), "glMapBufferRange() failed");
        TCU_CHECK(ptr);

        for (uint32_t groupZ = 0; groupZ < subCase.numWorkGroups.z(); groupZ++)
            for (uint32_t groupY = 0; groupY < subCase.numWorkGroups.y(); groupY++)
                for (uint32_t groupX = 0; groupX < subCase.numWorkGroups.x(); groupX++)
                    for (uint32_t localZ = 0; localZ < subCase.localSize.z(); localZ++)
                        for (uint32_t localY = 0; localY < subCase.localSize.y(); localY++)
                            for (uint32_t localX = 0; localX < subCase.localSize.x(); localX++)
                            {
                                const UVec3 refGroupID(groupX, groupY, groupZ);
                                const UVec3 refLocalID(localX, localY, localZ);
                                const UVec3 refGlobalID = refGroupID * subCase.localSize + refLocalID;
                                const uint32_t refOffset =
                                    stride.x() * refGlobalID.z() + stride.y() * refGlobalID.y() + refGlobalID.x();
                                const UVec3 refValue =
                                    computeReference(subCase.numWorkGroups, subCase.localSize, refGroupID, refLocalID);

                                const uint32_t *resPtr =
                                    (const uint32_t *)((const uint8_t *)ptr + refOffset * outVarInfo.arrayStride);
                                const UVec3 resValue = readResultVec(resPtr, numScalars);

                                if (!compareComps(refValue, resValue, numScalars))
                                {
                                    if (numFailed < maxLogPrints)
                                        m_testCtx.getLog()
                                            << TestLog::Message << "ERROR: comparison failed at offset " << refOffset
                                            << ": expected " << LogComps(refValue, numScalars) << ", got "
                                            << LogComps(resValue, numScalars) << TestLog::EndMessage;
                                    else if (numFailed == maxLogPrints)
                                        m_testCtx.getLog() << TestLog::Message << "..." << TestLog::EndMessage;

                                    numFailed += 1;
                                }
                            }

        m_testCtx.getLog() << TestLog::Message << (numInvocations - numFailed) << " / " << numInvocations
                           << " values passed" << TestLog::EndMessage;

        if (numFailed > 0)
            m_testCtx.setTestResult(QP_TEST_RESULT_FAIL, "Comparison failed");

        gl.unmapBuffer(GL_SHADER_STORAGE_BUFFER);
    }

    m_subCaseNdx += 1;
    return (m_subCaseNdx < (int)m_subCases.size() && m_testCtx.getTestResult() == QP_TEST_RESULT_PASS) ? CONTINUE :
                                                                                                         STOP;
}

// Test cases

class NumWorkGroupsCase : public ComputeBuiltinVarCase
{
public:
    NumWorkGroupsCase(Context &context)
        : ComputeBuiltinVarCase(context, "num_work_groups", "gl_NumWorkGroups", TYPE_UINT_VEC3)
    {
        m_subCases.push_back(SubCase(UVec3(1, 1, 1), UVec3(1, 1, 1)));
        m_subCases.push_back(SubCase(UVec3(1, 1, 1), UVec3(52, 1, 1)));
        m_subCases.push_back(SubCase(UVec3(1, 1, 1), UVec3(1, 39, 1)));
        m_subCases.push_back(SubCase(UVec3(1, 1, 1), UVec3(1, 1, 78)));
        m_subCases.push_back(SubCase(UVec3(1, 1, 1), UVec3(4, 7, 11)));
        m_subCases.push_back(SubCase(UVec3(2, 3, 4), UVec3(4, 7, 11)));
    }

    UVec3 computeReference(const UVec3 &numWorkGroups, const UVec3 &workGroupSize, const UVec3 &workGroupID,
                           const UVec3 &localInvocationID) const
    {
        DE_UNREF(numWorkGroups);
        DE_UNREF(workGroupSize);
        DE_UNREF(workGroupID);
        DE_UNREF(localInvocationID);
        return numWorkGroups;
    }
};

class WorkGroupSizeCase : public ComputeBuiltinVarCase
{
public:
    WorkGroupSizeCase(Context &context)
        : ComputeBuiltinVarCase(context, "work_group_size", "gl_WorkGroupSize", TYPE_UINT_VEC3)
    {
        m_subCases.push_back(SubCase(UVec3(1, 1, 1), UVec3(1, 1, 1)));
        m_subCases.push_back(SubCase(UVec3(1, 1, 1), UVec3(2, 7, 3)));
        m_subCases.push_back(SubCase(UVec3(2, 1, 1), UVec3(1, 1, 1)));
        m_subCases.push_back(SubCase(UVec3(2, 1, 1), UVec3(1, 3, 5)));
        m_subCases.push_back(SubCase(UVec3(1, 3, 1), UVec3(1, 1, 1)));
        m_subCases.push_back(SubCase(UVec3(1, 1, 7), UVec3(1, 1, 1)));
        m_subCases.push_back(SubCase(UVec3(1, 1, 7), UVec3(3, 3, 1)));
        m_subCases.push_back(SubCase(UVec3(10, 3, 4), UVec3(1, 1, 1)));
        m_subCases.push_back(SubCase(UVec3(10, 3, 4), UVec3(3, 1, 2)));
    }

    UVec3 computeReference(const UVec3 &numWorkGroups, const UVec3 &workGroupSize, const UVec3 &workGroupID,
                           const UVec3 &localInvocationID) const
    {
        DE_UNREF(numWorkGroups);
        DE_UNREF(workGroupID);
        DE_UNREF(localInvocationID);
        return workGroupSize;
    }
};

class WorkGroupIDCase : public ComputeBuiltinVarCase
{
public:
    WorkGroupIDCase(Context &context)
        : ComputeBuiltinVarCase(context, "work_group_id", "gl_WorkGroupID", TYPE_UINT_VEC3)
    {
        m_subCases.push_back(SubCase(UVec3(1, 1, 1), UVec3(1, 1, 1)));
        m_subCases.push_back(SubCase(UVec3(1, 1, 1), UVec3(52, 1, 1)));
        m_subCases.push_back(SubCase(UVec3(1, 1, 1), UVec3(1, 39, 1)));
        m_subCases.push_back(SubCase(UVec3(1, 1, 1), UVec3(1, 1, 78)));
        m_subCases.push_back(SubCase(UVec3(1, 1, 1), UVec3(4, 7, 11)));
        m_subCases.push_back(SubCase(UVec3(2, 3, 4), UVec3(4, 7, 11)));
    }

    UVec3 computeReference(const UVec3 &numWorkGroups, const UVec3 &workGroupSize, const UVec3 &workGroupID,
                           const UVec3 &localInvocationID) const
    {
        DE_UNREF(numWorkGroups);
        DE_UNREF(workGroupSize);
        DE_UNREF(localInvocationID);
        return workGroupID;
    }
};

class LocalInvocationIDCase : public ComputeBuiltinVarCase
{
public:
    LocalInvocationIDCase(Context &context)
        : ComputeBuiltinVarCase(context, "local_invocation_id", "gl_LocalInvocationID", TYPE_UINT_VEC3)
    {
        m_subCases.push_back(SubCase(UVec3(1, 1, 1), UVec3(1, 1, 1)));
        m_subCases.push_back(SubCase(UVec3(1, 1, 1), UVec3(2, 7, 3)));
        m_subCases.push_back(SubCase(UVec3(2, 1, 1), UVec3(1, 1, 1)));
        m_subCases.push_back(SubCase(UVec3(2, 1, 1), UVec3(1, 3, 5)));
        m_subCases.push_back(SubCase(UVec3(1, 3, 1), UVec3(1, 1, 1)));
        m_subCases.push_back(SubCase(UVec3(1, 1, 7), UVec3(1, 1, 1)));
        m_subCases.push_back(SubCase(UVec3(1, 1, 7), UVec3(3, 3, 1)));
        m_subCases.push_back(SubCase(UVec3(10, 3, 4), UVec3(1, 1, 1)));
        m_subCases.push_back(SubCase(UVec3(10, 3, 4), UVec3(3, 1, 2)));
    }

    UVec3 computeReference(const UVec3 &numWorkGroups, const UVec3 &workGroupSize, const UVec3 &workGroupID,
                           const UVec3 &localInvocationID) const
    {
        DE_UNREF(numWorkGroups);
        DE_UNREF(workGroupSize);
        DE_UNREF(workGroupID);
        return localInvocationID;
    }
};

class GlobalInvocationIDCase : public ComputeBuiltinVarCase
{
public:
    GlobalInvocationIDCase(Context &context)
        : ComputeBuiltinVarCase(context, "global_invocation_id", "gl_GlobalInvocationID", TYPE_UINT_VEC3)
    {
        m_subCases.push_back(SubCase(UVec3(1, 1, 1), UVec3(1, 1, 1)));
        m_subCases.push_back(SubCase(UVec3(1, 1, 1), UVec3(52, 1, 1)));
        m_subCases.push_back(SubCase(UVec3(1, 1, 1), UVec3(1, 39, 1)));
        m_subCases.push_back(SubCase(UVec3(1, 1, 1), UVec3(1, 1, 78)));
        m_subCases.push_back(SubCase(UVec3(1, 1, 1), UVec3(4, 7, 11)));
        m_subCases.push_back(SubCase(UVec3(2, 3, 4), UVec3(4, 7, 11)));
        m_subCases.push_back(SubCase(UVec3(10, 3, 4), UVec3(1, 1, 1)));
        m_subCases.push_back(SubCase(UVec3(10, 3, 4), UVec3(3, 1, 2)));
    }

    UVec3 computeReference(const UVec3 &numWorkGroups, const UVec3 &workGroupSize, const UVec3 &workGroupID,
                           const UVec3 &localInvocationID) const
    {
        DE_UNREF(numWorkGroups);
        return workGroupID * workGroupSize + localInvocationID;
    }
};

class LocalInvocationIndexCase : public ComputeBuiltinVarCase
{
public:
    LocalInvocationIndexCase(Context &context)
        : ComputeBuiltinVarCase(context, "local_invocation_index", "gl_LocalInvocationIndex", TYPE_UINT)
    {
        m_subCases.push_back(SubCase(UVec3(1, 1, 1), UVec3(1, 1, 1)));
        m_subCases.push_back(SubCase(UVec3(1, 1, 1), UVec3(1, 39, 1)));
        m_subCases.push_back(SubCase(UVec3(1, 1, 1), UVec3(4, 7, 11)));
        m_subCases.push_back(SubCase(UVec3(2, 3, 4), UVec3(4, 7, 11)));
        m_subCases.push_back(SubCase(UVec3(10, 3, 4), UVec3(1, 1, 1)));
        m_subCases.push_back(SubCase(UVec3(10, 3, 4), UVec3(3, 1, 2)));
    }

    UVec3 computeReference(const UVec3 &numWorkGroups, const UVec3 &workGroupSize, const UVec3 &workGroupID,
                           const UVec3 &localInvocationID) const
    {
        DE_UNREF(workGroupID);
        DE_UNREF(numWorkGroups);
        return UVec3(localInvocationID.z() * workGroupSize.x() * workGroupSize.y() +
                         localInvocationID.y() * workGroupSize.x() + localInvocationID.x(),
                     0, 0);
    }
};

ComputeShaderBuiltinVarTests::ComputeShaderBuiltinVarTests(Context &context)
    : TestCaseGroup(context, "compute", "Compute Shader Builtin Variables")
{
}

ComputeShaderBuiltinVarTests::~ComputeShaderBuiltinVarTests(void)
{
}

void ComputeShaderBuiltinVarTests::init(void)
{
    addChild(new NumWorkGroupsCase(m_context));
    addChild(new WorkGroupSizeCase(m_context));
    addChild(new WorkGroupIDCase(m_context));
    addChild(new LocalInvocationIDCase(m_context));
    addChild(new GlobalInvocationIDCase(m_context));
    addChild(new LocalInvocationIndexCase(m_context));
}

} // namespace Functional
} // namespace gles31
} // namespace deqp
