/*
 * MIT License
 *
 * Copyright (c) 2026
 *
 * Permission is hereby granted, free of charge, to any person obtaining a copy
 * of this software and associated documentation files (the "Software"), to deal
 * in the Software without restriction, including without limitation the rights
 * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
 * copies of the Software, and to permit persons to whom the Software is
 * furnished to do so, subject to the following conditions:
 *
 * The above copyright notice and this permission notice shall be included in all
 * copies or substantial portions of the Software.
 *
 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
 * SOFTWARE.
 */

#include "avl_heap.h"
#include <stdlib.h>
#include <stdio.h>
#include <string.h>
#include <sys/mman.h>
#include <math.h>

// On Debian bookworm, gcc -std=c11 doesn't define MAP_ANONYMOUS.  -std=gnu17 does define it.
#ifndef MAP_ANONYMOUS
#define MAP_ANONYMOUS 0x20
#endif

#define INITIAL_COMMIT (128 * 1024)
#define PAGE_SIZE 4096

#define EMPTY        0
#define BALANCED     3

/* Balance bit masks and operations - will be customized per heap */
#define TOP_BALANCE_MASK      0xC000000000000000ULL  /* bits 62-63 */
#define TOP_VALUE_MASK        0x3FFFFFFFFFFFFFFFULL  /* bits 0-61 */
#define BOTTOM_BALANCE_MASK   0x0000000000000003ULL  /* bits 0-1 */
#define BOTTOM_VALUE_MASK     0xFFFFFFFFFFFFFFFCULL  /* bits 2-63 */

#define LEFT_CHILD(i)   (2 * (i) + 1)
#define RIGHT_CHILD(i)  (2 * (i) + 2)

typedef struct {
    void *base;
    size_t total_size;
    size_t committed;
    size_t used;
} MemoryArena;

struct AVLHeap {
    uint64_t *heap;
    size_t capacity;
    size_t size;
    size_t element_words;
    size_t balance_word_idx;      /* Which word contains balance bits */
    BalanceBitLocation balance_location;  /* TOP or BOTTOM bits */
    uint64_t balance_mask;        /* Mask for extracting balance bits */
    uint64_t value_mask;          /* Mask for extracting value */
    int balance_shift;            /* Shift amount for balance bits */
    MemoryArena *arena;
    const char *name;
};

/* -----------------------------------
 * arena_create
 */
static MemoryArena* arena_create(size_t max_size)
{
    void *base = mmap(NULL, max_size, PROT_NONE, MAP_PRIVATE | MAP_ANONYMOUS, -1, 0);
    if (base == MAP_FAILED) return NULL;
    if (mprotect(base, INITIAL_COMMIT, PROT_READ | PROT_WRITE) != 0) {
        munmap(base, max_size);
        return NULL;
    }
    MemoryArena *arena = malloc(sizeof(MemoryArena));
    if (!arena) {
        munmap(base, max_size);
        return NULL;
    }
    arena->base = base;
    arena->total_size = max_size;
    arena->committed = INITIAL_COMMIT;
    arena->used = 0;
    return arena;
}

/* -----------------------------------
 * arena_grow
 */
static int arena_grow(MemoryArena *arena, size_t new_committed)
{
    if (new_committed > arena->total_size || new_committed <= arena->committed)
        return (new_committed <= arena->committed) ? 0 : -1;

    new_committed = (new_committed + PAGE_SIZE - 1) & ~(PAGE_SIZE - 1);
    void *grow_start = (char*)arena->base + arena->committed;
    size_t grow_size = new_committed - arena->committed;

    if (mprotect(grow_start, grow_size, PROT_READ | PROT_WRITE) != 0)
        return -1;

    arena->committed = new_committed;
    return 0;
}

/* -----------------------------------
 * arena_alloc
 */
static void* arena_alloc(MemoryArena *arena, size_t size)
{
    size = (size + 7) & ~7;
    size_t new_used = arena->used + size;

    if (new_used > arena->committed) {
        size_t new_committed = arena->committed * 2;
        if (new_committed < new_used)
            new_committed = (new_used + PAGE_SIZE - 1) & ~(PAGE_SIZE - 1);
        if (arena_grow(arena, new_committed) != 0)
            return NULL;
    }

    void *ptr = (char*)arena->base + arena->used;
    arena->used = new_used;
    memset(ptr, 0, size);
    return ptr;
}

/* -----------------------------------
 * arena_destroy
 */
static void arena_destroy(MemoryArena *arena)
{
    if (arena) {
        munmap(arena->base, arena->total_size);
        free(arena);
    }
}

/* -----------------------------------
 * element_ptr
 */
static inline uint64_t* element_ptr(AVLHeap *avlh, size_t idx)
{
    return avlh->heap + (idx * avlh->element_words);
}

/* -----------------------------------
 * balance_word_ptr - Get pointer to word containing balance bits
 */
static inline uint64_t* balance_word_ptr(AVLHeap *avlh, size_t idx)
{
    return element_ptr(avlh, idx) + avlh->balance_word_idx;
}

/* -----------------------------------
 * element_get_balance
 */
static inline uint8_t element_get_balance(AVLHeap *avlh, size_t idx)
{
    if (idx >= avlh->capacity) return EMPTY;
    uint64_t word = *balance_word_ptr(avlh, idx);
    return (word & avlh->balance_mask) >> avlh->balance_shift;
}

/* -----------------------------------
 * element_get_key
 */
static inline uint64_t element_get_key(AVLHeap *avlh, size_t idx)
{
    return (*balance_word_ptr(avlh, idx)) & avlh->value_mask;
}

/* -----------------------------------
 * element_set
 */
static inline void element_set(AVLHeap *avlh, size_t idx, uint64_t key, uint8_t balance, const uint64_t *data)
{
    uint64_t *elem = element_ptr(avlh, idx);
    uint64_t *balance_word = elem + avlh->balance_word_idx;
    
    /* Encode balance and key into the balance word */
    uint64_t encoded = (key & avlh->value_mask) | 
                       (((uint64_t)balance << avlh->balance_shift) & avlh->balance_mask);
    *balance_word = encoded;
    
    /* Copy data to remaining words, skipping the balance word if needed */
    if (data != NULL && avlh->element_words > 1) {
        size_t data_words = avlh->element_words - 1;
        if (avlh->balance_word_idx == 0) {
            /* Balance is in first word, data starts at word 1 */
            memcpy(&elem[1], data, data_words * sizeof(uint64_t));
        } else {
            /* Balance is not in first word, copy data around it */
            if (avlh->balance_word_idx > 0) {
                memcpy(elem, data, avlh->balance_word_idx * sizeof(uint64_t));
            }
            if (avlh->balance_word_idx < avlh->element_words - 1) {
                memcpy(&elem[avlh->balance_word_idx + 1], 
                       &data[avlh->balance_word_idx],
                       (avlh->element_words - avlh->balance_word_idx - 1) * sizeof(uint64_t));
            }
        }
    }
}

/* -----------------------------------
 * element_clear
 */
static inline void element_clear(AVLHeap *avlh, size_t idx)
{
    memset(element_ptr(avlh, idx), 0, avlh->element_words * sizeof(uint64_t));
}

/* -----------------------------------
 * element_get_data
 */
static inline uint64_t* element_get_data(AVLHeap *avlh, size_t idx)
{
    if (avlh->element_words <= 1) return NULL;
    
    /* Return pointer to first data word (skipping balance word if it's first) */
    if (avlh->balance_word_idx == 0) {
        return element_ptr(avlh, idx) + 1;
    } else {
        return element_ptr(avlh, idx);
    }
}

/* -----------------------------------
 * avlh_init
 */
AVLHeap* avlh_init(const char *name, size_t initial_capacity, size_t element_words, 
                   size_t max_memory, size_t balance_word_idx, BalanceBitLocation balance_location)
{
    if (element_words < 1) return NULL;
    if (balance_word_idx >= element_words) return NULL;

    AVLHeap *avlh = malloc(sizeof(AVLHeap));
    if (!avlh) return NULL;

    if (max_memory == 0) max_memory = GiB;

    avlh->arena = arena_create(max_memory);
    if (!avlh->arena) {
        free(avlh);
        return NULL;
    }

    avlh->element_words = element_words;
    avlh->balance_word_idx = balance_word_idx;
    avlh->balance_location = balance_location;
    
    /* Configure balance bit masks and shifts based on location */
    if (balance_location == BALANCE_TOP_BITS) {
        avlh->balance_mask = TOP_BALANCE_MASK;
        avlh->value_mask = TOP_VALUE_MASK;
        avlh->balance_shift = 62;
    } else {  /* BALANCE_BOTTOM_BITS */
        avlh->balance_mask = BOTTOM_BALANCE_MASK;
        avlh->value_mask = BOTTOM_VALUE_MASK;
        avlh->balance_shift = 0;
    }
    
    size_t initial_bytes = initial_capacity * element_words * sizeof(uint64_t);
    avlh->heap = arena_alloc(avlh->arena, initial_bytes);

    if (!avlh->heap) {
        arena_destroy(avlh->arena);
        free(avlh);
        return NULL;
    }

    avlh->capacity = initial_capacity;
    avlh->size = 0;
    avlh->name = name;

    return avlh;
}

/* -----------------------------------
 * avlh_free
 */
void avlh_free(AVLHeap *avlh)
{
    if (avlh) {
        arena_destroy(avlh->arena);
        free(avlh);
    }
}

/* -----------------------------------
 * avlh_ensure_capacity
 */
static void avlh_ensure_capacity(AVLHeap *avlh, size_t idx)
{
    if (idx >= avlh->capacity) {
        size_t new_capacity = avlh->capacity;
        while (idx >= new_capacity) new_capacity *= 2;

        size_t new_bytes = new_capacity * avlh->element_words * sizeof(uint64_t);
        uint64_t *new_heap = arena_alloc(avlh->arena, new_bytes);
        if (!new_heap) exit(1);

        memcpy(new_heap, avlh->heap, avlh->capacity * avlh->element_words * sizeof(uint64_t));
        avlh->heap = new_heap;
        avlh->capacity = new_capacity;
    }
}

/* -----------------------------------
 * avlh_add
 */
int avlh_add(AVLHeap *avlh, uint64_t key, const uint64_t *data)
{
    /* Validate key based on balance bit location */
    if (avlh->balance_location == BALANCE_TOP_BITS) {
        /* For top bits, key must be < 2^62 */
        if (key >= (1ULL << 62)) return 0;
    } else {
        /* For bottom bits, key must be 4-byte aligned (bottom 2 bits must be 0) */
        if (key & 0x3) return 0;
    }

    size_t idx = 0;
    avlh_ensure_capacity(avlh, idx);

    if (element_get_balance(avlh, idx) == EMPTY) {
        element_set(avlh, idx, key, BALANCED, data);
        avlh->size++;
        return 1;
    }

    while (1) {
        uint64_t current_key = element_get_key(avlh, idx);
        if (key == current_key) return 0;

        size_t next_idx = (key < current_key) ? LEFT_CHILD(idx) : RIGHT_CHILD(idx);
        avlh_ensure_capacity(avlh, next_idx);

        if (element_get_balance(avlh, next_idx) == EMPTY) {
            element_set(avlh, next_idx, key, BALANCED, data);
            avlh->size++;
            return 1;
        }
        idx = next_idx;
    }
}

/* -----------------------------------
 * avlh_find
 */
int avlh_find(AVLHeap *avlh, uint64_t key)
{
    if (!avlh || avlh->size == 0) return 0;

    size_t idx = 0;
    while (idx < avlh->capacity && element_get_balance(avlh, idx) != EMPTY) {
        uint64_t current_key = element_get_key(avlh, idx);
        if (key == current_key) return 1;
        idx = (key < current_key) ? LEFT_CHILD(idx) : RIGHT_CHILD(idx);
    }
    return 0;
}

/* -----------------------------------
 * avlh_find_element
 */
uint64_t* avlh_find_element(AVLHeap *avlh, uint64_t key)
{
    if (!avlh || avlh->size == 0) return NULL;

    size_t idx = 0;
    while (idx < avlh->capacity && element_get_balance(avlh, idx) != EMPTY) {
        uint64_t current_key = element_get_key(avlh, idx);
        if (key == current_key) {
            return element_get_data(avlh, idx);
        }
        idx = (key < current_key) ? LEFT_CHILD(idx) : RIGHT_CHILD(idx);
    }
    return NULL;
}

/* -----------------------------------
 * avlh_find_min
 */
static size_t find_min(AVLHeap *avlh, size_t idx)
{
    while (LEFT_CHILD(idx) < avlh->capacity &&
           element_get_balance(avlh, LEFT_CHILD(idx)) != EMPTY) {
        idx = LEFT_CHILD(idx);
    }
    return idx;
}

/* -----------------------------------
 * avlh_delete
 */
int avlh_delete(AVLHeap *avlh, uint64_t key)
{
    if (!avlh || avlh->size == 0) return 0;

    size_t idx = 0;
    while (idx < avlh->capacity && element_get_balance(avlh, idx) != EMPTY) {
        uint64_t current_key = element_get_key(avlh, idx);

        if (key == current_key) {
            size_t left = LEFT_CHILD(idx);
            size_t right = RIGHT_CHILD(idx);
            int has_left = (left < avlh->capacity && element_get_balance(avlh, left) != EMPTY);
            int has_right = (right < avlh->capacity && element_get_balance(avlh, right) != EMPTY);

            if (!has_left && !has_right) {
                element_clear(avlh, idx);
            } else if (!has_left) {
                uint64_t *right_elem = element_ptr(avlh, right);
                memcpy(element_ptr(avlh, idx), right_elem,
                       avlh->element_words * sizeof(uint64_t));
                element_clear(avlh, right);
            } else if (!has_right) {
                uint64_t *left_elem = element_ptr(avlh, left);
                memcpy(element_ptr(avlh, idx), left_elem,
                       avlh->element_words * sizeof(uint64_t));
                element_clear(avlh, left);
            } else {
                size_t successor_idx = find_min(avlh, right);
                uint64_t *succ_elem = element_ptr(avlh, successor_idx);
                memcpy(element_ptr(avlh, idx), succ_elem,
                       avlh->element_words * sizeof(uint64_t));
                element_clear(avlh, successor_idx);
            }
            avlh->size--;
            return 1;
        }

        size_t left_child = LEFT_CHILD(idx);
        size_t right_child = RIGHT_CHILD(idx);
        idx = (key < current_key) ? left_child : right_child;
    }
    return 0;
}

/* -----------------------------------
 * print_inorder_helper
 */
static void print_inorder_helper(AVLHeap *avlh, size_t idx, size_t *count)
{
    if (idx >= avlh->capacity || element_get_balance(avlh, idx) == EMPTY)
        return;

    print_inorder_helper(avlh, LEFT_CHILD(idx), count);

    uint64_t key = element_get_key(avlh, idx);
    uint64_t *data = element_get_data(avlh, idx);

    printf("  [%zu] Key=%llu", *count, (unsigned long long)key);
    if (data) {
        printf(", Data=(");
        for (size_t i = 0; i < avlh->element_words - 1; i++) {
            if (i > 0) printf(", ");
            printf("%llu", (unsigned long long)data[i]);
        }
        printf(")");
    }
    printf("\n");
    (*count)++;

    print_inorder_helper(avlh, RIGHT_CHILD(idx), count);
}

/* -----------------------------------
 * avlh_print_inorder
 */
void avlh_print_inorder(AVLHeap *avlh)
{
    if (!avlh) return;
    printf("In-order traversal of '%s':\n", avlh->name);
    size_t count = 0;
    print_inorder_helper(avlh, 0, &count);
}

/* -----------------------------------
 * print_preorder_helper
 */
static void print_preorder_helper(AVLHeap *avlh, size_t idx, size_t *count)
{
    if (idx >= avlh->capacity || element_get_balance(avlh, idx) == EMPTY)
        return;

    uint64_t key = element_get_key(avlh, idx);
    printf("  [%zu] Key=%llu\n", *count, (unsigned long long)key);
    (*count)++;

    print_preorder_helper(avlh, LEFT_CHILD(idx), count);
    print_preorder_helper(avlh, RIGHT_CHILD(idx), count);
}

/* -----------------------------------
 * avlh_print_preorder
 */
void avlh_print_preorder(AVLHeap *avlh)
{
    if (!avlh) return;
    printf("Pre-order traversal of '%s':\n", avlh->name);
    size_t count = 0;
    print_preorder_helper(avlh, 0, &count);
}

/* -----------------------------------
 * print_postorder_helper
 */
static void print_postorder_helper(AVLHeap *avlh, size_t idx, size_t *count)
{
    if (idx >= avlh->capacity || element_get_balance(avlh, idx) == EMPTY)
        return;

    print_postorder_helper(avlh, LEFT_CHILD(idx), count);
    print_postorder_helper(avlh, RIGHT_CHILD(idx), count);

    uint64_t key = element_get_key(avlh, idx);
    printf("  [%zu] Key=%llu\n", *count, (unsigned long long)key);
    (*count)++;
}

/* -----------------------------------
 * avlh_print_postorder
 */
void avlh_print_postorder(AVLHeap *avlh)
{
    if (!avlh) return;
    printf("Post-order traversal of '%s':\n", avlh->name);
    size_t count = 0;
    print_postorder_helper(avlh, 0, &count);
}

/* -----------------------------------
 * avlh_stats
 */
void avlh_stats(AVLHeap *avlh)
{
    if (!avlh) return;

    printf("\nHeap '%s' (%zu-word elements):\n", avlh->name, avlh->element_words);
    printf("  Base address:  %p\n", avlh->arena->base);
    printf("  Element size:  %zu bytes\n", avlh->element_words * 8);
    printf("  Reserved:      %.2f MiB\n", avlh->arena->total_size / (1024.0 * 1024.0));
    printf("  Committed:     %.2f KiB\n", avlh->arena->committed / 1024.0);
    printf("  Used:          %.2f KiB\n", avlh->arena->used / 1024.0);
    printf("  Elements:      %zu\n", avlh->size);
    printf("  Capacity:      %zu\n", avlh->capacity);
    printf("  Storage:       %.2f KiB\n",
           (avlh->capacity * avlh->element_words * 8) / 1024.0);
}

/* -----------------------------------
 * avlh_get_size
 */
size_t avlh_get_size(AVLHeap *avlh)
{
    return avlh ? avlh->size : 0;
}
