// Copyright 2018 The Amber Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "src/buffer.h"

#include <algorithm>
#include <cassert>
#include <cmath>
#include <cstring>

#include "src/float16_helper.h"

namespace amber {
namespace {

template <typename T>
T* ValuesAs(uint8_t* values) {
  return reinterpret_cast<T*>(values);
}

template <typename T>
double Sub(const uint8_t* buf1, const uint8_t* buf2) {
  return static_cast<double>(*reinterpret_cast<const T*>(buf1) -
                             *reinterpret_cast<const T*>(buf2));
}

double CalculateDiff(const Format::Segment* seg,
                     const uint8_t* buf1,
                     const uint8_t* buf2) {
  FormatMode mode = seg->GetFormatMode();
  uint32_t num_bits = seg->GetNumBits();
  if (type::Type::IsInt8(mode, num_bits))
    return Sub<int8_t>(buf1, buf2);
  if (type::Type::IsInt16(mode, num_bits))
    return Sub<int16_t>(buf1, buf2);
  if (type::Type::IsInt32(mode, num_bits))
    return Sub<int32_t>(buf1, buf2);
  if (type::Type::IsInt64(mode, num_bits))
    return Sub<int64_t>(buf1, buf2);
  if (type::Type::IsUint8(mode, num_bits))
    return Sub<uint8_t>(buf1, buf2);
  if (type::Type::IsUint16(mode, num_bits))
    return Sub<uint16_t>(buf1, buf2);
  if (type::Type::IsUint32(mode, num_bits))
    return Sub<uint32_t>(buf1, buf2);
  if (type::Type::IsUint64(mode, num_bits))
    return Sub<uint64_t>(buf1, buf2);
  if (type::Type::IsFloat16(mode, num_bits)) {
    float val1 = float16::HexFloatToFloat(buf1, 16);
    float val2 = float16::HexFloatToFloat(buf2, 16);
    return static_cast<double>(val1 - val2);
  }
  if (type::Type::IsFloat32(mode, num_bits))
    return Sub<float>(buf1, buf2);
  if (type::Type::IsFloat64(mode, num_bits))
    return Sub<double>(buf1, buf2);

  assert(false && "NOTREACHED");
  return 0.0;
}

}  // namespace

Buffer::Buffer() = default;

Buffer::~Buffer() = default;

Result Buffer::CopyTo(Buffer* buffer) const {
  if (buffer->width_ != width_)
    return Result("Buffer::CopyBaseFields() buffers have a different width");
  if (buffer->height_ != height_)
    return Result("Buffer::CopyBaseFields() buffers have a different height");
  if (buffer->element_count_ != element_count_)
    return Result("Buffer::CopyBaseFields() buffers have a different size");
  buffer->bytes_ = bytes_;
  return {};
}

Result Buffer::IsEqual(Buffer* buffer) const {
  auto result = CheckCompability(buffer);
  if (!result.IsSuccess())
    return result;

  uint32_t num_different = 0;
  uint32_t first_different_index = 0;
  uint8_t first_different_left = 0;
  uint8_t first_different_right = 0;
  for (uint32_t i = 0; i < bytes_.size(); ++i) {
    if (bytes_[i] != buffer->bytes_[i]) {
      if (num_different == 0) {
        first_different_index = i;
        first_different_left = bytes_[i];
        first_different_right = buffer->bytes_[i];
      }
      num_different++;
    }
  }

  if (num_different) {
    return Result{"Buffers have different values. " +
                  std::to_string(num_different) +
                  " values differed, first difference at byte " +
                  std::to_string(first_different_index) + " values " +
                  std::to_string(first_different_left) +
                  " != " + std::to_string(first_different_right)};
  }

  return {};
}

std::vector<double> Buffer::CalculateDiffs(const Buffer* buffer) const {
  std::vector<double> diffs;

  auto* buf_1_ptr = GetValues<uint8_t>();
  auto* buf_2_ptr = buffer->GetValues<uint8_t>();
  const auto& segments = format_->GetSegments();
  for (size_t i = 0; i < ElementCount(); ++i) {
    for (const auto& seg : segments) {
      if (seg.IsPadding()) {
        buf_1_ptr += seg.PaddingBytes();
        buf_2_ptr += seg.PaddingBytes();
        continue;
      }

      diffs.push_back(CalculateDiff(&seg, buf_1_ptr, buf_2_ptr));

      buf_1_ptr += seg.SizeInBytes();
      buf_2_ptr += seg.SizeInBytes();
    }
  }

  return diffs;
}

Result Buffer::CheckCompability(Buffer* buffer) const {
  if (!buffer->format_->Equal(format_))
    return Result{"Buffers have a different format"};
  if (buffer->element_count_ != element_count_)
    return Result{"Buffers have a different size"};
  if (buffer->width_ != width_)
    return Result{"Buffers have a different width"};
  if (buffer->height_ != height_)
    return Result{"Buffers have a different height"};
  if (buffer->ValueCount() != ValueCount())
    return Result{"Buffers have a different number of values"};

  return {};
}

Result Buffer::CompareRMSE(Buffer* buffer, float tolerance) const {
  auto result = CheckCompability(buffer);
  if (!result.IsSuccess())
    return result;

  auto diffs = CalculateDiffs(buffer);
  double sum = 0.0;
  for (const auto val : diffs)
    sum += (val * val);

  sum /= static_cast<double>(diffs.size());
  double rmse = std::sqrt(sum);
  if (rmse > static_cast<double>(tolerance)) {
    return Result("Root Mean Square Error of " + std::to_string(rmse) +
                  " is greater than tolerance of " + std::to_string(tolerance));
  }

  return {};
}

std::vector<uint64_t> Buffer::GetHistogramForChannel(uint32_t channel,
                                                     uint32_t num_bins) const {
  assert(num_bins == 256);
  std::vector<uint64_t> bins(num_bins, 0);
  auto* buf_ptr = GetValues<uint8_t>();
  auto num_channels = format_->InputNeededPerElement();
  uint32_t channel_id = 0;

  for (size_t i = 0; i < ElementCount(); ++i) {
    for (const auto& seg : format_->GetSegments()) {
      if (seg.IsPadding()) {
        buf_ptr += seg.PaddingBytes();
        continue;
      }
      if (channel_id == channel) {
        assert(type::Type::IsUint8(seg.GetFormatMode(), seg.GetNumBits()));
        const auto bin = *reinterpret_cast<const uint8_t*>(buf_ptr);
        bins[bin]++;
      }
      buf_ptr += seg.SizeInBytes();
      channel_id = (channel_id + 1) % num_channels;
    }
  }

  return bins;
}

Result Buffer::CompareHistogramEMD(Buffer* buffer, float tolerance) const {
  auto result = CheckCompability(buffer);
  if (!result.IsSuccess())
    return result;

  const int num_bins = 256;
  auto num_channels = format_->InputNeededPerElement();
  for (auto segment : format_->GetSegments()) {
    if (!type::Type::IsUint8(segment.GetFormatMode(), segment.GetNumBits()) ||
        num_channels != 4) {
      return Result(
          "EMD comparison only supports 8bit unorm format with four channels.");
    }
  }

  std::vector<std::vector<uint64_t>> histogram1;
  std::vector<std::vector<uint64_t>> histogram2;
  for (uint32_t c = 0; c < num_channels; ++c) {
    histogram1.push_back(GetHistogramForChannel(c, num_bins));
    histogram2.push_back(buffer->GetHistogramForChannel(c, num_bins));
  }

  // Earth movers's distance: Calculate the minimal cost of moving "earth" to
  // transform the first histogram into the second, where each bin of the
  // histogram can be thought of as a column of units of earth. The cost is the
  // amount of earth moved times the distance carried (the distance is the
  // number of adjacent bins over which the earth is carried). Calculate this
  // using the cumulative difference of the bins, which works as long as both
  // histograms have the same amount of earth. Sum the absolute values of the
  // cumulative difference to get the final cost of how much (and how far) the
  // earth was moved.
  double max_emd = 0;

  for (uint32_t c = 0; c < num_channels; ++c) {
    double diff_total = 0;
    double diff_accum = 0;

    for (size_t i = 0; i < num_bins; ++i) {
      double hist_normalized_1 =
          static_cast<double>(histogram1[c][i]) / element_count_;
      double hist_normalized_2 =
          static_cast<double>(histogram2[c][i]) / buffer->element_count_;
      diff_accum += hist_normalized_1 - hist_normalized_2;
      diff_total += fabs(diff_accum);
    }
    // Normalize to range 0..1
    double emd = diff_total / num_bins;
    max_emd = std::max(max_emd, emd);
  }

  if (max_emd > static_cast<double>(tolerance)) {
    return Result("Histogram EMD value of " + std::to_string(max_emd) +
                  " is greater than tolerance of " + std::to_string(tolerance));
  }

  return {};
}

Result Buffer::SetData(const std::vector<Value>& data) {
  return SetDataWithOffset(data, 0);
}

Result Buffer::RecalculateMaxSizeInBytes(const std::vector<Value>& data,
                                         uint32_t offset) {
  // Multiply by the input needed because the value count will use the needed
  // input as the multiplier
  uint32_t value_count =
      ((offset / format_->SizeInBytes()) * format_->InputNeededPerElement()) +
      static_cast<uint32_t>(data.size());
  uint32_t element_count = value_count;
  if (!format_->IsPacked()) {
    // This divides by the needed input values, not the values per element.
    // The assumption being the values coming in are read from the input,
    // where components are specified. The needed values maybe less then the
    // values per element.
    element_count = value_count / format_->InputNeededPerElement();
  }
  if (GetMaxSizeInBytes() < element_count * format_->SizeInBytes())
    SetMaxSizeInBytes(element_count * format_->SizeInBytes());
  return {};
}

Result Buffer::SetDataWithOffset(const std::vector<Value>& data,
                                 uint32_t offset) {
  // Multiply by the input needed because the value count will use the needed
  // input as the multiplier
  uint32_t value_count =
      ((offset / format_->SizeInBytes()) * format_->InputNeededPerElement()) +
      static_cast<uint32_t>(data.size());

  // The buffer should only be resized to become bigger. This means that if a
  // command was run to set the buffer size we'll honour that size until a
  // request happens to make the buffer bigger.
  if (value_count > ValueCount())
    SetValueCount(value_count);

  // Even if the value count doesn't change, the buffer is still resized because
  // this maybe the first time data is set into the buffer.
  bytes_.resize(GetSizeInBytes());

  // Set the new memory to zero to be on the safe side.
  uint32_t new_space =
      (static_cast<uint32_t>(data.size()) / format_->InputNeededPerElement()) *
      format_->SizeInBytes();
  assert(new_space + offset <= GetSizeInBytes());

  if (new_space > 0)
    memset(bytes_.data() + offset, 0, new_space);

  if (data.size() > (ElementCount() * format_->InputNeededPerElement()))
    return Result("Mismatched number of items in buffer");

  uint8_t* ptr = bytes_.data() + offset;
  const auto& segments = format_->GetSegments();
  for (uint32_t i = 0; i < data.size();) {
    for (const auto& seg : segments) {
      if (seg.IsPadding()) {
        ptr += seg.PaddingBytes();
        continue;
      }

      Value v = data[i++];
      ptr += WriteValueFromComponent(v, seg.GetFormatMode(), seg.GetNumBits(),
                                     ptr);
      if (i >= data.size())
        break;
    }
  }
  return {};
}

uint32_t Buffer::WriteValueFromComponent(const Value& value,
                                         FormatMode mode,
                                         uint32_t num_bits,
                                         uint8_t* ptr) {
  if (type::Type::IsInt8(mode, num_bits)) {
    *(ValuesAs<int8_t>(ptr)) = value.AsInt8();
    return sizeof(int8_t);
  }
  if (type::Type::IsInt16(mode, num_bits)) {
    *(ValuesAs<int16_t>(ptr)) = value.AsInt16();
    return sizeof(int16_t);
  }
  if (type::Type::IsInt32(mode, num_bits)) {
    *(ValuesAs<int32_t>(ptr)) = value.AsInt32();
    return sizeof(int32_t);
  }
  if (type::Type::IsInt64(mode, num_bits)) {
    *(ValuesAs<int64_t>(ptr)) = value.AsInt64();
    return sizeof(int64_t);
  }
  if (type::Type::IsUint8(mode, num_bits)) {
    *(ValuesAs<uint8_t>(ptr)) = value.AsUint8();
    return sizeof(uint8_t);
  }
  if (type::Type::IsUint16(mode, num_bits)) {
    *(ValuesAs<uint16_t>(ptr)) = value.AsUint16();
    return sizeof(uint16_t);
  }
  if (type::Type::IsUint32(mode, num_bits)) {
    *(ValuesAs<uint32_t>(ptr)) = value.AsUint32();
    return sizeof(uint32_t);
  }
  if (type::Type::IsUint64(mode, num_bits)) {
    *(ValuesAs<uint64_t>(ptr)) = value.AsUint64();
    return sizeof(uint64_t);
  }
  if (type::Type::IsFloat16(mode, num_bits)) {
    *(ValuesAs<uint16_t>(ptr)) = float16::FloatToHexFloat16(value.AsFloat());
    return sizeof(uint16_t);
  }
  if (type::Type::IsFloat32(mode, num_bits)) {
    *(ValuesAs<float>(ptr)) = value.AsFloat();
    return sizeof(float);
  }
  if (type::Type::IsFloat64(mode, num_bits)) {
    *(ValuesAs<double>(ptr)) = value.AsDouble();
    return sizeof(double);
  }

  // The float 10 and float 11 sizes are only used in PACKED formats.
  assert(false && "Not reached");
  return 0;
}

void Buffer::SetSizeInElements(uint32_t element_count) {
  element_count_ = element_count;
  bytes_.resize(element_count * format_->SizeInBytes());
}

void Buffer::SetSizeInBytes(uint32_t size_in_bytes) {
  assert(size_in_bytes % format_->SizeInBytes() == 0);
  element_count_ = size_in_bytes / format_->SizeInBytes();
  bytes_.resize(size_in_bytes);
}

void Buffer::SetMaxSizeInBytes(uint32_t max_size_in_bytes) {
  max_size_in_bytes_ = max_size_in_bytes;
}

uint32_t Buffer::GetMaxSizeInBytes() const {
  if (max_size_in_bytes_ != 0)
    return max_size_in_bytes_;
  else
    return GetSizeInBytes();
}

Result Buffer::SetDataFromBuffer(const Buffer* src, uint32_t offset) {
  if (bytes_.size() < offset + src->bytes_.size())
    bytes_.resize(offset + src->bytes_.size());

  std::memcpy(bytes_.data() + offset, src->bytes_.data(), src->bytes_.size());
  element_count_ =
      static_cast<uint32_t>(bytes_.size()) / format_->SizeInBytes();
  return {};
}

}  // namespace amber
