/*
 * Copyright 2014 The gRPC Authors
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package io.grpc;

import static com.google.common.truth.Truth.assertThat;
import static java.util.concurrent.TimeUnit.NANOSECONDS;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNotSame;
import static org.junit.Assert.assertSame;
import static org.junit.Assert.assertTrue;
import static org.mockito.AdditionalAnswers.delegatesTo;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.isA;
import static org.mockito.ArgumentMatchers.same;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions;
import static org.mockito.Mockito.when;

import io.grpc.ClientInterceptors.CheckedForwardingClientCall;
import io.grpc.ForwardingClientCall.SimpleForwardingClientCall;
import io.grpc.ForwardingClientCallListener.SimpleForwardingClientCallListener;
import io.grpc.testing.TestMethodDescriptors;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
import org.mockito.ArgumentCaptor;
import org.mockito.ArgumentMatchers;
import org.mockito.Mock;
import org.mockito.junit.MockitoJUnit;
import org.mockito.junit.MockitoRule;

/** Unit tests for {@link ClientInterceptors}. */
@RunWith(JUnit4.class)
public class ClientInterceptorsTest {

  @Rule
  public final MockitoRule mocks = MockitoJUnit.rule();

  @Mock
  private Channel channel;

  private BaseClientCall call = new BaseClientCall();

  private final MethodDescriptor<Void, Void> method = TestMethodDescriptors.voidMethod();

  /**
   * Sets up mocks.
   */
  @Before public void setUp() {
    when(channel.newCall(
            ArgumentMatchers.<MethodDescriptor<String, Integer>>any(), any(CallOptions.class)))
        .thenReturn(call);
  }

  @Test(expected = NullPointerException.class)
  public void npeForNullChannel() {
    ClientInterceptors.intercept(null, Arrays.<ClientInterceptor>asList());
  }

  @Test(expected = NullPointerException.class)
  public void npeForNullInterceptorList() {
    ClientInterceptors.intercept(channel, (List<ClientInterceptor>) null);
  }

  @Test(expected = NullPointerException.class)
  public void npeForNullInterceptor() {
    ClientInterceptors.intercept(channel, (ClientInterceptor) null);
  }

  @Test
  public void noop() {
    assertSame(channel, ClientInterceptors.intercept(channel, Arrays.<ClientInterceptor>asList()));
  }

  @Test
  public void channelAndInterceptorCalled() {
    ClientInterceptor interceptor =
        mock(ClientInterceptor.class, delegatesTo(new NoopInterceptor()));
    Channel intercepted = ClientInterceptors.intercept(channel, interceptor);
    CallOptions callOptions = CallOptions.DEFAULT;
    // First call
    assertSame(call, intercepted.newCall(method, callOptions));
    verify(channel).newCall(same(method), same(callOptions));
    verify(interceptor)
        .interceptCall(same(method), same(callOptions), ArgumentMatchers.<Channel>any());
    verifyNoMoreInteractions(channel, interceptor);
    // Second call
    assertSame(call, intercepted.newCall(method, callOptions));
    verify(channel, times(2)).newCall(same(method), same(callOptions));
    verify(interceptor, times(2))
        .interceptCall(same(method), same(callOptions), ArgumentMatchers.<Channel>any());
    verifyNoMoreInteractions(channel, interceptor);
  }

  @Test
  public void callNextTwice() {
    ClientInterceptor interceptor = new ClientInterceptor() {
      @Override
      public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall(
          MethodDescriptor<ReqT, RespT> method,
          CallOptions callOptions,
          Channel next) {
        // Calling next twice is permitted, although should only rarely be useful.
        assertSame(call, next.newCall(method, callOptions));
        return next.newCall(method, callOptions);
      }
    };
    Channel intercepted = ClientInterceptors.intercept(channel, interceptor);
    assertSame(call, intercepted.newCall(method, CallOptions.DEFAULT));
    verify(channel, times(2)).newCall(same(method), same(CallOptions.DEFAULT));
    verifyNoMoreInteractions(channel);
  }

  @Test
  public void ordered() {
    final List<String> order = new ArrayList<>();
    channel = new Channel() {
      @SuppressWarnings("unchecked")
      @Override
      public <ReqT, RespT> ClientCall<ReqT, RespT> newCall(
          MethodDescriptor<ReqT, RespT> method, CallOptions callOptions) {
        order.add("channel");
        return (ClientCall<ReqT, RespT>) call;
      }

      @Override
      public String authority() {
        return null;
      }
    };
    ClientInterceptor interceptor1 = new ClientInterceptor() {
      @Override
      public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall(
          MethodDescriptor<ReqT, RespT> method,
          CallOptions callOptions,
          Channel next) {
        order.add("i1");
        return next.newCall(method, callOptions);
      }
    };
    ClientInterceptor interceptor2 = new ClientInterceptor() {
      @Override
      public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall(
          MethodDescriptor<ReqT, RespT> method,
          CallOptions callOptions,
          Channel next) {
        order.add("i2");
        return next.newCall(method, callOptions);
      }
    };
    Channel intercepted = ClientInterceptors.intercept(channel, interceptor1, interceptor2);
    assertSame(call, intercepted.newCall(method, CallOptions.DEFAULT));
    assertEquals(Arrays.asList("i2", "i1", "channel"), order);
  }

  @Test
  public void orderedForward() {
    final List<String> order = new ArrayList<>();
    channel = new Channel() {
      @SuppressWarnings("unchecked")
      @Override
      public <ReqT, RespT> ClientCall<ReqT, RespT> newCall(
          MethodDescriptor<ReqT, RespT> method, CallOptions callOptions) {
        order.add("channel");
        return (ClientCall<ReqT, RespT>) call;
      }

      @Override
      public String authority() {
        return null;
      }
    };
    ClientInterceptor interceptor1 = new ClientInterceptor() {
      @Override
      public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall(
          MethodDescriptor<ReqT, RespT> method,
          CallOptions callOptions,
          Channel next) {
        order.add("i1");
        return next.newCall(method, callOptions);
      }
    };
    ClientInterceptor interceptor2 = new ClientInterceptor() {
      @Override
      public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall(
          MethodDescriptor<ReqT, RespT> method,
          CallOptions callOptions,
          Channel next) {
        order.add("i2");
        return next.newCall(method, callOptions);
      }
    };
    Channel intercepted = ClientInterceptors.interceptForward(channel, interceptor1, interceptor2);
    assertSame(call, intercepted.newCall(method, CallOptions.DEFAULT));
    assertEquals(Arrays.asList("i1", "i2", "channel"), order);
  }

  @Test
  public void callOptions() {
    final CallOptions initialCallOptions = CallOptions.DEFAULT.withDeadlineAfter(100, NANOSECONDS);
    final CallOptions newCallOptions = initialCallOptions.withDeadlineAfter(300, NANOSECONDS);
    assertNotSame(initialCallOptions, newCallOptions);
    ClientInterceptor interceptor =
        mock(ClientInterceptor.class, delegatesTo(new ClientInterceptor() {
          @Override
          public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall(
              MethodDescriptor<ReqT, RespT> method,
              CallOptions callOptions,
              Channel next) {
            return next.newCall(method, newCallOptions);
          }
        }));
    Channel intercepted = ClientInterceptors.intercept(channel, interceptor);
    intercepted.newCall(method, initialCallOptions);
    verify(interceptor)
        .interceptCall(same(method), same(initialCallOptions), ArgumentMatchers.<Channel>any());
    verify(channel).newCall(same(method), same(newCallOptions));
  }

  @Test
  public void addOutboundHeaders() {
    final Metadata.Key<String> credKey = Metadata.Key.of("Cred", Metadata.ASCII_STRING_MARSHALLER);
    ClientInterceptor interceptor = new ClientInterceptor() {
      @Override
      public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall(
          MethodDescriptor<ReqT, RespT> method,
          CallOptions callOptions,
          Channel next) {
        ClientCall<ReqT, RespT> call = next.newCall(method, callOptions);
        return new SimpleForwardingClientCall<ReqT, RespT>(call) {
          @Override
          public void start(ClientCall.Listener<RespT> responseListener, Metadata headers) {
            headers.put(credKey, "abcd");
            super.start(responseListener, headers);
          }
        };
      }
    };
    Channel intercepted = ClientInterceptors.intercept(channel, interceptor);
    @SuppressWarnings("unchecked")
    ClientCall.Listener<Void> listener = mock(ClientCall.Listener.class);
    ClientCall<Void, Void> interceptedCall = intercepted.newCall(method, CallOptions.DEFAULT);
    // start() on the intercepted call will eventually reach the call created by the real channel
    interceptedCall.start(listener, new Metadata());
    // The headers passed to the real channel call will contain the information inserted by the
    // interceptor.
    assertSame(listener, call.listener);
    assertEquals("abcd", call.headers.get(credKey));
  }

  @Test
  public void examineInboundHeaders() {
    final List<Metadata> examinedHeaders = new ArrayList<>();
    ClientInterceptor interceptor = new ClientInterceptor() {
      @Override
      public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall(
          MethodDescriptor<ReqT, RespT> method,
          CallOptions callOptions,
          Channel next) {
        ClientCall<ReqT, RespT> call = next.newCall(method, callOptions);
        return new SimpleForwardingClientCall<ReqT, RespT>(call) {
          @Override
          public void start(ClientCall.Listener<RespT> responseListener, Metadata headers) {
            super.start(new SimpleForwardingClientCallListener<RespT>(responseListener) {
              @Override
              public void onHeaders(Metadata headers) {
                examinedHeaders.add(headers);
                super.onHeaders(headers);
              }
            }, headers);
          }
        };
      }
    };
    Channel intercepted = ClientInterceptors.intercept(channel, interceptor);
    @SuppressWarnings("unchecked")
    ClientCall.Listener<Void> listener = mock(ClientCall.Listener.class);
    ClientCall<Void, Void> interceptedCall = intercepted.newCall(method, CallOptions.DEFAULT);
    interceptedCall.start(listener, new Metadata());
    // Capture the underlying call listener that will receive headers from the transport.

    Metadata inboundHeaders = new Metadata();
    // Simulate that a headers arrives on the underlying call listener.
    call.listener.onHeaders(inboundHeaders);
    assertThat(examinedHeaders).contains(inboundHeaders);
  }

  @Test
  public void normalCall() {
    ClientInterceptor interceptor = new ClientInterceptor() {
      @Override
      public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall(
          MethodDescriptor<ReqT, RespT> method,
          CallOptions callOptions,
          Channel next) {
        ClientCall<ReqT, RespT> call = next.newCall(method, callOptions);
        return new SimpleForwardingClientCall<ReqT, RespT>(call) { };
      }
    };
    Channel intercepted = ClientInterceptors.intercept(channel, interceptor);
    ClientCall<Void, Void> interceptedCall = intercepted.newCall(method, CallOptions.DEFAULT);
    assertNotSame(call, interceptedCall);
    @SuppressWarnings("unchecked")
    ClientCall.Listener<Void> listener = mock(ClientCall.Listener.class);
    Metadata headers = new Metadata();
    interceptedCall.start(listener, headers);
    assertSame(listener, call.listener);
    assertSame(headers, call.headers);
    interceptedCall.sendMessage(null /*request*/);
    assertThat(call.messages).containsExactly((String) null);
    interceptedCall.halfClose();
    assertTrue(call.halfClosed);
    interceptedCall.request(1);
    assertThat(call.requests).containsExactly(1);
  }

  @Test
  public void exceptionInStart() {
    final Exception error = new Exception("emulated error");
    ClientInterceptor interceptor = new ClientInterceptor() {
      @Override
      public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall(
          MethodDescriptor<ReqT, RespT> method,
          CallOptions callOptions,
          Channel next) {
        ClientCall<ReqT, RespT> call = next.newCall(method, callOptions);
        return new CheckedForwardingClientCall<ReqT, RespT>(call) {
          @Override
          protected void checkedStart(ClientCall.Listener<RespT> responseListener, Metadata headers)
              throws Exception {
            throw error;
            // delegate().start will not be called
          }
        };
      }
    };
    Channel intercepted = ClientInterceptors.intercept(channel, interceptor);
    @SuppressWarnings("unchecked")
    ClientCall.Listener<Void> listener = mock(ClientCall.Listener.class);
    ClientCall<Void, Void> interceptedCall = intercepted.newCall(method, CallOptions.DEFAULT);
    assertNotSame(call, interceptedCall);
    interceptedCall.start(listener, new Metadata());
    interceptedCall.sendMessage(null /*request*/);
    interceptedCall.halfClose();
    interceptedCall.request(1);
    call.done = true;
    ArgumentCaptor<Status> captor = ArgumentCaptor.forClass(Status.class);
    verify(listener).onClose(captor.capture(), any(Metadata.class));
    assertSame(error, captor.getValue().getCause());

    // Make sure nothing bad happens after the exception.
    ClientCall<?, ?> noop = ((CheckedForwardingClientCall<?, ?>)interceptedCall).delegate();
    // Should not throw, even on bad input
    noop.cancel("Cancel for test", null);
    noop.start(null, null);
    noop.request(-1);
    noop.halfClose();
    noop.sendMessage(null);
    assertFalse(noop.isReady());
  }

  @Test
  public void authorityIsDelegated() {
    ClientInterceptor interceptor = new ClientInterceptor() {
      @Override
      public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall(
          MethodDescriptor<ReqT, RespT> method,
          CallOptions callOptions,
          Channel next) {
        return next.newCall(method, callOptions);
      }
    };

    when(channel.authority()).thenReturn("auth");
    Channel intercepted = ClientInterceptors.intercept(channel, interceptor);
    assertEquals("auth", intercepted.authority());
  }

  @Test
  public void customOptionAccessible() {
    CallOptions.Key<String> customOption = CallOptions.Key.create("custom");
    CallOptions callOptions = CallOptions.DEFAULT.withOption(customOption, "value");
    ArgumentCaptor<CallOptions> passedOptions = ArgumentCaptor.forClass(CallOptions.class);
    ClientInterceptor interceptor =
        mock(ClientInterceptor.class, delegatesTo(new NoopInterceptor()));

    Channel intercepted = ClientInterceptors.intercept(channel, interceptor);

    assertSame(call, intercepted.newCall(method, callOptions));
    verify(channel).newCall(same(method), same(callOptions));

    verify(interceptor).interceptCall(same(method), passedOptions.capture(), isA(Channel.class));
    assertSame("value", passedOptions.getValue().getOption(customOption));
  }

  private static class NoopInterceptor implements ClientInterceptor {
    @Override
    public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall(MethodDescriptor<ReqT, RespT> method,
        CallOptions callOptions, Channel next) {
      return next.newCall(method, callOptions);
    }
  }

  private static class BaseClientCall extends ClientCall<String, Integer> {
    private boolean started;
    private boolean done;
    private ClientCall.Listener<Integer> listener;
    private Metadata headers;
    private List<Integer> requests = new ArrayList<>();
    private List<String> messages = new ArrayList<>();
    private boolean halfClosed;

    @Override
    public void start(ClientCall.Listener<Integer> listener, Metadata headers) {
      checkNotDone();
      started = true;
      this.listener = listener;
      this.headers = headers;
    }

    @Override
    public void request(int numMessages) {
      checkNotDone();
      checkStarted();
      requests.add(numMessages);
    }

    @Override
    public void cancel(String message, Throwable cause) {
      checkNotDone();
    }

    @Override
    public void halfClose() {
      checkNotDone();
      checkStarted();
      this.halfClosed = true;
    }

    @Override
    public void sendMessage(String message) {
      checkNotDone();
      checkStarted();
      messages.add(message);
    }

    private void checkNotDone() {
      if (done) {
        throw new IllegalStateException("no more methods should be called");
      }
    }

    private void checkStarted() {
      if (!started) {
        throw new IllegalStateException("should have called start");
      }
    }
  }
}
