//
// Copyright (c) 2017 The Khronos Group Inc.
// 
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//    http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
#include "testBase.h"
#include "testHarness.h"
#include "harness/conversions.h"
#include "harness/typeWrappers.h"
#include <math.h>
#include <float.h>

#if !defined (__APPLE__)
    #include <CL/cl_gl.h>
#endif

static const char *bufferKernelPattern =
"__kernel void sample_test( __global %s%s *source, __global %s%s *clDest, __global %s%s *glDest )\n"
"{\n"
"    int  tid = get_global_id(0);\n"
"     clDest[ tid ] = source[ tid ] + (%s%s)(1);\n"
"     glDest[ tid ] = source[ tid ] + (%s%s)(2);\n"
"}\n";

#define TYPE_CASE( enum, type, range, offset )    \
    case enum:    \
    {                \
        cl_##type *ptr = (cl_##type *)outData; \
        for( i = 0; i < count; i++ ) \
            ptr[ i ] = (cl_##type)( ( genrand_int32(d) & range ) - offset ); \
        break; \
    }

void gen_input_data( ExplicitType type, size_t count, MTdata d, void *outData )
{
    size_t i;

    switch( type )
    {
        case kBool:
        {
            bool *boolPtr = (bool *)outData;
            for( i = 0; i < count; i++ )
            {
                boolPtr[i] = ( genrand_int32(d) & 1 ) ? true : false;
            }
            break;
        }

        TYPE_CASE( kChar, char, 250, 127 )
        TYPE_CASE( kUChar, uchar, 250, 0 )
        TYPE_CASE( kShort, short, 65530, 32767 )
        TYPE_CASE( kUShort, ushort, 65530, 0 )
        TYPE_CASE( kInt, int, 0x0fffffff, 0x70000000 )
        TYPE_CASE( kUInt, uint, 0x0fffffff, 0 )

        case kLong:
        {
            cl_long *longPtr = (cl_long *)outData;
            for( i = 0; i < count; i++ )
            {
                longPtr[i] = (cl_long)genrand_int32(d) | ( (cl_ulong)genrand_int32(d) << 32 );
            }
            break;
        }

        case kULong:
        {
            cl_ulong *ulongPtr = (cl_ulong *)outData;
            for( i = 0; i < count; i++ )
            {
                ulongPtr[i] = (cl_ulong)genrand_int32(d) | ( (cl_ulong)genrand_int32(d) << 32 );
            }
            break;
        }

        case kFloat:
        {
            cl_float *floatPtr = (float *)outData;
            for( i = 0; i < count; i++ )
                floatPtr[i] = get_random_float( -100000.f, 100000.f, d );
            break;
        }

        default:
            log_error( "ERROR: Invalid type passed in to generate_random_data!\n" );
            break;
    }
}

#define INC_CASE( enum, type )    \
    case enum:    \
    {                \
        cl_##type *src = (cl_##type *)inData; \
        cl_##type *dst = (cl_##type *)outData; \
        *dst = *src + 1; \
        break; \
    }

void get_incremented_value( void *inData, void *outData, ExplicitType type )
{
    switch( type )
    {
        INC_CASE( kChar, char )
        INC_CASE( kUChar, uchar )
        INC_CASE( kShort, short )
        INC_CASE( kUShort, ushort )
        INC_CASE( kInt, int )
        INC_CASE( kUInt, uint )
        INC_CASE( kLong, long )
        INC_CASE( kULong, ulong )
        INC_CASE( kFloat, float )
        default:
            break;
    }
}

int test_buffer_kernel(cl_context context, cl_command_queue queue, ExplicitType vecType, size_t vecSize, int numElements, int validate_only, MTdata d)
{
    clProgramWrapper program;
    clKernelWrapper kernel;
    clMemWrapper streams[ 3 ];
    size_t dataSize = numElements * 16 * sizeof(cl_long);
#if !(defined(_WIN32) && defined(_MSC_VER))
    cl_long inData[numElements * 16], outDataCL[numElements * 16], outDataGL[ numElements * 16 ];
#else
    cl_long* inData    = (cl_long*)_malloca(dataSize);
    cl_long* outDataCL = (cl_long*)_malloca(dataSize);
    cl_long* outDataGL = (cl_long*)_malloca(dataSize);
#endif
    glBufferWrapper inGLBuffer, outGLBuffer;
    int    i;
    size_t bufferSize;

    int error;
    size_t threads[1], localThreads[1];
    char kernelSource[10240];
    char *programPtr;
    char sizeName[4];

    /* Create the source */
    if( vecSize == 1 )
        sizeName[ 0 ] = 0;
    else
        sprintf( sizeName, "%d", (int)vecSize );

    sprintf( kernelSource, bufferKernelPattern, get_explicit_type_name( vecType ), sizeName,
                                                get_explicit_type_name( vecType ), sizeName,
                                                get_explicit_type_name( vecType ), sizeName,
                                                get_explicit_type_name( vecType ), sizeName,
                                                get_explicit_type_name( vecType ), sizeName );

    /* Create kernels */
    programPtr = kernelSource;
    if( create_single_kernel_helper( context, &program, &kernel, 1, (const char **)&programPtr, "sample_test" ) )
    {
        return -1;
    }

    bufferSize = numElements * vecSize * get_explicit_type_size( vecType );

    /* Generate some almost-random input data */
    gen_input_data( vecType, vecSize * numElements, d, inData );
    memset( outDataCL, 0, dataSize );
    memset( outDataGL, 0, dataSize );

    /* Generate some GL buffers to go against */
    glGenBuffers( 1, &inGLBuffer );
    glGenBuffers( 1, &outGLBuffer );

    glBindBuffer( GL_ARRAY_BUFFER, inGLBuffer );
    glBufferData( GL_ARRAY_BUFFER, bufferSize, inData, GL_STATIC_DRAW );

    // Note: we need to bind the output buffer, even though we don't care about its values yet,
    // because CL needs it to get the buffer size
    glBindBuffer( GL_ARRAY_BUFFER, outGLBuffer );
    glBufferData( GL_ARRAY_BUFFER, bufferSize, outDataGL, GL_STATIC_DRAW );

    glBindBuffer( GL_ARRAY_BUFFER, 0 );
    glFlush();


    /* Generate some streams. The first and last ones are GL, middle one just vanilla CL */
    streams[ 0 ] = (*clCreateFromGLBuffer_ptr)( context, CL_MEM_READ_ONLY, inGLBuffer, &error );
    test_error( error, "Unable to create input GL buffer" );

    streams[ 1 ] = clCreateBuffer( context, CL_MEM_READ_WRITE, bufferSize, NULL, &error );
    test_error( error, "Unable to create output CL buffer" );

    streams[ 2 ] = (*clCreateFromGLBuffer_ptr)( context, CL_MEM_WRITE_ONLY, outGLBuffer, &error );
    test_error( error, "Unable to create output GL buffer" );


  /* Validate the info */
  if (validate_only) {
    int result = (CheckGLObjectInfo(streams[0], CL_GL_OBJECT_BUFFER, (GLuint)inGLBuffer, (GLenum)0, 0) |
                  CheckGLObjectInfo(streams[2], CL_GL_OBJECT_BUFFER, (GLuint)outGLBuffer, (GLenum)0, 0) );

    for (i = 0; i < 3; i++)
    {
        streams[i].reset();
    }

    glDeleteBuffers(1, &inGLBuffer);    inGLBuffer = 0;
    glDeleteBuffers(1, &outGLBuffer);    outGLBuffer = 0;

    return result;
  }

    /* Assign streams and execute */
    for( int i = 0; i < 3; i++ )
    {
        error = clSetKernelArg( kernel, i, sizeof( streams[ i ] ), &streams[ i ] );
        test_error( error, "Unable to set kernel arguments" );
    }
    error = (*clEnqueueAcquireGLObjects_ptr)( queue, 1, &streams[ 0 ], 0, NULL, NULL);
  test_error( error, "Unable to acquire GL obejcts");
    error = (*clEnqueueAcquireGLObjects_ptr)( queue, 1, &streams[ 2 ], 0, NULL, NULL);
  test_error( error, "Unable to acquire GL obejcts");

    /* Run the kernel */
    threads[0] = numElements;

    error = get_max_common_work_group_size( context, kernel, threads[0], &localThreads[0] );
    test_error( error, "Unable to get work group size to use" );

  error = clEnqueueNDRangeKernel( queue, kernel, 1, NULL, threads, localThreads, 0, NULL, NULL );
    test_error( error, "Unable to execute test kernel" );

    error = (*clEnqueueReleaseGLObjects_ptr)( queue, 1, &streams[ 0 ], 0, NULL, NULL );
  test_error(error, "clEnqueueReleaseGLObjects failed");
    error = (*clEnqueueReleaseGLObjects_ptr)( queue, 1, &streams[ 2 ], 0, NULL, NULL );
  test_error(error, "clEnqueueReleaseGLObjects failed");

    // Get the results from both CL and GL and make sure everything looks correct
    error = clEnqueueReadBuffer( queue, streams[ 1 ], CL_TRUE, 0, bufferSize, outDataCL, 0, NULL, NULL );
    test_error( error, "Unable to read output CL array!" );
    glBindBuffer( GL_ARRAY_BUFFER, outGLBuffer );
    void *glMem = glMapBufferRange(GL_ARRAY_BUFFER, 0, bufferSize, GL_MAP_READ_BIT );
    memcpy( outDataGL, glMem, bufferSize );
    glUnmapBuffer( GL_ARRAY_BUFFER );
    char *inP = (char *)inData, *glP = (char *)outDataGL, *clP = (char *)outDataCL;
    error = 0;
    for( size_t i = 0; i < numElements * vecSize; i++ )
    {
        cl_long expectedCLValue, expectedGLValue;
        get_incremented_value( inP, &expectedCLValue, vecType );
        get_incremented_value( &expectedCLValue, &expectedGLValue, vecType );

        if( memcmp( clP, &expectedCLValue, get_explicit_type_size( vecType ) ) != 0 )
        {
            char scratch[ 64 ];
            log_error( "ERROR: Data sample %d from the CL output did not validate!\n", (int)i );
            log_error( "\t   Input: %s\n", GetDataVectorString( inP, get_explicit_type_size( vecType ), 1, scratch ) );
            log_error( "\tExpected: %s\n", GetDataVectorString( &expectedCLValue, get_explicit_type_size( vecType ), 1, scratch ) );
            log_error( "\t  Actual: %s\n", GetDataVectorString( clP, get_explicit_type_size( vecType ), 1, scratch ) );
            error = -1;
        }

        if( memcmp( glP, &expectedGLValue, get_explicit_type_size( vecType ) ) != 0 )
        {
            char scratch[ 64 ];
            log_error( "ERROR: Data sample %d from the GL output did not validate!\n", (int)i );
            log_error( "\t   Input: %s\n", GetDataVectorString( inP, get_explicit_type_size( vecType ), 1, scratch ) );
            log_error( "\tExpected: %s\n", GetDataVectorString( &expectedGLValue, get_explicit_type_size( vecType ), 1, scratch ) );
            log_error( "\t  Actual: %s\n", GetDataVectorString( glP, get_explicit_type_size( vecType ), 1, scratch ) );
            error = -1;
        }

        if( error )
            return error;

        inP += get_explicit_type_size( vecType );
        glP += get_explicit_type_size( vecType );
        clP += get_explicit_type_size( vecType );
    }

    for (i = 0; i < 3; i++)
    {
        streams[i].reset();
    }

    glDeleteBuffers(1, &inGLBuffer);    inGLBuffer = 0;
    glDeleteBuffers(1, &outGLBuffer);    outGLBuffer = 0;

#if (defined(_WIN32) && defined(_MSC_VER))
    _freea(inData);
    _freea(outDataCL);
    _freea(outDataGL);
#endif

    return 0;
}

int test_buffers( cl_device_id device, cl_context context, cl_command_queue queue, int numElements )
{
    ExplicitType vecType[] = { kChar, kUChar, kShort, kUShort, kInt, kUInt, kLong, kULong, kFloat, kNumExplicitTypes };
    unsigned int vecSizes[] = { 1, 2, 4, 8, 16, 0 };
    unsigned int index, typeIndex;
    int retVal = 0;
    RandomSeed seed(gRandomSeed);

    /* 64-bit ints optional in embedded profile */
    int hasLong = 1;
    int isEmbedded = 0;
    char profile[1024] = "";
    retVal = clGetDeviceInfo(device, CL_DEVICE_PROFILE, sizeof(profile), profile, NULL);
    if (retVal)
    {
        print_error(retVal, "clGetDeviceInfo for CL_DEVICE_PROFILE failed\n" );
        return -1;
    }
    isEmbedded = NULL != strstr(profile, "EMBEDDED_PROFILE");
    if(isEmbedded && !is_extension_available(device, "cles_khr_int64"))
    {
        hasLong = 0;
    }

    for( typeIndex = 0; vecType[ typeIndex ] != kNumExplicitTypes; typeIndex++ )
    {
        for( index = 0; vecSizes[ index ] != 0; index++ )
        {
            /* Fix bug 7124 */
            if ( (vecType[ typeIndex ] == kLong || vecType[ typeIndex ] == kULong) && !hasLong )
              continue;

            // Test!
            if( test_buffer_kernel( context, queue, vecType[ typeIndex ], vecSizes[ index ], numElements, 0, seed) != 0 )
            {
                char sizeNames[][ 4 ] = { "", "", "2", "", "4", "", "", "", "8", "", "", "", "", "", "", "", "16" };
                log_error( "   Buffer test %s%s FAILED\n", get_explicit_type_name( vecType[ typeIndex ] ), sizeNames[ vecSizes[ index ] ] );
                retVal++;
            }
        }
    }

    return retVal;

}


int test_buffers_getinfo( cl_device_id device, cl_context context, cl_command_queue queue, int numElements )
{
    ExplicitType vecType[] = { kChar, kUChar, kShort, kUShort, kInt, kUInt, kLong, kULong, kFloat, kNumExplicitTypes };
    unsigned int vecSizes[] = { 1, 2, 4, 8, 16, 0 };
    unsigned int index, typeIndex;
    int retVal = 0;
    RandomSeed seed( gRandomSeed );

    /* 64-bit ints optional in embedded profile */
    int hasLong = 1;
    int isEmbedded = 0;
    char profile[1024] = "";
    retVal = clGetDeviceInfo(device, CL_DEVICE_PROFILE, sizeof(profile), profile, NULL);
    if (retVal)
    {
        print_error(retVal, "clGetDeviceInfo for CL_DEVICE_PROFILE failed\n" );
        return -1;
    }
    isEmbedded = NULL != strstr(profile, "EMBEDDED_PROFILE");
    if(isEmbedded && !is_extension_available(device, "cles_khr_int64"))
    {
        hasLong = 0;
    }

    for( typeIndex = 0; vecType[ typeIndex ] != kNumExplicitTypes; typeIndex++ )
    {
        for( index = 0; vecSizes[ index ] != 0; index++ )
        {
            /* Fix bug 7124 */
            if ( (vecType[ typeIndex ] == kLong || vecType[ typeIndex ] == kULong) && !hasLong )
              continue;

            // Test!
            if( test_buffer_kernel( context, queue, vecType[ typeIndex ], vecSizes[ index ], numElements, 1, seed ) != 0 )
            {
                char sizeNames[][ 4 ] = { "", "", "2", "", "4", "", "", "", "8", "", "", "", "", "", "", "", "16" };
                log_error( "   Buffer test %s%s FAILED\n", get_explicit_type_name( vecType[ typeIndex ] ), sizeNames[ vecSizes[ index ] ] );
                retVal++;
            }
        }
    }

    return retVal;

}



