/* pit
 * See LICENSE file for copyright and license details. */
#define _POSIX_C_SOURCE 200809L
#define _DEFAULT_SOURCE
#include <ctype.h>
#include <dirent.h>
#include <errno.h>
#include <fcntl.h>
#include <libcryptsetup.h>
#include <limits.h>
#include <pwd.h>
#include <signal.h>
#include <sodium.h>
#include <stdarg.h>
#include <string.h>
#include <stdio.h>
#include <sys/mman.h>
#include <sys/prctl.h>
#include <sys/resource.h>
#include <sys/stat.h>
#include <sys/types.h>
#include <sys/wait.h>
#include <sys/mount.h>
#include <termios.h>
#include <unistd.h>

#define PIT_BLOCK_SIZE 4096
#define KEY_SIZE 32
#define SALT_SIZE 32
#define ITER_COUNT 500000

#define VERSION "0.1"
#define MAPPER_PREFIX "pit-"
#define MOUNTPOINT_PREFIX "/mnt/pit-"
#define FS_TYPE "ext4"

#define PASSWORD_MAX 1024
#define WRAP_NONCE_SIZE crypto_secretbox_NONCEBYTES
#define WRAP_MAC_SIZE crypto_secretbox_MACBYTES
#define KEYFILE_PAYLOAD_SIZE (KEY_SIZE + WRAP_MAC_SIZE)
#define KEYFILE_SIZE (SALT_SIZE + WRAP_NONCE_SIZE + KEYFILE_PAYLOAD_SIZE)

#define CIPHER "aes"
#define CIPHER_MODE "xts-plain64"
#define HASH "sha256"

struct pwhash_level
{
  unsigned long long opslimit;
  size_t memlimit;
};

static const struct pwhash_level pwhash_levels[] = {
  {crypto_pwhash_OPSLIMIT_SENSITIVE, crypto_pwhash_MEMLIMIT_SENSITIVE},
  {crypto_pwhash_OPSLIMIT_MODERATE, crypto_pwhash_MEMLIMIT_MODERATE},
  {crypto_pwhash_OPSLIMIT_MIN, crypto_pwhash_MEMLIMIT_MIN}
};

typedef struct Pit
{
  char *path;
  size_t size;
  char *key;
  int ismounted;
} Pit;

static const size_t npwhash_levels =
  sizeof (pwhash_levels) / sizeof (pwhash_levels[0]);

static void die (const char *fmt, ...);
static void usage (void);
static int init_sec_mem (void);
static void *xrealloc (void *p, size_t len);
static void *secure_alloc (size_t size);
static void secure_free (void *ptr, size_t size);
static void term_cleanup (int signo);
static int read_password (char *password, size_t size, const char *prompt);
static int exec_cmd (const char *fmt, ...);
static const char *get_username (void);
static int check_priv_esc (const char *tool);
static int run_privileged (const char *fmt, ...);
static int generate_key (const char *keyfile);
static int read_key_file (const char *path, char **key);
static int create_pit (const char *path, size_t size);
static char *get_mapper_path (const char *path);
static int cleanup_stale_device (const char *name);
static int setup_device_mapper (const char *path, const char *key);
static int teardown_device_mapper (const char *path);
static int check_filesystem (const char *device);
static int create_filesystem (const char *device);
static int ensure_mount_dir (void);
static int create_mount_point (const char *path);
static int mount_filesystem (const char *device, const char *mountpoint);
static int unmount_filesystem (const char *mountpoint);
static int open_pit (const char *path, const char *keyfile);
static int close_pit (const char *path);
static int list_pits (void);
static int find_mounted_pits (char ***paths, int *count);
static int panic_close (void);

static const char *program_name;
static int term_modified = 0;
static struct termios saved_term;
static long pagesize;

static void
die (const char *fmt, ...)
{
  va_list ap;

  va_start (ap, fmt);
  vfprintf (stderr, fmt, ap);
  va_end (ap);
  exit (1);
}

static int
init_sec_mem (void)
{
  struct rlimit rlim;
  size_t required_mem = 8 * 1024 * 1024;

  if (getrlimit (RLIMIT_MEMLOCK, &rlim) == 0) {
    if (rlim.rlim_cur < required_mem) {

      if (geteuid () == 0) {
	rlim.rlim_cur = required_mem;
	rlim.rlim_max = required_mem;
	if (setrlimit (RLIMIT_MEMLOCK, &rlim) < 0) {
	  fprintf (stderr,
		   "pit: warning: couldn't increase memory lock limit\n");
	}
      }
    }
  }

  if (mlockall (MCL_CURRENT | MCL_FUTURE) < 0)
    die ("pit: couldn't lock memory pages: %s\n", strerror (errno));

  return 0;
}

static void *
xrealloc (void *p, size_t len)
{
  if ((p = realloc (p, len)) == NULL)
    die ("pit: realloc: %s\n", strerror (errno));

  return p;
}

static void *
secure_alloc (size_t size)
{
  void *ptr;

  if (pagesize <= 0) {
    pagesize = sysconf (_SC_PAGESIZE);
    if (pagesize < 0)
      die ("pit: could not get system page size: %s\n", strerror (errno));
  }

  if (posix_memalign (&ptr, pagesize, size) != 0)
    die ("pit: posix_memalign failed: %s\n", strerror (errno));

  if (mlock (ptr, size) < 0) {
    free (ptr);
    die ("pit: mlock failed: %s\n", strerror (errno));
  }

  if (madvise (ptr, size, MADV_DONTDUMP) < 0) {
    munlock (ptr, size);
    free (ptr);
    die ("pit: madvise failed: %s\n", strerror (errno));
  }

  return ptr;
}

static void
secure_free (void *ptr, size_t size)
{
  if (ptr) {
    sodium_memzero (ptr, size);
    munlock (ptr, size);
    free (ptr);
  }
}

static void
usage (void)
{
  die ("usage: pit [-v] [-h] command [arguments]\n"
       "Commands:\n"
       "  dig FILE 10                 - create new empty pit file of 10 MB size\n"
       "  key KEY.key                 - generate new encrypted key file\n"
       "  open FILE KEY.key           - open an existing pit\n"
       "  close PATH                  - close an opened pit\n"
       "  list                        - list opened pits\n"
       "  panic                       - emergency close all pits (forced)\n"
       "  example: pit dig container.pit\n"
       "           pit key container.key\n"
       "           pit open container.pit container.key\n");
}

static void
term_cleanup (int signo)
{
  if (term_modified) {
    tcsetattr (STDIN_FILENO, TCSAFLUSH, &saved_term);
    fprintf (stderr, "\n");
    term_modified = 0;
  }
  if (signo != 0) {
    exit (1);
  }
}

static int
read_password (char *password, size_t size, const char *prompt)
{
  struct termios new;
  struct sigaction sa;
  int len;

  sa.sa_handler = term_cleanup;
  sigemptyset (&sa.sa_mask);
  sa.sa_flags = 0;
  sigaction (SIGINT, &sa, NULL);
  sigaction (SIGTERM, &sa, NULL);

  if (tcgetattr (STDIN_FILENO, &saved_term) != 0)
    return -1;

  new = saved_term;
  new.c_lflag &= ~ECHO;

  if (tcsetattr (STDIN_FILENO, TCSAFLUSH, &new) != 0)
    return -1;

  term_modified = 1;

  fprintf (stderr, "%s", prompt);
  fflush (stderr);

  if (!fgets (password, size, stdin)) {
    term_cleanup (0);
    return -1;
  }

  term_cleanup (0);

  len = strlen (password);
  if (len > 0 && password[len - 1] == '\n')
    password[--len] = '\0';

  return len;
}

static int
exec_cmd (const char *fmt, ...)
{
  char cmd[4096];
  va_list ap;
  int ret;

  va_start (ap, fmt);
  vsnprintf (cmd, sizeof (cmd), fmt, ap);
  va_end (ap);

  ret = system (cmd);
  return WEXITSTATUS (ret);
}

static const char *
get_username (void)
{
  uid_t uid = getuid ();
  struct passwd *pw = getpwuid (uid);
  return pw ? pw->pw_name : NULL;
}

static int
check_priv_esc (const char *tool)
{
  char path[PATH_MAX];
  snprintf (path, sizeof (path), "/usr/bin/%s", tool);
  return access (path, X_OK) == 0;
}

static int
run_privileged (const char *fmt, ...)
{
  char cmd[4096];
  va_list ap;
  const char *username;
  const char *sudo_tool = NULL;

  if (geteuid () == 0) {
    va_start (ap, fmt);
    vsnprintf (cmd, sizeof (cmd), fmt, ap);
    va_end (ap);
    return exec_cmd ("%s", cmd);
  }

  if (check_priv_esc ("doas"))
    sudo_tool = "doas";
  else if (check_priv_esc ("sudo"))
    sudo_tool = "sudo";

  if (!sudo_tool) {
    fprintf (stderr, "pit: no privilege escalation tool found\n");
    return -1;
  }

  username = get_username ();
  if (!username) {
    fprintf (stderr, "pit: cannot get username\n");
    return -1;
  }

  va_start (ap, fmt);
  vsnprintf (cmd, sizeof (cmd), fmt, ap);
  va_end (ap);

  if (strcmp (sudo_tool, "doas") == 0) {
    return exec_cmd ("doas %s", cmd);
  }
  else {
    return exec_cmd ("sudo -p '[sudo] Enter password for user %s: ' %s",
		     username, cmd);
  }
}

static int
generate_key (const char *keyfile)
{
  unsigned char *master = secure_alloc (KEY_SIZE);
  unsigned char *salt = secure_alloc (SALT_SIZE);
  unsigned char *nonce = secure_alloc (WRAP_NONCE_SIZE);
  unsigned char *wrapping_key = secure_alloc (KEY_SIZE);
  unsigned char *filebuf = secure_alloc (KEYFILE_SIZE);
  char *password = secure_alloc (PASSWORD_MAX);
  char *verify = secure_alloc (PASSWORD_MAX);
  int pwlen;
  int fd = -1;
  int ret = -1;
  size_t i;
  size_t total;
  ssize_t written;
  int derived = -1;

  randombytes_buf (master, KEY_SIZE);
  randombytes_buf (salt, SALT_SIZE);
  randombytes_buf (nonce, WRAP_NONCE_SIZE);

  pwlen =
    read_password (password, PASSWORD_MAX,
		   "Enter password for key encryption (no echo): ");
  if (pwlen <= 0) {
    fprintf (stderr, "pit: failed to read password\n");
    goto cleanup;
  }

  if (read_password (verify, PASSWORD_MAX, "Verify password: ") <= 0) {
    fprintf (stderr, "pit: failed to read password verification\n");
    goto cleanup;
  }

  if (strcmp (password, verify) != 0) {
    fprintf (stderr, "pit: passwords do not match\n");
    goto cleanup;
  }

  for (i = 0; i < npwhash_levels; i++) {
    if (crypto_pwhash (wrapping_key, KEY_SIZE,
		       password, pwlen,
		       salt,
		       pwhash_levels[i].opslimit,
		       pwhash_levels[i].memlimit,
		       crypto_pwhash_ALG_DEFAULT) == 0) {
      derived = 0;
      break;
    }
  }

  if (derived != 0) {
    fprintf (stderr, "pit: key derivation failed - insufficient memory\n");
    goto cleanup;
  }

  memcpy (filebuf, salt, SALT_SIZE);
  memcpy (filebuf + SALT_SIZE, nonce, WRAP_NONCE_SIZE);

  if (crypto_secretbox_easy (filebuf + SALT_SIZE + WRAP_NONCE_SIZE,
			     master, KEY_SIZE, nonce, wrapping_key) != 0) {
    fprintf (stderr, "pit: failed to encrypt key material\n");
    goto cleanup;
  }

  fd = open (keyfile, O_WRONLY | O_CREAT | O_TRUNC, 0600);
  if (fd < 0) {
    fprintf (stderr, "pit: cannot create key file: %s\n", strerror (errno));
    goto cleanup;
  }

  total = 0;
  while (total < KEYFILE_SIZE) {
    written = write (fd, filebuf + total, KEYFILE_SIZE - total);
    if (written < 0) {
      if (errno == EINTR)
	continue;
      fprintf (stderr, "pit: failed to write key file: %s\n",
	       strerror (errno));
      goto cleanup;
    }
    if (written == 0) {
      fprintf (stderr, "pit: failed to write key file: short write\n");
      goto cleanup;
    }
    total += written;
  }

  printf ("pit: key generated successfully\n");
  ret = 0;

cleanup:
  if (fd >= 0)
    close (fd);
  secure_free (master, KEY_SIZE);
  secure_free (salt, SALT_SIZE);
  secure_free (nonce, WRAP_NONCE_SIZE);
  secure_free (filebuf, KEYFILE_SIZE);
  secure_free (wrapping_key, KEY_SIZE);
  secure_free (password, PASSWORD_MAX);
  secure_free (verify, PASSWORD_MAX);
  return ret;
}

static int
create_pit (const char *path, size_t size)
{
  int fd;
  char *buf = secure_alloc (PIT_BLOCK_SIZE);
  size_t remain;
  ssize_t nwrite;

  if (!buf) {
    die ("pit: out of memory\n");
  }

  if (access (path, F_OK) == 0)
    die ("pit: %s already exists\n", path);

  fd = open (path, O_WRONLY | O_CREAT, 0600);
  if (fd < 0)
    die ("pit: cannot create %s: %s\n", path, strerror (errno));

  randombytes_buf (buf, PIT_BLOCK_SIZE);

  remain = size * 1024 * 1024;
  while (remain > 0) {
    nwrite = write (fd, buf, PIT_BLOCK_SIZE);
    if (nwrite < 0) {
      close (fd);
      secure_free (buf, PIT_BLOCK_SIZE);
      die ("pit: write error: %s\n", strerror (errno));
    }
    remain -= nwrite;

    if (remain % (50 * PIT_BLOCK_SIZE) == 0)
      randombytes_buf (buf, PIT_BLOCK_SIZE);
  }

  close (fd);
  secure_free (buf, PIT_BLOCK_SIZE);
  return 0;
}

static int
read_key_file (const char *path, char **key)
{
  int fd;
  struct stat st;
  unsigned char *filebuf = secure_alloc (KEYFILE_SIZE);
  unsigned char *decrypted = secure_alloc (KEY_SIZE);
  unsigned char *wrapping_key = secure_alloc (KEY_SIZE);
  char *password = secure_alloc (PASSWORD_MAX);
  int pwlen;
  int ret = -1;
  size_t i;
  int success = 0;
  ssize_t nread;

  if (stat (path, &st) < 0)
    die ("pit: cannot stat key file %s: %s\n", path, strerror (errno));
  if (st.st_size != (off_t) KEYFILE_SIZE)
    die ("pit: invalid key file size\n");

  fd = open (path, O_RDONLY);
  if (fd < 0)
    die ("pit: cannot open key file %s: %s\n", path, strerror (errno));

  nread = read (fd, filebuf, KEYFILE_SIZE);
  if (nread != (ssize_t) KEYFILE_SIZE) {
    close (fd);
    die ("pit: invalid key file size\n");
  }
  close (fd);
  fd = -1;

  fflush (stdout);
  pwlen =
    read_password (password, PASSWORD_MAX,
		   "Enter password for key (no echo): ");
  if (pwlen <= 0) {
    fprintf (stderr, "pit: failed to read password\n");
    goto cleanup;
  }

  for (i = 0; i < npwhash_levels; i++) {
    if (crypto_pwhash (wrapping_key, KEY_SIZE,
		       password, pwlen,
		       filebuf,
		       pwhash_levels[i].opslimit,
		       pwhash_levels[i].memlimit,
		       crypto_pwhash_ALG_DEFAULT) != 0)
      continue;

    if (crypto_secretbox_open_easy (decrypted,
				    filebuf + SALT_SIZE + WRAP_NONCE_SIZE,
				    KEYFILE_PAYLOAD_SIZE,
				    filebuf + SALT_SIZE, wrapping_key) == 0) {
      success = 1;
      break;
    }
  }

  if (!success) {
    fprintf (stderr,
	     "pit: key derivation failed - insufficient memory or wrong password\n");
    goto cleanup;
  }

  *key = (char *) decrypted;
  decrypted = NULL;
  ret = 0;

cleanup:
  if (fd >= 0)
    close (fd);
  secure_free (filebuf, KEYFILE_SIZE);
  if (decrypted)
    secure_free (decrypted, KEY_SIZE);
  secure_free (wrapping_key, KEY_SIZE);
  secure_free (password, PASSWORD_MAX);
  return ret;
}

static char *
get_mapper_path (const char *path)
{
  static char mapper[PATH_MAX];
  const char *name = strrchr (path, '/');

  name = name ? name + 1 : path;
  snprintf (mapper, sizeof (mapper), "/dev/mapper/%s%s", MAPPER_PREFIX, name);
  return mapper;
}

static int
cleanup_stale_device (const char *name)
{
  char path[PATH_MAX];
  struct stat st;
  const char *mapper_name;

  mapper_name = strrchr (get_mapper_path (name), '/');
  if (!mapper_name)
    return -1;
  mapper_name++;

  snprintf (path, sizeof (path), "/dev/mapper/%s", mapper_name);
  if (stat (path, &st) == 0) {
    struct crypt_device *cd;
    int r;

    printf ("pit: cleaning up stale device %s\n", mapper_name);

    r = crypt_init_by_name (&cd, mapper_name);
    if (r < 0) {
      fprintf (stderr, "pit: failed to init device %s\n", mapper_name);
      return -1;
    }

    r = crypt_deactivate (cd, mapper_name);
    crypt_free (cd);

    if (r < 0) {
      fprintf (stderr, "pit: failed to deactivate stale device %s\n",
	       mapper_name);
      return -1;
    }
  }
  return 0;
}

static int
setup_device_mapper (const char *path, const char *key)
{
  struct crypt_device *cd;
  int r;
  const char *mapper_name;
  const char *mapper_path;

  mapper_path = get_mapper_path (path);
  if (!mapper_path)
    return -1;
  mapper_name = strrchr (mapper_path, '/') + 1;

  if (cleanup_stale_device (path) < 0) {
    fprintf (stderr, "pit: failed to clean up stale device\n");
    return -1;
  }

  r = crypt_init (&cd, path);
  if (r < 0) {
    fprintf (stderr, "pit: crypt_init() failed for %s\n", path);
    return r;
  }

  r = crypt_load (cd, CRYPT_LUKS1, NULL);
  if (r == 0) {

    r = crypt_activate_by_passphrase (cd, mapper_name,
				      CRYPT_ANY_SLOT, key, KEY_SIZE, 0);
    if (r < 0) {
      fprintf (stderr, "pit: failed to activate device %s\n", path);
      crypt_free (cd);
      return r;
    }
  }
  else {

    struct crypt_params_luks1 params = {
      .hash = HASH,
      .data_alignment = 0,
      .data_device = NULL
    };

    r = crypt_format (cd, CRYPT_LUKS1, CIPHER, CIPHER_MODE,
		      NULL, key, KEY_SIZE, &params);
    if (r < 0) {
      fprintf (stderr, "pit: failed to format device %s\n", path);
      crypt_free (cd);
      return r;
    }

    r = crypt_keyslot_add_by_volume_key (cd, 0, NULL, 0, key, KEY_SIZE);
    if (r < 0) {
      fprintf (stderr, "pit: failed to add keyslot\n");
      crypt_free (cd);
      return r;
    }

    r = crypt_activate_by_passphrase (cd, mapper_name,
				      CRYPT_ANY_SLOT, key, KEY_SIZE, 0);
    if (r < 0) {
      fprintf (stderr, "pit: failed to activate device %s\n", path);
      crypt_free (cd);
      return r;
    }
  }

  crypt_free (cd);
  return 0;
}

static int
teardown_device_mapper (const char *path)
{
  struct crypt_device *cd;
  int r;
  const char *mapper_path;

  mapper_path = get_mapper_path (path);
  if (!mapper_path)
    return -1;

  r = crypt_init_by_name (&cd, strrchr (mapper_path, '/') + 1);
  if (r < 0) {
    fprintf (stderr, "pit: crypt_init_by_name() failed for %s\n",
	     mapper_path);
    return r;
  }

  r = crypt_deactivate (cd, strrchr (mapper_path, '/') + 1);
  if (r < 0) {
    fprintf (stderr, "pit: failed to deactivate device %s\n", mapper_path);
    crypt_free (cd);
    return r;
  }

  crypt_free (cd);
  return 0;
}

static int
check_filesystem (const char *device)
{
  int fd;
  unsigned char buf[2048];
  ssize_t bytes_read;

  fd = open (device, O_RDONLY);
  if (fd < 0)
    return 0;

  if (lseek (fd, 1024, SEEK_SET) != 1024) {
    close (fd);
    return 0;
  }

  bytes_read = read (fd, buf, sizeof (buf));
  close (fd);

  if (bytes_read < 2)
    return 0;

  return (buf[0x38] == 0x53 && buf[0x39] == 0xEF);
}

static int
open_pit (const char *path, const char *keyfile)
{
  char *key;
  struct stat st;
  int r;
  char mount_dir[PATH_MAX];
  const char *name;
  const char *mapper_path;

  if (stat (path, &st) < 0)
    die ("pit: cannot stat %s: %s\n", path, strerror (errno));

  if (!S_ISREG (st.st_mode))
    die ("pit: %s is not a regular file\n", path);

  if (read_key_file (keyfile, &key) < 0)
    return -1;

  r = setup_device_mapper (path, key);
  secure_free (key, KEY_SIZE);

  if (r < 0)
    die ("pit: failed to setup device mapper\n");

  mapper_path = get_mapper_path (path);
  if (!mapper_path)
    die ("pit: failed to get mapper path\n");

  printf ("pit: checking fs on %s\n", mapper_path);
  if (!check_filesystem (mapper_path)) {
    printf ("pit: no filesystem detected, creating new one\n");
    if (create_filesystem (mapper_path) < 0) {
      teardown_device_mapper (path);
      die ("pit: failed to create filesystem\n");
    }
  }
  else {
    printf ("pit: existing filesystem found\n");
  }

  if (create_mount_point (path) < 0) {
    teardown_device_mapper (path);
    die ("pit: failed to create mount point\n");
  }

  name = strrchr (path, '/');
  name = name ? name + 1 : path;
  snprintf (mount_dir, sizeof (mount_dir), "%s%s", MOUNTPOINT_PREFIX, name);

  if (mount_filesystem (mapper_path, mount_dir) < 0) {
    teardown_device_mapper (path);
    rmdir (mount_dir);
    die ("pit: failed to mount filesystem\n");
  }

  printf ("pit: successfully opened %s on %s\n", path, mount_dir);
  return 0;
}

static int
list_pits (void)
{
  char **mounted_paths = NULL;
  int count = 0;
  int ret = 0;
  int i = 0;

  if (find_mounted_pits (&mounted_paths, &count) < 0) {
    fprintf (stderr, "pit: failed to find mounted pits\n");
    return -1;
  }

  if (count == 0) {
    printf ("pit: no pits currently mounted\n");
  }
  else {
    printf ("pit: %d pit%s currently mounted:\n", count,
	    count == 1 ? "" : "s");

    for (i = 0; i < count; i++) {
      if (mounted_paths[i]) {
	char device_path[PATH_MAX] = { 0 };
	FILE *mtab = fopen ("/proc/mounts", "r");
	if (mtab) {
	  char line[PATH_MAX];
	  while (fgets (line, sizeof (line), mtab)) {
	    char mnt_path[PATH_MAX] = { 0 };
	    char dev_path[PATH_MAX] = { 0 };
	    sscanf (line, "%s %s", dev_path, mnt_path);

	    if (strcmp (mnt_path, mounted_paths[i]) == 0) {
	      strncpy (device_path, dev_path, PATH_MAX - 1);
	      break;
	    }
	  }
	  fclose (mtab);
	}

	if (device_path[0]) {
	  const char *mapper_prefix = "/dev/mapper/pit-";
	  if (strncmp (device_path, mapper_prefix, strlen (mapper_prefix)) ==
	      0) {
	    printf ("  %s -> %s\n", device_path + strlen (mapper_prefix),
		    mounted_paths[i]);
	  }
	  else {
	    printf ("  %s -> %s\n", device_path, mounted_paths[i]);
	  }
	}
	else {
	  printf ("  %s\n", mounted_paths[i]);
	}

	free (mounted_paths[i]);
      }
    }
  }

  free (mounted_paths);
  return ret;
}

static int
ensure_mount_dir (void)
{
  struct stat st;

  if (stat ("/mnt", &st) < 0) {
    if (mkdir ("/mnt", 0755) < 0) {
      fprintf (stderr, "pit: cannot create /mnt: %s\n", strerror (errno));
      return -1;
    }
  }
  return 0;
}

static int
create_mount_point (const char *path)
{
  char mount_dir[PATH_MAX];
  const char *name;

  if (ensure_mount_dir () < 0)
    return -1;

  name = strrchr (path, '/');
  name = name ? name + 1 : path;

  snprintf (mount_dir, sizeof (mount_dir), "%s%s", MOUNTPOINT_PREFIX, name);

  if (mkdir (mount_dir, 0700) < 0 && errno != EEXIST) {
    fprintf (stderr, "pit: cannot create mountpoint %s: %s\n",
	     mount_dir, strerror (errno));
    return -1;
  }

  return 0;
}

static int
create_filesystem (const char *device)
{
  pid_t pid;
  int status;

  pid = fork ();
  if (pid < 0) {
    fprintf (stderr, "pit: fork failed: %s\n", strerror (errno));
    return -1;
  }

  if (pid == 0) {
    execl ("/sbin/mkfs.ext4", "mkfs.ext4", "-q", "-F", device, NULL);
    fprintf (stderr, "pit: exec mkfs.ext4 failed: %s\n", strerror (errno));
    _exit (1);
  }

  if (waitpid (pid, &status, 0) < 0) {
    fprintf (stderr, "pit: waitpid failed: %s\n", strerror (errno));
    return -1;
  }

  if (!WIFEXITED (status) || WEXITSTATUS (status) != 0) {
    fprintf (stderr, "pit: mkfs.ext4 failed\n");
    return -1;
  }

  return 0;
}

static int
mount_filesystem (const char *device, const char *mountpoint)
{
  unsigned long flags = 0;
  if (mount (device, mountpoint, FS_TYPE, flags, NULL) < 0) {
    fprintf (stderr, "pit: mount failed: %s\n", strerror (errno));
    return -1;
  }
  return 0;
}

static int
unmount_filesystem (const char *mountpoint)
{
  if (umount (mountpoint) < 0) {
    fprintf (stderr, "pit: umount failed: %s\n", strerror (errno));
    return -1;
  }
  if (rmdir (mountpoint) < 0) {
    fprintf (stderr, "pit: rmdir failed: %s\n", strerror (errno));
    return -1;
  }
  return 0;
}

static int
close_pit (const char *path)
{
  struct stat st;
  int r;
  char mount_dir[PATH_MAX];
  const char *name;

  if (stat (path, &st) < 0)
    die ("pit: cannot stat %s: %s\n", path, strerror (errno));

  if (!S_ISREG (st.st_mode))
    die ("pit: %s is not a regular file\n", path);

  name = strrchr (path, '/');
  name = name ? name + 1 : path;
  snprintf (mount_dir, sizeof (mount_dir), "%s%s", MOUNTPOINT_PREFIX, name);

  if (unmount_filesystem (mount_dir) < 0)
    die ("pit: failed to unmount filesystem\n");

  r = teardown_device_mapper (path);
  if (r < 0)
    die ("pit: failed to teardown device mapper\n");

  printf ("pit: successfully closed %s\n", path);
  return 0;
}

static int
find_mounted_pits (char ***paths, int *count)
{
  FILE *mtab;
  char line[PATH_MAX];
  char **list = NULL;
  int n = 0;

  mtab = fopen ("/proc/mounts", "r");
  if (!mtab) {
    fprintf (stderr, "pit: cannot open /proc/mounts\n");
    return -1;
  }

  while (fgets (line, sizeof (line), mtab)) {
    char *space;
    char *mountpoint;
    int i;

    if (strstr (line, MOUNTPOINT_PREFIX)) {
      space = strchr (line, ' ');
      if (!space)
	continue;

      list = xrealloc (list, (n + 1) * sizeof (char *));

      *space = '\0';
      space++;
      mountpoint = space;
      space = strchr (mountpoint, ' ');
      if (!space)
	continue;
      *space = '\0';

      list[n] = strdup (mountpoint);
      if (!list[n]) {
	fclose (mtab);
	for (i = 0; i < n; i++)
	  free (list[i]);
	free (list);
	die ("pit: strdup failed: %s\n", strerror (errno));
      }
      n++;
    }
  }

  fclose (mtab);
  *paths = list;
  *count = n;
  return 0;
}

static int
panic_close (void)
{
  DIR *dir;
  struct dirent *dp;
  int ret = 0;
  char **mounted = NULL;
  int count = 0;
  int i;

  if (geteuid () != 0) {
    return run_privileged ("%s panic", program_name);
  }

  if (find_mounted_pits (&mounted, &count) == 0 && count > 0) {
    printf ("pit: force closing %d containers...\n", count);

    for (i = 0; i < count; i++) {
      DIR *proc_dir;
      struct dirent *pid_dir;
      char path[PATH_MAX];
      char link[PATH_MAX];
      ssize_t len;
      pid_t pid;

      if (!mounted[i])
	continue;

      proc_dir = opendir ("/proc");
      if (proc_dir) {
	while ((pid_dir = readdir (proc_dir)) != NULL) {
	  if (!isdigit (pid_dir->d_name[0]))
	    continue;

	  snprintf (path, sizeof (path), "/proc/%s/cwd", pid_dir->d_name);

	  len = readlink (path, link, sizeof (link) - 1);
	  if (len > 0) {
	    link[len] = '\0';
	    if (strstr (link, mounted[i])) {
	      pid = atoi (pid_dir->d_name);
	      printf ("pit: killing process %d using %s\n", pid, mounted[i]);
	      kill (pid, SIGKILL);
	    }
	  }
	}
	closedir (proc_dir);
	usleep (100000);
      }

      printf ("pit: force unmounting %s\n", mounted[i]);
      if (umount2 (mounted[i], MNT_FORCE | MNT_DETACH) < 0) {
	fprintf (stderr, "pit: cannot unmount %s: %s\n",
		 mounted[i], strerror (errno));
      }
      rmdir (mounted[i]);
      free (mounted[i]);
    }
    free (mounted);
  }

  dir = opendir ("/dev/mapper");
  if (!dir) {
    fprintf (stderr, "pit: cannot open /dev/mapper: %s\n", strerror (errno));
    return -1;
  }

  while ((dp = readdir (dir)) != NULL) {
    if (strncmp (dp->d_name, MAPPER_PREFIX, strlen (MAPPER_PREFIX)) == 0) {
      struct crypt_device *cd;

      printf ("pit: force closing device %s\n", dp->d_name);
      if (crypt_init_by_name (&cd, dp->d_name) == 0) {
	crypt_deactivate_by_name (cd, dp->d_name, CRYPT_DEACTIVATE_FORCE);
	crypt_free (cd);
      }
    }
  }
  closedir (dir);

  return ret;
}

int
main (int argc, char *argv[])
{
  program_name = argv[0];

  if (argc < 2)
    usage ();

  init_sec_mem ();

  if (!strcmp (argv[1], "-v")) {
    printf ("pit-%s\n", VERSION);
    return 0;
  }
  if (!strcmp (argv[1], "-h"))
    usage ();

  if (sodium_init () < 0)
    die ("pit: cannot initialize sodium\n");

  if (!strcmp (argv[1], "key")) {
    if (argc != 3)
      usage ();
    if (generate_key (argv[2]) < 0)
      die ("pit: cannot generate key\n");
    return 0;
  }

  if (!strcmp (argv[1], "dig") ||
      !strcmp (argv[1], "open") ||
      !strcmp (argv[1], "close") || !strcmp (argv[1], "panic")) {
    if (geteuid () != 0) {
      char cmd[4096] = { 0 };
      int i;

      strcat (cmd, program_name);
      for (i = 1; i < argc; i++) {
	strcat (cmd, " ");
	strcat (cmd, argv[i]);
      }

      return run_privileged ("%s", cmd);
    }
  }

  if (!strcmp (argv[1], "dig")) {
    if (argc != 4)
      usage ();
    return create_pit (argv[2], atoi (argv[3]));
  }

  if (!strcmp (argv[1], "open")) {
    if (argc != 4)
      usage ();
    return open_pit (argv[2], argv[3]);
  }
  if (!strcmp (argv[1], "close")) {
    if (argc != 3)
      usage ();
    return close_pit (argv[2]);
  }
  if (!strcmp (argv[1], "list")) {
    if (argc != 2)
      usage ();
    return list_pits ();
  }
  if (!strcmp (argv[1], "panic")) {
    if (argc != 2)
      usage ();
    return panic_close ();
  }

  usage ();
  return 1;
}
