/*
 *  Copyright (c) 2013 The WebRTC project authors. All Rights Reserved.
 *
 *  Use of this source code is governed by a BSD-style license
 *  that can be found in the LICENSE file in the root of the source
 *  tree. An additional intellectual property rights grant can be found
 *  in the file PATENTS.  All contributing project authors may
 *  be found in the AUTHORS file in the root of the source tree.
 */

#include "modules/audio_processing/transient/transient_suppressor_impl.h"

#include <string.h>

#include <algorithm>
#include <cmath>
#include <complex>
#include <deque>
#include <limits>
#include <set>
#include <string>

#include "common_audio/include/audio_util.h"
#include "common_audio/signal_processing/include/signal_processing_library.h"
#include "common_audio/third_party/ooura/fft_size_256/fft4g.h"
#include "modules/audio_processing/transient/common.h"
#include "modules/audio_processing/transient/transient_detector.h"
#include "modules/audio_processing/transient/transient_suppressor.h"
#include "modules/audio_processing/transient/windows_private.h"
#include "rtc_base/checks.h"
#include "rtc_base/logging.h"

namespace webrtc {

static const float kMeanIIRCoefficient = 0.5f;

// TODO(aluebs): Check if these values work also for 48kHz.
static const size_t kMinVoiceBin = 3;
static const size_t kMaxVoiceBin = 60;

namespace {

float ComplexMagnitude(float a, float b) {
  return std::abs(a) + std::abs(b);
}

std::string GetVadModeLabel(TransientSuppressor::VadMode vad_mode) {
  switch (vad_mode) {
    case TransientSuppressor::VadMode::kDefault:
      return "default";
    case TransientSuppressor::VadMode::kRnnVad:
      return "RNN VAD";
    case TransientSuppressor::VadMode::kNoVad:
      return "no VAD";
  }
}

}  // namespace

TransientSuppressorImpl::TransientSuppressorImpl(VadMode vad_mode,
                                                 int sample_rate_hz,
                                                 int detector_rate_hz,
                                                 int num_channels)
    : vad_mode_(vad_mode),
      voice_probability_delay_unit_(/*delay_num_samples=*/0, sample_rate_hz),
      analyzed_audio_is_silent_(false),
      data_length_(0),
      detection_length_(0),
      analysis_length_(0),
      buffer_delay_(0),
      complex_analysis_length_(0),
      num_channels_(0),
      window_(NULL),
      detector_smoothed_(0.f),
      keypress_counter_(0),
      chunks_since_keypress_(0),
      detection_enabled_(false),
      suppression_enabled_(false),
      use_hard_restoration_(false),
      chunks_since_voice_change_(0),
      seed_(182),
      using_reference_(false) {
  RTC_LOG(LS_INFO) << "VAD mode: " << GetVadModeLabel(vad_mode_);
  Initialize(sample_rate_hz, detector_rate_hz, num_channels);
}

TransientSuppressorImpl::~TransientSuppressorImpl() {}

void TransientSuppressorImpl::Initialize(int sample_rate_hz,
                                         int detection_rate_hz,
                                         int num_channels) {
  RTC_DCHECK(sample_rate_hz == ts::kSampleRate8kHz ||
             sample_rate_hz == ts::kSampleRate16kHz ||
             sample_rate_hz == ts::kSampleRate32kHz ||
             sample_rate_hz == ts::kSampleRate48kHz);
  RTC_DCHECK(detection_rate_hz == ts::kSampleRate8kHz ||
             detection_rate_hz == ts::kSampleRate16kHz ||
             detection_rate_hz == ts::kSampleRate32kHz ||
             detection_rate_hz == ts::kSampleRate48kHz);
  RTC_DCHECK_GT(num_channels, 0);

  switch (sample_rate_hz) {
    case ts::kSampleRate8kHz:
      analysis_length_ = 128u;
      window_ = kBlocks80w128;
      break;
    case ts::kSampleRate16kHz:
      analysis_length_ = 256u;
      window_ = kBlocks160w256;
      break;
    case ts::kSampleRate32kHz:
      analysis_length_ = 512u;
      window_ = kBlocks320w512;
      break;
    case ts::kSampleRate48kHz:
      analysis_length_ = 1024u;
      window_ = kBlocks480w1024;
      break;
    default:
      RTC_DCHECK_NOTREACHED();
      return;
  }

  detector_.reset(new TransientDetector(detection_rate_hz));
  data_length_ = sample_rate_hz * ts::kChunkSizeMs / 1000;
  RTC_DCHECK_LE(data_length_, analysis_length_);
  buffer_delay_ = analysis_length_ - data_length_;

  voice_probability_delay_unit_.Initialize(/*delay_num_samples=*/buffer_delay_,
                                           sample_rate_hz);

  complex_analysis_length_ = analysis_length_ / 2 + 1;
  RTC_DCHECK_GE(complex_analysis_length_, kMaxVoiceBin);
  num_channels_ = num_channels;
  in_buffer_.reset(new float[analysis_length_ * num_channels_]);
  memset(in_buffer_.get(), 0,
         analysis_length_ * num_channels_ * sizeof(in_buffer_[0]));
  detection_length_ = detection_rate_hz * ts::kChunkSizeMs / 1000;
  detection_buffer_.reset(new float[detection_length_]);
  memset(detection_buffer_.get(), 0,
         detection_length_ * sizeof(detection_buffer_[0]));
  out_buffer_.reset(new float[analysis_length_ * num_channels_]);
  memset(out_buffer_.get(), 0,
         analysis_length_ * num_channels_ * sizeof(out_buffer_[0]));
  // ip[0] must be zero to trigger initialization using rdft().
  size_t ip_length = 2 + sqrtf(analysis_length_);
  ip_.reset(new size_t[ip_length]());
  memset(ip_.get(), 0, ip_length * sizeof(ip_[0]));
  wfft_.reset(new float[complex_analysis_length_ - 1]);
  memset(wfft_.get(), 0, (complex_analysis_length_ - 1) * sizeof(wfft_[0]));
  spectral_mean_.reset(new float[complex_analysis_length_ * num_channels_]);
  memset(spectral_mean_.get(), 0,
         complex_analysis_length_ * num_channels_ * sizeof(spectral_mean_[0]));
  fft_buffer_.reset(new float[analysis_length_ + 2]);
  memset(fft_buffer_.get(), 0, (analysis_length_ + 2) * sizeof(fft_buffer_[0]));
  magnitudes_.reset(new float[complex_analysis_length_]);
  memset(magnitudes_.get(), 0,
         complex_analysis_length_ * sizeof(magnitudes_[0]));
  mean_factor_.reset(new float[complex_analysis_length_]);

  static const float kFactorHeight = 10.f;
  static const float kLowSlope = 1.f;
  static const float kHighSlope = 0.3f;
  for (size_t i = 0; i < complex_analysis_length_; ++i) {
    mean_factor_[i] =
        kFactorHeight /
            (1.f + std::exp(kLowSlope * static_cast<int>(i - kMinVoiceBin))) +
        kFactorHeight /
            (1.f + std::exp(kHighSlope * static_cast<int>(kMaxVoiceBin - i)));
  }
  detector_smoothed_ = 0.f;
  keypress_counter_ = 0;
  chunks_since_keypress_ = 0;
  detection_enabled_ = false;
  suppression_enabled_ = false;
  use_hard_restoration_ = false;
  chunks_since_voice_change_ = 0;
  seed_ = 182;
  using_reference_ = false;
}

float TransientSuppressorImpl::Suppress(float* data,
                                        size_t data_length,
                                        int num_channels,
                                        const float* detection_data,
                                        size_t detection_length,
                                        const float* reference_data,
                                        size_t reference_length,
                                        float voice_probability,
                                        bool key_pressed) {
  if (!data || data_length != data_length_ || num_channels != num_channels_ ||
      detection_length != detection_length_ || voice_probability < 0 ||
      voice_probability > 1) {
    // The audio is not modified, so the voice probability is returned as is
    // (delay not applied).
    return voice_probability;
  }

  UpdateKeypress(key_pressed);
  UpdateBuffers(data);

  if (detection_enabled_) {
    UpdateRestoration(voice_probability);

    if (!detection_data) {
      // Use the input data  of the first channel if special detection data is
      // not supplied.
      detection_data = &in_buffer_[buffer_delay_];
    }

    float detector_result = detector_->Detect(detection_data, detection_length,
                                              reference_data, reference_length);
    if (detector_result < 0) {
      // The audio is not modified, so the voice probability is returned as is
      // (delay not applied).
      return voice_probability;
    }

    using_reference_ = detector_->using_reference();

    // `detector_smoothed_` follows the `detector_result` when this last one is
    // increasing, but has an exponential decaying tail to be able to suppress
    // the ringing of keyclicks.
    float smooth_factor = using_reference_ ? 0.6 : 0.1;
    detector_smoothed_ = detector_result >= detector_smoothed_
                             ? detector_result
                             : smooth_factor * detector_smoothed_ +
                                   (1 - smooth_factor) * detector_result;

    for (int i = 0; i < num_channels_; ++i) {
      Suppress(&in_buffer_[i * analysis_length_],
               &spectral_mean_[i * complex_analysis_length_],
               &out_buffer_[i * analysis_length_]);
    }
  }

  // If the suppression isn't enabled, we use the in buffer to delay the signal
  // appropriately. This also gives time for the out buffer to be refreshed with
  // new data between detection and suppression getting enabled.
  for (int i = 0; i < num_channels_; ++i) {
    memcpy(&data[i * data_length_],
           suppression_enabled_ ? &out_buffer_[i * analysis_length_]
                                : &in_buffer_[i * analysis_length_],
           data_length_ * sizeof(*data));
  }

  // The audio has been modified, return the delayed voice probability.
  return voice_probability_delay_unit_.Delay(voice_probability);
}

// This should only be called when detection is enabled. UpdateBuffers() must
// have been called. At return, `out_buffer_` will be filled with the
// processed output.
void TransientSuppressorImpl::Suppress(float* in_ptr,
                                       float* spectral_mean,
                                       float* out_ptr) {
  // Go to frequency domain.
  for (size_t i = 0; i < analysis_length_; ++i) {
    // TODO(aluebs): Rename windows
    fft_buffer_[i] = in_ptr[i] * window_[i];
  }

  WebRtc_rdft(analysis_length_, 1, fft_buffer_.get(), ip_.get(), wfft_.get());

  // Since WebRtc_rdft puts R[n/2] in fft_buffer_[1], we move it to the end
  // for convenience.
  fft_buffer_[analysis_length_] = fft_buffer_[1];
  fft_buffer_[analysis_length_ + 1] = 0.f;
  fft_buffer_[1] = 0.f;

  for (size_t i = 0; i < complex_analysis_length_; ++i) {
    magnitudes_[i] =
        ComplexMagnitude(fft_buffer_[i * 2], fft_buffer_[i * 2 + 1]);
  }
  // Restore audio if necessary.
  if (suppression_enabled_) {
    if (use_hard_restoration_) {
      HardRestoration(spectral_mean);
    } else {
      SoftRestoration(spectral_mean);
    }
  }

  // Update the spectral mean.
  for (size_t i = 0; i < complex_analysis_length_; ++i) {
    spectral_mean[i] = (1 - kMeanIIRCoefficient) * spectral_mean[i] +
                       kMeanIIRCoefficient * magnitudes_[i];
  }

  // Back to time domain.
  // Put R[n/2] back in fft_buffer_[1].
  fft_buffer_[1] = fft_buffer_[analysis_length_];

  WebRtc_rdft(analysis_length_, -1, fft_buffer_.get(), ip_.get(), wfft_.get());
  const float fft_scaling = 2.f / analysis_length_;

  for (size_t i = 0; i < analysis_length_; ++i) {
    out_ptr[i] += fft_buffer_[i] * window_[i] * fft_scaling;
  }
}

void TransientSuppressorImpl::UpdateKeypress(bool key_pressed) {
  const int kKeypressPenalty = 1000 / ts::kChunkSizeMs;
  const int kIsTypingThreshold = 1000 / ts::kChunkSizeMs;
  const int kChunksUntilNotTyping = 4000 / ts::kChunkSizeMs;  // 4 seconds.

  if (key_pressed) {
    keypress_counter_ += kKeypressPenalty;
    chunks_since_keypress_ = 0;
    detection_enabled_ = true;
  }
  keypress_counter_ = std::max(0, keypress_counter_ - 1);

  if (keypress_counter_ > kIsTypingThreshold) {
    if (!suppression_enabled_) {
      RTC_LOG(LS_INFO) << "[ts] Transient suppression is now enabled.";
    }
    suppression_enabled_ = true;
    keypress_counter_ = 0;
  }

  if (detection_enabled_ && ++chunks_since_keypress_ > kChunksUntilNotTyping) {
    if (suppression_enabled_) {
      RTC_LOG(LS_INFO) << "[ts] Transient suppression is now disabled.";
    }
    detection_enabled_ = false;
    suppression_enabled_ = false;
    keypress_counter_ = 0;
  }
}

void TransientSuppressorImpl::UpdateRestoration(float voice_probability) {
  bool not_voiced;
  switch (vad_mode_) {
    case TransientSuppressor::VadMode::kDefault: {
      constexpr float kVoiceThreshold = 0.02f;
      not_voiced = voice_probability < kVoiceThreshold;
      break;
    }
    case TransientSuppressor::VadMode::kRnnVad: {
      constexpr float kVoiceThreshold = 0.7f;
      not_voiced = voice_probability < kVoiceThreshold;
      break;
    }
    case TransientSuppressor::VadMode::kNoVad:
      // Always assume that voice is detected.
      not_voiced = false;
      break;
  }

  if (not_voiced == use_hard_restoration_) {
    chunks_since_voice_change_ = 0;
  } else {
    ++chunks_since_voice_change_;

    // Number of 10 ms frames to wait to transition to and from hard
    // restoration.
    constexpr int kHardRestorationOffsetDelay = 3;
    constexpr int kHardRestorationOnsetDelay = 80;

    if ((use_hard_restoration_ &&
         chunks_since_voice_change_ > kHardRestorationOffsetDelay) ||
        (!use_hard_restoration_ &&
         chunks_since_voice_change_ > kHardRestorationOnsetDelay)) {
      use_hard_restoration_ = not_voiced;
      chunks_since_voice_change_ = 0;
    }
  }
}

// Shift buffers to make way for new data. Must be called after
// `detection_enabled_` is updated by UpdateKeypress().
void TransientSuppressorImpl::UpdateBuffers(float* data) {
  // TODO(aluebs): Change to ring buffer.
  memmove(in_buffer_.get(), &in_buffer_[data_length_],
          (buffer_delay_ + (num_channels_ - 1) * analysis_length_) *
              sizeof(in_buffer_[0]));
  // Copy new chunk to buffer.
  for (int i = 0; i < num_channels_; ++i) {
    memcpy(&in_buffer_[buffer_delay_ + i * analysis_length_],
           &data[i * data_length_], data_length_ * sizeof(*data));
  }
  if (detection_enabled_) {
    // Shift previous chunk in out buffer.
    memmove(out_buffer_.get(), &out_buffer_[data_length_],
            (buffer_delay_ + (num_channels_ - 1) * analysis_length_) *
                sizeof(out_buffer_[0]));
    // Initialize new chunk in out buffer.
    for (int i = 0; i < num_channels_; ++i) {
      memset(&out_buffer_[buffer_delay_ + i * analysis_length_], 0,
             data_length_ * sizeof(out_buffer_[0]));
    }
  }
}

// Restores the unvoiced signal if a click is present.
// Attenuates by a certain factor every peak in the `fft_buffer_` that exceeds
// the spectral mean. The attenuation depends on `detector_smoothed_`.
// If a restoration takes place, the `magnitudes_` are updated to the new value.
void TransientSuppressorImpl::HardRestoration(float* spectral_mean) {
  const float detector_result =
      1.f - std::pow(1.f - detector_smoothed_, using_reference_ ? 200.f : 50.f);
  // To restore, we get the peaks in the spectrum. If higher than the previous
  // spectral mean we adjust them.
  for (size_t i = 0; i < complex_analysis_length_; ++i) {
    if (magnitudes_[i] > spectral_mean[i] && magnitudes_[i] > 0) {
      // RandU() generates values on [0, int16::max()]
      const float phase = 2 * ts::kPi * WebRtcSpl_RandU(&seed_) /
                          std::numeric_limits<int16_t>::max();
      const float scaled_mean = detector_result * spectral_mean[i];

      fft_buffer_[i * 2] = (1 - detector_result) * fft_buffer_[i * 2] +
                           scaled_mean * cosf(phase);
      fft_buffer_[i * 2 + 1] = (1 - detector_result) * fft_buffer_[i * 2 + 1] +
                               scaled_mean * sinf(phase);
      magnitudes_[i] = magnitudes_[i] -
                       detector_result * (magnitudes_[i] - spectral_mean[i]);
    }
  }
}

// Restores the voiced signal if a click is present.
// Attenuates by a certain factor every peak in the `fft_buffer_` that exceeds
// the spectral mean and that is lower than some function of the current block
// frequency mean. The attenuation depends on `detector_smoothed_`.
// If a restoration takes place, the `magnitudes_` are updated to the new value.
void TransientSuppressorImpl::SoftRestoration(float* spectral_mean) {
  // Get the spectral magnitude mean of the current block.
  float block_frequency_mean = 0;
  for (size_t i = kMinVoiceBin; i < kMaxVoiceBin; ++i) {
    block_frequency_mean += magnitudes_[i];
  }
  block_frequency_mean /= (kMaxVoiceBin - kMinVoiceBin);

  // To restore, we get the peaks in the spectrum. If higher than the
  // previous spectral mean and lower than a factor of the block mean
  // we adjust them. The factor is a double sigmoid that has a minimum in the
  // voice frequency range (300Hz - 3kHz).
  for (size_t i = 0; i < complex_analysis_length_; ++i) {
    if (magnitudes_[i] > spectral_mean[i] && magnitudes_[i] > 0 &&
        (using_reference_ ||
         magnitudes_[i] < block_frequency_mean * mean_factor_[i])) {
      const float new_magnitude =
          magnitudes_[i] -
          detector_smoothed_ * (magnitudes_[i] - spectral_mean[i]);
      const float magnitude_ratio = new_magnitude / magnitudes_[i];

      fft_buffer_[i * 2] *= magnitude_ratio;
      fft_buffer_[i * 2 + 1] *= magnitude_ratio;
      magnitudes_[i] = new_magnitude;
    }
  }
}

}  // namespace webrtc
