pit

Owner: IIIlllIIIllI URL: git@github.com:nyangkosense/pit.git

pit.c

/* pit 
 * See LICENSE file for copyright and license details. */
#define _POSIX_C_SOURCE 200809L
#define _DEFAULT_SOURCE  
#include <ctype.h>
#include <dirent.h>
#include <errno.h>
#include <fcntl.h>
#include <libcryptsetup.h>
#include <limits.h>
#include <pwd.h>
#include <signal.h>
#include <sodium.h>
#include <stdarg.h>
#include <string.h>
#include <stdio.h>
#include <sys/mman.h>
#include <sys/prctl.h>
#include <sys/resource.h>
#include <sys/stat.h>
#include <sys/types.h>
#include <sys/wait.h>
#include <sys/mount.h>
#include <termios.h>
#include <unistd.h>

/* arbitrary sizes */
#define PIT_BLOCK_SIZE 4096
#define KEY_SIZE 32    /* 256 bit */
#define SALT_SIZE 32
#define ITER_COUNT 500000

#define VERSION "0.1"
#define MAPPER_PREFIX "pit-"
#define MOUNTPOINT_PREFIX "/mnt/pit-"
#define FS_TYPE "ext4"

/* cipher conf */
#define CIPHER "aes"
#define CIPHER_MODE "xts-plain64"
#define HASH "sha256"

/* types */
typedef struct Pit {
    char *path;
    size_t size;
    char *key;
    int ismounted;
} Pit;

/* function declarations, kept the names more verbose here */
static void die(const char *fmt, ...);
static void usage(void);
static int init_sec_mem(void);
static void *secure_alloc(size_t size);
static void secure_free(void *ptr, size_t size);
static void term_cleanup(int signo);
static int read_password(char *password, size_t size, const char *prompt);
static int exec_cmd(const char *fmt, ...);
static const char *get_username(void);
static int check_sudo_tool(const char *tool);
static int run_privileged(const char *fmt, ...);
static int generate_key(const char *keyfile);
static int read_key_file(const char *path, char **key);
static int create_pit(const char *path, size_t size);
static char *get_mapper_path(const char *path);
static int cleanup_stale_device(const char *name);
static int setup_device_mapper(const char *path, const char *key);
static int teardown_device_mapper(const char *path);
static int check_filesystem(const char *device);
static int debug_fs_info(const char *device); /* this can be removed if not needed, just used for verbose printing */
static int create_filesystem(const char *device);
static int ensure_mount_dir(void);
static int create_mount_point(const char *path);
static int mount_filesystem(const char *device, const char *mountpoint);
static int unmount_filesystem(const char *mountpoint);
static int open_pit(const char *path, const char *keyfile);
static int close_pit(const char *path);
static int list_pits(void); /* this is not used */
static int find_mounted_pits(char ***paths, int *count);
static int panic_close(void);

/* globals */
static const char *program_name;
static int term_modified = 0;
static struct termios saved_term;
static long pagesize;

/* function implementations */
static void
die(const char *fmt, ...)
{
    va_list ap;

    va_start(ap, fmt);
    vfprintf(stderr, fmt, ap);
    va_end(ap);
    exit(1);
}

static int
init_sec_mem(void)
{
    struct rlimit rlim;
    size_t required_mem = 8 * 1024 * 1024; /* 8mb minimum */

    if (getrlimit(RLIMIT_MEMLOCK, &rlim) == 0) {
        if (rlim.rlim_cur < required_mem) {

            if (geteuid() == 0) {
                rlim.rlim_cur = required_mem;
                rlim.rlim_max = required_mem;
                if (setrlimit(RLIMIT_MEMLOCK, &rlim) < 0){
                    fprintf(stderr, "pit: warning: couldn't increase memory lock limit\n");
                }
            }
        }
    }

    if (mlockall(MCL_CURRENT | MCL_FUTURE) < 0){
        fprintf(stderr, "pit: warning: couldn't lock memory pages: %s\n",
                strerror(errno));
        fprintf(stderr, "pit: sensitive data might be swapped to disk\n");
        return -1;
    }

    return 0;
}

static void *
secure_alloc(size_t size)
{
    void *ptr; 

    pagesize = sysconf(_SC_PAGESIZE);
    if (pagesize < 0) {
        fprintf(stderr, "pit: could not get system page size: %s\n",
            strerror(errno));
    }

    if (posix_memalign(&ptr, pagesize, size) != 0) {
        return NULL;
    }

    if (mlock(ptr, size) < 0) {
        free(ptr);
        return NULL;
    }

    if (madvise(ptr, size, MADV_DONTDUMP) < 0) {
        munlock(ptr, size);
        free(ptr);
        return NULL;
    }

    return ptr;
}

static void
secure_free(void *ptr, size_t size)
{
    if (ptr) {
        sodium_memzero(ptr, size);
        munlock(ptr, size);
        free(ptr);
    }
}

/* todo: format this better and more clearly */
static void 
usage(void)
{
    die("usage: pit [-v] [-h] command [arguments]\n"
        "Commands:\n"
        "  dig FILE 10                 - create new empty pit file of 10 MB size\n"
        "  key KEY.key                 - generate new encrypted key file\n"
        "  open FILE KEY.key           - open an existing pit\n"
        "  close PATH                  - close an opened pit\n"
        "  list                        - list opened pits\n"
        "  panic                       - emergency close all pits (forced)\n"
        "  example: pit dig container.pit\n"
        "           pit key container.key\n"
        "           pit open container.pit container.key\n");
}

static void
term_cleanup(int signo)
{
    if (term_modified) {
        tcsetattr(STDIN_FILENO, TCSAFLUSH, &saved_term);
        fprintf(stderr, "\n");
        term_modified = 0;
    }
    if (signo != 0) {
        exit(1);
    }
}

static int
read_password(char *password, size_t size, const char *prompt)
{
    struct termios new;
    struct sigaction sa;
    int len;

    sa.sa_handler = term_cleanup;
    sigemptyset(&sa.sa_mask);
    sa.sa_flags = 0;
    sigaction(SIGINT, &sa, NULL);
    sigaction(SIGTERM, &sa, NULL);

    if (tcgetattr(STDIN_FILENO, &saved_term) != 0)
        return -1;

    new = saved_term;
    new.c_lflag &= ~ECHO;  /* disable echo */

    if (tcsetattr(STDIN_FILENO, TCSAFLUSH, &new) != 0)
        return -1;

    term_modified = 1;  

    /* Show prompt and get password */
    fprintf(stderr, "%s", prompt);
    fflush(stderr);

    if (!fgets(password, size, stdin)) {
        term_cleanup(0);
        return -1;
    }

    term_cleanup(0);

    len = strlen(password);
    if (len > 0 && password[len-1] == '\n')
        password[--len] = '\0';

    return len;
}

static int
exec_cmd(const char *fmt, ...)
{
    char cmd[4096];
    va_list ap;
    int ret;

    va_start(ap, fmt);
    vsnprintf(cmd, sizeof(cmd), fmt, ap);
    va_end(ap);

    ret = system(cmd);
    return WEXITSTATUS(ret);
}

static const char *
get_username(void)
{
    uid_t uid = getuid();
    struct passwd *pw = getpwuid(uid);
    return pw ? pw->pw_name : NULL;
}

static int
check_sudo_tool(const char *tool)
{
    char path[PATH_MAX];
    snprintf(path, sizeof(path), "/usr/bin/%s", tool);
    return access(path, X_OK) == 0;
}

static int
run_privileged(const char *fmt, ...)
{
    char cmd[4096];
    va_list ap;
    const char *username;
    const char *sudo_tool = NULL;

    if (geteuid() == 0) {
        va_start(ap, fmt);
        vsnprintf(cmd, sizeof(cmd), fmt, ap);
        va_end(ap);
        return exec_cmd("%s", cmd);
    }

    if (check_sudo_tool("doas"))
        sudo_tool = "doas";
    else if (check_sudo_tool("sudo"))
        sudo_tool = "sudo";

    if (!sudo_tool) {
        fprintf(stderr, "pit: no privilege escalation tool found\n");
        return -1;
    }

    username = get_username();
    if (!username) {
        fprintf(stderr, "pit: cannot get username\n");
        return -1;
    }

    va_start(ap, fmt);
    vsnprintf(cmd, sizeof(cmd), fmt, ap);
    va_end(ap);

    if (strcmp(sudo_tool, "doas") == 0) {
        return exec_cmd("doas %s", cmd);
    } else {
        return exec_cmd("sudo -p '[sudo] Enter password for user %s: ' %s",
                       username, cmd);
    }
}

static int
generate_key(const char *keyfile)
{
    unsigned char *key = secure_alloc(KEY_SIZE);
    unsigned char *salt = secure_alloc(SALT_SIZE);
    unsigned char *encrypted = secure_alloc(KEY_SIZE + SALT_SIZE);
    char *password = secure_alloc(1024);
    char *verify = secure_alloc(1024);
    int pwlen;
    int ret = -1; /* def to error */
    int pwhash_result = -1;

    if (!key || !salt || !encrypted || !password || !verify) {
        fprintf(stderr, "pit: failed to allocate secure memory\n");
        goto cleanup;
    }

    if (sodium_init() < 0) {
        fprintf(stderr, "pit: failed to initialize sodium\n");
        goto cleanup;
    }

    randombytes_buf(key, KEY_SIZE);
    randombytes_buf(salt, SALT_SIZE);

    pwlen = read_password(password, 1024, "Enter password for key encryption (no echo): ");
    if (pwlen <= 0) {
        fprintf(stderr, "pit: failed to read password\n");
        goto cleanup;
    }

    if (read_password(verify, 1024, "Verify password: ") <= 0) {
        fprintf(stderr, "pit: failed to read password verification\n");
        goto cleanup;
    }

    if (strcmp(password, verify) != 0) {
        fprintf(stderr, "pit: passwords do not match\n");
        goto cleanup;
    }

    /* Try with SENSITIVE settings first */
    pwhash_result = crypto_pwhash(
        encrypted + SALT_SIZE, KEY_SIZE,
        password, pwlen,
        salt,
        crypto_pwhash_OPSLIMIT_SENSITIVE,
        crypto_pwhash_MEMLIMIT_SENSITIVE,
        crypto_pwhash_ALG_DEFAULT);

    /* If SENSITIVE fails, try MODERATE */
    if (pwhash_result != 0) {
        fprintf(stderr, "pit: key derivation failed with sensitive memory settings - trying with moderate...\n");
        pwhash_result = crypto_pwhash(
            encrypted + SALT_SIZE, KEY_SIZE,
            password, pwlen,
            salt,
            crypto_pwhash_OPSLIMIT_MODERATE,
            crypto_pwhash_MEMLIMIT_MODERATE,
            crypto_pwhash_ALG_DEFAULT);
    }

    /* If MODERATE fails, try MINIMAL */
    if (pwhash_result != 0) {
        fprintf(stderr, "pit: key derivation with moderate memory settings failed - trying minimal...\n");
        pwhash_result = crypto_pwhash(
            encrypted + SALT_SIZE, KEY_SIZE,
            password, pwlen,
            salt,
            crypto_pwhash_OPSLIMIT_MIN,
            crypto_pwhash_MEMLIMIT_MIN,
            crypto_pwhash_ALG_DEFAULT);
    }

    /* If all attempts failed */
    if (pwhash_result != 0) {
        fprintf(stderr, "pit: key derivation failed - system has insufficient memory\n");
        goto cleanup;
    }

    /* If we reached here, one of the pwhash attempts succeeded */
    memcpy(encrypted, salt, SALT_SIZE);

    FILE *f = fopen(keyfile, "wb");
    if (!f) {
        fprintf(stderr, "pit: cannot create key file: %s\n", strerror(errno));
        goto cleanup;
    }

    if (fwrite(encrypted, 1, KEY_SIZE + SALT_SIZE, f) != KEY_SIZE + SALT_SIZE) {
        fprintf(stderr, "pit: failed to write key file: %s\n", strerror(errno));
        fclose(f);
        goto cleanup;
    }

    fclose(f);
    printf("pit: key generated successfully\n");
    ret = 0; // success

cleanup:
    secure_free(key, KEY_SIZE);
    secure_free(salt, SALT_SIZE);
    secure_free(encrypted, KEY_SIZE + SALT_SIZE);
    secure_free(password, 1024);
    secure_free(verify, 1024);
    return ret;
}

static int
create_pit(const char *path, size_t size)
{
    int fd;
    char *buf = secure_alloc(PIT_BLOCK_SIZE);
    size_t remain;

    if (!buf) {
        die("pit: out of memory\n");
    }

    if (access(path, F_OK) == 0)
        die("pit: %s already exists\n", path);

    fd = open(path, O_WRONLY | O_CREAT, 0600);
    if (fd < 0)
        die("pit: cannot create %s: %s\n", path, strerror(errno));

    randombytes_buf(buf, PIT_BLOCK_SIZE);

    remain = size * 1024 * 1024; 
    while (remain > 0) {
        ssize_t nwrite = write(fd, buf, PIT_BLOCK_SIZE);
        if (nwrite < 0) {
            close(fd);
            secure_free(buf, PIT_BLOCK_SIZE);
            die("pit: write error: %s\n", strerror(errno));
        }
        remain -= nwrite;
        
        /* refresh random data every 50 blocks */
        if (remain % (50 * PIT_BLOCK_SIZE) == 0)
            randombytes_buf(buf, PIT_BLOCK_SIZE);
    }

    close(fd);
    secure_free(buf, PIT_BLOCK_SIZE);
    return 0;
}

static int
read_key_file(const char *path, char **key)
{
    FILE *f;
    unsigned char *encrypted = secure_alloc(KEY_SIZE + SALT_SIZE);
    unsigned char *decrypted = secure_alloc(KEY_SIZE);
    char *password = secure_alloc(1024);
    int pwlen;
    int ret = -1;  /* see above */

    if (!encrypted || !decrypted || !password) {
        fprintf(stderr, "pit: failed to allocate secure memory\n");
        goto cleanup;
    }

    f = fopen(path, "r");
    if (!f) {
        die("pit: cannot open key file %s: %s\n", path, strerror(errno));
    }

    if (fread(encrypted, 1, KEY_SIZE + SALT_SIZE, f) != KEY_SIZE + SALT_SIZE) {
        fclose(f);
        die("pit: invalid key file size\n");
    }
    fclose(f);

    fflush(stdout);
    if (read_password(password, 1024, "Enter password for key (no echo): ") <= 0) {
        fprintf(stderr, "pit: failed to read password\n");
        goto cleanup;
    }

    pwlen = strlen(password);
    if (pwlen > 0 && password[pwlen-1] == '\n')
        password[--pwlen] = 0;

    int r = crypto_pwhash(
        decrypted, KEY_SIZE,
        password, pwlen,
        encrypted,
        crypto_pwhash_OPSLIMIT_MIN,
        crypto_pwhash_MEMLIMIT_MIN,
        crypto_pwhash_ALG_DEFAULT);

    if (r != 0) {
        fprintf(stderr, "pit: trying with moderate memory settings ... \n");
        r = crypto_pwhash(
            decrypted, KEY_SIZE,
            password, pwlen,
            encrypted,
            crypto_pwhash_OPSLIMIT_MODERATE,
            crypto_pwhash_MEMLIMIT_MODERATE,
            crypto_pwhash_ALG_DEFAULT);
    }

    if (r !=0) {
        fprintf(stderr, "pit: trying with sensitive memory settings ...\n");
        r = crypto_pwhash(
            decrypted, KEY_SIZE,
            password, pwlen,
            encrypted,
            crypto_pwhash_OPSLIMIT_SENSITIVE,
            crypto_pwhash_MEMLIMIT_SENSITIVE,
            crypto_pwhash_ALG_DEFAULT);
    }
    
    if (r != 0) {
        fprintf(stderr, "pit: key derivation failed - insufficient memory or wrong password\n");
        goto cleanup;
    }

    *key = (char*)decrypted;
    decrypted = NULL;
    ret = 0; /* success, otherwise ret = -1 */

cleanup:
    secure_free(encrypted, KEY_SIZE + SALT_SIZE);
    if (decrypted) secure_free(decrypted, KEY_SIZE);
    secure_free(password, 1024);
    return ret;
}

static char *
get_mapper_path(const char *path)
{
    static char mapper[PATH_MAX];
    const char *name = strrchr(path, '/');
    
    name = name ? name + 1 : path;
    snprintf(mapper, sizeof(mapper), "/dev/mapper/%s%s", MAPPER_PREFIX, name);
    return mapper;
}

static int
cleanup_stale_device(const char *name)
{
    char path[PATH_MAX];
    struct stat st;
    const char *mapper_name;
    
    mapper_name = strrchr(get_mapper_path(name), '/');
    if (!mapper_name)
        return -1;
    mapper_name++; 

    snprintf(path, sizeof(path), "/dev/mapper/%s", mapper_name);
    if (stat(path, &st) == 0) {
        struct crypt_device *cd;
        int r;

        printf("pit: cleaning up stale device %s\n", mapper_name);
        
        r = crypt_init_by_name(&cd, mapper_name);
        if (r < 0) {
            fprintf(stderr, "pit: failed to init device %s\n", mapper_name);
            return -1;
        }

        r = crypt_deactivate(cd, mapper_name);
        crypt_free(cd);
        
        if (r < 0) {
            fprintf(stderr, "pit: failed to deactivate stale device %s\n", mapper_name);
            return -1;
        }
    }
    return 0;
}

static int
setup_device_mapper(const char *path, const char *key)
{
    struct crypt_device *cd;
    int r;
    const char *mapper_name;
    const char *mapper_path;

    mapper_path = get_mapper_path(path);
    if (!mapper_path)
        return -1;
    mapper_name = strrchr(mapper_path, '/') + 1;

    if (cleanup_stale_device(path) < 0){
        fprintf(stderr, "pit: failed to clean up stale device\n");
        return -1;
    }

    r = crypt_init(&cd, path);
    if (r < 0) {
        fprintf(stderr, "pit: crypt_init() failed for %s\n", path);
        return r;
    }

    r = crypt_load(cd, CRYPT_LUKS1, NULL);
    if (r == 0) {
   
        r = crypt_activate_by_passphrase(cd, mapper_name, 
                                       CRYPT_ANY_SLOT,
                                       key, KEY_SIZE, 0);
        if (r < 0) {
            fprintf(stderr, "pit: failed to activate device %s\n", path);
            crypt_free(cd);
            return r;
        }
    } else {

        struct crypt_params_luks1 params = {
            .hash = HASH,
            .data_alignment = 0,
            .data_device = NULL
        };

        r = crypt_format(cd, CRYPT_LUKS1, CIPHER, CIPHER_MODE,
                        NULL, key, KEY_SIZE, &params);
        if (r < 0) {
            fprintf(stderr, "pit: failed to format device %s\n", path);
            crypt_free(cd);
            return r;
        }

        r = crypt_keyslot_add_by_volume_key(cd, 0, NULL, 0,
                                           key, KEY_SIZE);
        if (r < 0) {
            fprintf(stderr, "pit: failed to add keyslot\n");
            crypt_free(cd);
            return r;
        }

        r = crypt_activate_by_passphrase(cd, mapper_name,
                                       CRYPT_ANY_SLOT,
                                       key, KEY_SIZE, 0);
        if (r < 0) {
            fprintf(stderr, "pit: failed to activate device %s\n", path);
            crypt_free(cd);
            return r;
        }
    }

    crypt_free(cd);
    return 0;
}

static int
teardown_device_mapper(const char *path)
{
    struct crypt_device *cd;
    int r;
    const char *mapper_path;

    mapper_path = get_mapper_path(path);
    if (!mapper_path)
        return -1;

    r = crypt_init_by_name(&cd, strrchr(mapper_path, '/') + 1);
    if (r < 0) {
        fprintf(stderr, "pit: crypt_init_by_name() failed for %s\n", mapper_path);
        return r;
    }

    r = crypt_deactivate(cd, strrchr(mapper_path, '/') + 1);
    if (r < 0) {
        fprintf(stderr, "pit: failed to deactivate device %s\n", mapper_path);
        crypt_free(cd);
        return r;
    }

    crypt_free(cd);
    return 0;
}

static int
check_filesystem(const char *device)
{
    int fd;
    unsigned char buf[2048];  /* ext4 superblock starts at offset 1024 */
    ssize_t bytes_read;

    fd = open(device, O_RDONLY);
    if (fd < 0)
        return 0;

    if (lseek(fd, 1024, SEEK_SET) != 1024) {
        close(fd);
        return 0;
    }

    bytes_read = read(fd, buf, sizeof(buf));
    close(fd);

    if (bytes_read < 2)
        return 0;

    /* check ext4 superblock magic number at offset 0x38 (56) */
    return (buf[0x38] == 0x53 && buf[0x39] == 0xEF);
}

static int
debug_fs_info(const char *device)
{
    pid_t pid;
    int status;

    pid = fork();
    if (pid < 0) {
        return -1;
    }

    if (pid == 0) {
        execl("/sbin/tune2fs", "tune2fs", "-l", device, NULL);
        _exit(1);
    }

    waitpid(pid, &status, 0);
    return WIFEXITED(status) ? WEXITSTATUS(status) : -1;
}

static int
open_pit(const char *path, const char *keyfile)
{
    char *key;
    struct stat st;
    int r;
    char mount_dir[PATH_MAX];
    const char *name;
    const char *mapper_path;

    if (stat(path, &st) < 0)
        die("pit: cannot stat %s: %s\n", path, strerror(errno));

    if (!S_ISREG(st.st_mode))
        die("pit: %s is not a regular file\n", path);

    if (read_key_file(keyfile, &key) < 0)
        return -1;

    r = setup_device_mapper(path, key);
    free(key);

    if (r < 0)
        die("pit: failed to setup device mapper\n");

    mapper_path = get_mapper_path(path);
    if (!mapper_path)
        die("pit: failed to get mapper path\n");

    /* only create filesystem if one doesn't exist, otherwise it erases the data */
    printf("pit: checking fs on %s\n", mapper_path);
    if (!check_filesystem(mapper_path)) {
        printf("pit: no filesystem detected, creating new one\n");
        if (create_filesystem(mapper_path) < 0) {
            teardown_device_mapper(path);
            die("pit: failed to create filesystem\n");
        }
    } else {
        printf("pit: existing filesystem found\n");
        //debug_fs_info(mapper_path);
    }

    /* Create mount point */
    if (create_mount_point(path) < 0) {
        teardown_device_mapper(path);
        die("pit: failed to create mount point\n");
    }

    name = strrchr(path, '/');
    name = name ? name + 1 : path;
    snprintf(mount_dir, sizeof(mount_dir), "%s%s", MOUNTPOINT_PREFIX, name);

    if (mount_filesystem(mapper_path, mount_dir) < 0) {
        teardown_device_mapper(path);
        rmdir(mount_dir);
        die("pit: failed to mount filesystem\n");
    }

    printf("pit: successfully opened %s on %s\n", path, mount_dir);
    return 0;
}

/* todo .. */
static int
list_pits(void)
{
    char **mounted_paths = NULL;
    int count = 0;
    int ret = 0;
    int i = 0;
    
    if (find_mounted_pits(&mounted_paths, &count) < 0) {
        fprintf(stderr, "pit: failed to find mounted pits\n");
        return -1;
    }
    
    if (count == 0) {
        printf("pit: no pits currently mounted\n");
    } else {
        printf("pit: %d pit%s currently mounted:\n", count, count == 1 ? "" : "s");
        
        for (i = 0; i < count; i++) {
            if (mounted_paths[i]) {
                char device_path[PATH_MAX] = {0};
                FILE *mtab = fopen("/proc/mounts", "r");
                if (mtab) {
                    char line[PATH_MAX];
                    while (fgets(line, sizeof(line), mtab)) {
                        char mnt_path[PATH_MAX] = {0};
                        char dev_path[PATH_MAX] = {0};
                        sscanf(line, "%s %s", dev_path, mnt_path);
                        
                        if (strcmp(mnt_path, mounted_paths[i]) == 0) {
                            strncpy(device_path, dev_path, PATH_MAX - 1);
                            break;
                        }
                    }
                    fclose(mtab);
                }
                
                if (device_path[0]) {
                    const char *mapper_prefix = "/dev/mapper/pit-";
                    if (strncmp(device_path, mapper_prefix, strlen(mapper_prefix)) == 0) {
                        printf("  %s -> %s\n", device_path + strlen(mapper_prefix), mounted_paths[i]);
                    } else {
                        printf("  %s -> %s\n", device_path, mounted_paths[i]);
                    }
                } else {
                    printf("  %s\n", mounted_paths[i]);
                }
                
                free(mounted_paths[i]);
            }
        }
    }
    
    free(mounted_paths);
    return ret;
}

static int
ensure_mount_dir(void)
{
    struct stat st;

    if (stat("/mnt", &st) < 0) {
        if (mkdir("/mnt", 0755) < 0) {
            fprintf(stderr, "pit: cannot create /mnt: %s\n", strerror(errno));
            return -1;
        }
    }
    return 0;
}

static int
create_mount_point(const char *path)
{
    char mount_dir[PATH_MAX];
    const char *name;

    if (ensure_mount_dir() < 0)
        return -1;

    name = strrchr(path, '/');
    name = name ? name + 1 : path;

    snprintf(mount_dir, sizeof(mount_dir), "%s%s", MOUNTPOINT_PREFIX, name);

    if (mkdir(mount_dir, 0700) < 0 && errno != EEXIST) {
        fprintf(stderr, "pit: cannot create mountpoint %s: %s\n",
                mount_dir, strerror(errno));
        return -1;
    }

    return 0;
}

static int
create_filesystem(const char *device)
{
    pid_t pid;
    int status;

    pid = fork();
    if (pid < 0) {
        fprintf(stderr, "pit: fork failed: %s\n", strerror(errno));
        return -1;
    }

    if (pid == 0) {
        
        execl("/sbin/mkfs.ext4", "mkfs.ext4", "-q", "-F", device, NULL);
        fprintf(stderr, "pit: exec mkfs.ext4 failed: %s\n", strerror(errno));
        _exit(1);
    }

    if (waitpid(pid, &status, 0) < 0) {
        fprintf(stderr, "pit: waitpid failed: %s\n", strerror(errno));
        return -1;
    }

    if (!WIFEXITED(status) || WEXITSTATUS(status) != 0) {
        fprintf(stderr, "pit: mkfs.ext4 failed\n");
        return -1;
    }

    return 0;
}

static int
mount_filesystem(const char *device, const char *mountpoint)
{
    unsigned long flags = 0;
    if (mount(device, mountpoint, FS_TYPE, flags, NULL) < 0) {
        fprintf(stderr, "pit: mount failed: %s\n", strerror(errno));
        return -1;
    }
    return 0;
}

static int
unmount_filesystem(const char *mountpoint)
{
    if (umount(mountpoint) < 0) {
        fprintf(stderr, "pit: umount failed: %s\n", strerror(errno));
        return -1;
    }
    if (rmdir(mountpoint) < 0) {
        fprintf(stderr, "pit: rmdir failed: %s\n", strerror(errno));
        return -1;
    }
    return 0;
}

static int
close_pit(const char *path)
{
    struct stat st;
    int r;
    char mount_dir[PATH_MAX];
    const char *name;

    if (stat(path, &st) < 0)
        die("pit: cannot stat %s: %s\n", path, strerror(errno));

    if (!S_ISREG(st.st_mode))
        die("pit: %s is not a regular file\n", path);

    /* Unmount filesystem */
    name = strrchr(path, '/');
    name = name ? name + 1 : path;
    snprintf(mount_dir, sizeof(mount_dir), "%s%s", MOUNTPOINT_PREFIX, name);

    if (unmount_filesystem(mount_dir) < 0)
        die("pit: failed to unmount filesystem\n");

    /* Teardown device mapper */
    r = teardown_device_mapper(path);
    if (r < 0)
        die("pit: failed to teardown device mapper\n");

    printf("pit: successfully closed %s\n", path);
    return 0;
}

/* this is a bit hacky, might need improvements, for now it works */
static int
find_mounted_pits(char ***paths, int *count)
{
    FILE *mtab;
    char line[PATH_MAX];
    char **list = NULL;
    int n = 0;
    
    mtab = fopen("/proc/mounts", "r");
    if (!mtab) {
        fprintf(stderr, "pit: cannot open /proc/mounts\n");
        return -1;
    }

    while (fgets(line, sizeof(line), mtab)) {
        if (strstr(line, MOUNTPOINT_PREFIX)) {
            char *space = strchr(line, ' ');
            if (!space)
                continue;
            
            char **new_list = realloc(list, (n + 1) * sizeof(char *));
            if (!new_list) {
                fprintf(stderr, "pit: out of memory\n");
                fclose(mtab);
                for (int i = 0; i < n; i++)
                    free(list[i]);
                free(list);
                return -1;
            }
            list = new_list;

            *space = '\0';
            space++;
            char *mountpoint = space;
            space = strchr(mountpoint, ' ');
            if (!space)
                continue;
            *space = '\0';

            list[n] = strdup(mountpoint);
            if (!list[n]) {
                fprintf(stderr, "pit: out of memory\n");
                fclose(mtab);
                for (int i = 0; i < n; i++)
                    free(list[i]);
                free(list);
                return -1;
            }
            n++;
        }
    }

    fclose(mtab);
    *paths = list;
    *count = n;
    return 0;
}

/* panic_close forcefully, without any save attempt, closes the mounts by pit */
static int
panic_close(void)
{
    DIR *dir;
    struct dirent *dp;
    int ret = 0;
    char **mounted = NULL;
    int count = 0;

    if (geteuid() != 0) {
        return run_privileged("%s panic", program_name);
    }

    /* first find our mounted pits */
    if (find_mounted_pits(&mounted, &count) == 0 && count > 0) {
        printf("pit: force closing %d containers...\n", count);
        
        for (int i = 0; i < count; i++) {
            if (!mounted[i]) continue;

            /* check /proc for processes using this mount */
            DIR *proc_dir = opendir("/proc");
            if (proc_dir) {
                struct dirent *pid_dir;
                while ((pid_dir = readdir(proc_dir)) != NULL) {
                    if (!isdigit(pid_dir->d_name[0])) 
                        continue;

                    char path[PATH_MAX], link[PATH_MAX];
                    snprintf(path, sizeof(path), "/proc/%s/cwd", pid_dir->d_name);
                    
                    ssize_t len = readlink(path, link, sizeof(link) - 1);
                    if (len > 0) {
                        link[len] = '\0';
                        if (strstr(link, mounted[i])) {
                            pid_t pid = atoi(pid_dir->d_name);
                            printf("pit: killing process %d using %s\n", pid, mounted[i]);
                            kill(pid, SIGKILL);
                        }
                    }
                }
                closedir(proc_dir);
                usleep(100000); /* small delay for processes to die */
            }

            /* Force unmount */
            printf("pit: force unmounting %s\n", mounted[i]);
            if (umount2(mounted[i], MNT_FORCE | MNT_DETACH) < 0) {
                fprintf(stderr, "pit: cannot unmount %s: %s\n", 
                        mounted[i], strerror(errno));
            }
            rmdir(mounted[i]);
            free(mounted[i]);
        }
        free(mounted);
    }

    /* force close all pit mappers */
    dir = opendir("/dev/mapper");
    if (!dir) {
        fprintf(stderr, "pit: cannot open /dev/mapper: %s\n", strerror(errno));
        return -1;
    }

    while ((dp = readdir(dir)) != NULL) {
        if (strncmp(dp->d_name, MAPPER_PREFIX, strlen(MAPPER_PREFIX)) == 0) {
            printf("pit: force closing device %s\n", dp->d_name);
            struct crypt_device *cd;
            if (crypt_init_by_name(&cd, dp->d_name) == 0) {
                crypt_deactivate_by_name(cd, dp->d_name, CRYPT_DEACTIVATE_FORCE);
                crypt_free(cd);
            }
        }
    }
    closedir(dir);

    return ret;
}

int
main(int argc, char *argv[])
{
    program_name = argv[0];

    if (argc < 2)
        usage();

    init_sec_mem();

    if (!strcmp(argv[1], "-v")) {
        printf("pit-%s\n", VERSION);
        return 0;
    }
    if (!strcmp(argv[1], "-h"))
        usage();

    if (sodium_init() < 0)
        die("pit: cannot initialize sodium\n");

    if (!strcmp(argv[1], "key")) {
        if (argc != 3)  
            usage();
        if (generate_key(argv[2]) < 0)
            die("pit: cannot generate key\n");
        return 0;
    }

    if (!strcmp(argv[1], "dig") || 
        !strcmp(argv[1], "open") ||
        !strcmp(argv[1], "close") ||
        !strcmp(argv[1], "panic")) {
        
        if (geteuid() != 0) {

            char cmd[4096] = {0};
            int i;
            
            strcat(cmd, program_name);
            for (i = 1; i < argc; i++) {
                strcat(cmd, " ");
                strcat(cmd, argv[i]);
            }
            
            return run_privileged("%s", cmd);
        }
    }

    if (!strcmp(argv[1], "dig")) {
        if (argc != 4)
            usage();
        return create_pit(argv[2], atoi(argv[3]));
    }

    if (!strcmp(argv[1], "open")) {
        if (argc != 4)
            usage();
        return open_pit(argv[2], argv[3]);
    }
    if (!strcmp(argv[1], "close")) {
        if (argc != 3)
            usage();
        return close_pit(argv[2]);
    }
    if (!strcmp(argv[1], "list")) {
        if (argc != 2)
            usage();
        return list_pits(); /* this is not implemented as of now */
    }
    if (!strcmp(argv[1], "panic")) {
        if (argc != 2)
            usage();
        return panic_close();
    }

    usage();
    return 1;
}