#include "sptree.h"
#include "utility.h"

#include <string.h>


/********************** #Structures ***********************/
struct StructSplayTree;
struct StructSplayTreeNode;

struct StructSplayTree{
    uint                        key_size;
    uint                        val_size;
    struct StructSplayTreeNode* root;
    ctl_spt_cmp                 func;
};
struct StructSplayTreeNode{
    struct StructSplayTreeNode* lchd;
    struct StructSplayTreeNode* rchd;
    struct StructSplayTreeNode* fath;
    pcvoid                      key;
    pvoid                       val;
};
typedef struct StructSplayTree     SplayTree;
typedef struct StructSplayTreeNode SplayTreeNode;

/******************* #Private functions *******************/
ctl_priv void connectL  (SplayTreeNode* fath, SplayTreeNode* lchd);
ctl_priv void connectR  (SplayTreeNode* fath, SplayTreeNode* righ);
ctl_priv void delete_dfs(SplayTreeNode* nd);

ctl_priv SplayTreeNode* acces(ctl_spt_cmp f,SplayTreeNode*nd,pcvoid k,int*t);
ctl_priv SplayTreeNode* split(ctl_spt_cmp f,SplayTreeNode*nd,pcvoid k,
                              SplayTreeNode** l, SplayTreeNode** r);
ctl_priv SplayTreeNode* merge(SplayTreeNode* l, SplayTreeNode* r);
ctl_priv SplayTreeNode* splay(SplayTreeNode* nd);

ctl_priv void splay_root(SplayTreeNode* m, SplayTreeNode* f);
ctl_priv void splay_l   (SplayTreeNode* m, SplayTreeNode* f, SplayTreeNode* g);
ctl_priv void splay_r   (SplayTreeNode* m, SplayTreeNode* f, SplayTreeNode* g);

//ctl_priv void dump_node(SplayTreeNode* now);

/************************* #Macros ************************/
#define getHeader(A) ((SplayTree*)(A))
#define getSize(A)   (sizeof(SplayTreeNode) + \
		              (A)->key_size+ (A)->val_size)
#define getKey(A)     pVoid(pChar(A) + \
			                sizeof(SplayTreeNode))
#define getValue(A,X) pVoid(pChar(A) + \
			                sizeof(SplayTreeNode) + \
			                (X))
#define disconnL(A) \
    do{\
        if((A)->lchd != NULL) (A)->lchd->fath = NULL;\
        (A)->lchd = NULL;\
    }while(0)
#define disconnR(A) \
    do{\
        if((A)->rchd != NULL) (A)->rchd->fath = NULL;\
        (A)->rchd = NULL;\
    }while(0)

/**********************************************************/
/*                #constructure / destructure             */
/**********************************************************/
pvoid ctl_sptree_initX(ppvoid sp, uint k_size, uint v_size, ctl_spt_cmp f){
    SplayTree* head = (SplayTree*)ctl_malloc(sizeof(SplayTree));
    head->key_size = k_size;
    head->val_size = v_size;
    head->root = NULL;
    head->func = f;
    if(*sp != NULL)
        *sp = pVoid(head);
    return pVoid(head);
}
pvoid ctl_sptree_freeX(ppvoid sp){
    SplayTree* head = (SplayTree*)*sp;
    if(head->root != NULL)
		delete_dfs(head->root);
    ctl_free(pVoid(head));
    *sp = NULL;
    return NULL;
}

/**********************************************************/
/*                   #get/is ??? method                   */
/**********************************************************/
int ctl_sptree_isEmptyX(ppcvoid sp){
    return (getHeader(*sp)->root == NULL ? 1 : 0);
}
uint ctl_sptree_getKeySizeX(ppcvoid sp){
    return getHeader(*sp)->key_size;
}
uint ctl_sptree_getValSizeX(ppcvoid sp){
    return getHeader(*sp)->val_size;
}

/**********************************************************/
/*           #splay tree's method  --  FIND               */
/**********************************************************/
pvoid ctl_sptree_findX(ppvoid sp, pcvoid key){
    SplayTree* head = getHeader(*sp);
	int t = 1;
	if(head->root != NULL){
		head->root = acces(head->func, head->root, key, &t);
    }
	if(t == 0)
		return head->root->val;
	else
		return NULL;
}

/**********************************************************/
/*        #splay tree's method  --  ADD (key, value)      */
/**********************************************************/
pvoid ctl_sptree_addX(ppvoid sp, pcvoid key, pvoid val){
    SplayTree* head = getHeader(*sp);
	SplayTreeNode* lft = NULL;
	SplayTreeNode* rgh = NULL;
	SplayTreeNode* mid = NULL;
	if(head->root != NULL)
		mid = split(head->func, head->root, key, &lft, &rgh);
	if(mid != NULL){
		memcpy(mid->val, val, head->val_size);
		head->root = merge(lft, rgh);
	}else{
		mid = (SplayTreeNode*)ctl_malloc(getSize(head));
		mid->key = pcVoid(getKey  (mid                ));
		mid->val =        getValue(mid, head->key_size) ;
		memcpy(getKey  (mid                ), key, head->key_size);
		memcpy(getValue(mid, head->key_size), val, head->val_size);
		mid->fath = NULL;
		mid->lchd = NULL;
		mid->rchd = NULL;
		head->root = merge(merge(lft, mid), rgh);
	}
    return mid->val;
}

/**********************************************************/
/*        #splay tree's method  --  Delete by key         */
/**********************************************************/
void ctl_sptree_delX(ppvoid sp, pcvoid key){
    SplayTree* head = getHeader(*sp);
    SplayTreeNode* left;
    SplayTreeNode* righ;
	int x;
    head->root = acces(head->func, head->root, key, &x);
	if(x == 0){
		left = head->root->lchd;
		righ = head->root->rchd;
		disconnL(head->root);
		disconnR(head->root);
		ctl_free(pVoid(head->root));
		head->root = merge(left, righ);
	}
}

/**********************************************************/
/*         #splay tree's method  --  Split by key         */
/**********************************************************/
void ctl_sptree_splitX(ppvoid sp , pcvoid key, ppvoid l, ppvoid r){
    SplayTree* head = getHeader(*sp);
    ctl_sptree_initX(l, head->key_size, head->val_size, head->func);
    ctl_sptree_initX(r, head->key_size, head->val_size, head->func);
    SplayTreeNode* left;
    SplayTreeNode* righ;
	if(head->root != NULL){
		split(head->func, head->root, key, &left, &righ);
		getHeader(*l)->root = left;
		getHeader(*r)->root = righ;
		head->root = NULL;
	}
    ctl_sptree_freeX(sp);
}
/**********************************************************/
/*      #splay tree's method  --  Big + small = merge     */
/**********************************************************/
pvoid ctl_sptree_mergeX(ppvoid sp1, ppvoid sp2){
    SplayTree* head1 = getHeader(*sp1);
    SplayTree* head2 = getHeader(*sp2);
    head1->root = merge(head1->root, head2->root);
    head2->root = NULL;
    ctl_sptree_freeX(sp2);
    *sp1 = NULL;
    return pVoid(head1);
}
    
/**********************************************************/
/*           #splay tree's method  --  clean up           */
/**********************************************************/
void ctl_sptree_clearX(ppvoid sp){
    if(getHeader(*sp)->root != NULL)
        delete_dfs(getHeader(*sp)->root);
    getHeader(*sp)->root = NULL;
}


/***************** # safe connect two nodes ***************/
ctl_priv void connectL(SplayTreeNode* fath, SplayTreeNode* lchd){
    if(fath != NULL) fath->lchd = lchd;
    if(lchd != NULL) lchd->fath = fath;
}
ctl_priv void connectR(SplayTreeNode* fath, SplayTreeNode* righ){
    if(fath != NULL) fath->rchd = righ;
    if(righ != NULL) righ->fath = fath;
}

/************* # for destruct the whole tree **************/
ctl_priv void delete_dfs(SplayTreeNode* nd){
    if(nd->lchd != NULL) delete_dfs(nd->lchd);
    if(nd->rchd != NULL) delete_dfs(nd->rchd);
    ctl_free(pVoid(nd));
}

/************* # access the node by some key **************/
ctl_priv SplayTreeNode* acces(ctl_spt_cmp f,SplayTreeNode*nd,pcvoid k,int*t){
    while(1){
		*t = f(k, nd->key);
		if     (*t < 0 && nd->lchd != NULL) nd = nd->lchd;
		else if(*t > 0 && nd->rchd != NULL) nd = nd->rchd;
		else                                break;
    }
    return (nd->fath == NULL ? nd : splay(nd));
}

/******************* # split & merge **********************/
ctl_priv SplayTreeNode* split(ctl_spt_cmp f, SplayTreeNode* nd, pcvoid k,
                              SplayTreeNode** l, SplayTreeNode** r){
	int ret;
    nd = acces(f, nd, k, &ret);
	if(ret >= 0){
        *l = nd;
        *r = nd->rchd;
        disconnR(nd);
    }else{
        *l = nd->lchd;
        *r = nd;
        disconnL(nd);
    }
	return (ret == 0 ? *l : NULL);
}
ctl_priv SplayTreeNode* merge(SplayTreeNode* l, SplayTreeNode* r){
	if(l == NULL) return r;
	while(l->rchd != NULL) l = l->rchd;
    if(l->fath != NULL) l = splay(l);
    connectR(l, r);
    return l;
}

/************* #'Splay' the node to the root **************/
ctl_priv SplayTreeNode* splay(SplayTreeNode* nd){
    while(nd->fath != NULL){
        if(nd->fath->fath == NULL){
            splay_root(nd, nd->fath);
            nd->fath = NULL;
        }else{
            SplayTreeNode* f = nd->fath;
            SplayTreeNode* g = nd->fath->fath;
            SplayTreeNode* a = nd->fath->fath->fath;
            if(a != NULL)
                if(a->lchd == g) connectL(a, nd);
                else             connectR(a, nd);
            else                 nd->fath = NULL;
            if(nd == f->lchd) splay_l(nd, f, g);
            else              splay_r(nd, f, g);
        }
    }
    return nd;
}
/**** #It's included the case of now->father == root ******/
ctl_priv void splay_root(SplayTreeNode* m, SplayTreeNode* f){
    if(f->lchd == m){
        connectL(f, m->rchd);
        connectR(m, f      );
    }else{
        connectR(f, m->lchd);
        connectL(m, f      );
    }
}

/************ #It's include the case of LL, LR ************/
ctl_priv void splay_l(SplayTreeNode* m, SplayTreeNode* f, SplayTreeNode* g){
    if(f == g->lchd){
        connectL(g, f ->rchd);
        connectL(f, m->rchd);
        connectR(f, g);
        connectR(m, f);
    }else{
        connectR(g, m->lchd);
        connectL(f, m->rchd);
        connectL(m, g);
        connectR(m, f);
    }
}

/************ #It's include the case of RR, RL ************/
ctl_priv void splay_r(SplayTreeNode* m, SplayTreeNode* f, SplayTreeNode* g){
    if(f == g->rchd){
        connectR(g, f->lchd);
        connectR(f, m->lchd);
        connectL(f, g);
        connectL(m, f);
    }else{
        connectL(g, m->rchd);
        connectR(f, m->lchd);
        connectR(m, g);
        connectL(m, f);
    }
}

/********************* # for debug ************************/
/*
#include <stdio.h>
#include <unistd.h>
ctl_priv void dump_node(SplayTreeNode* now){
    if(now == NULL) return ;
    printf("now = (%5.1f, %2d) %lld ", *(double*)now->key, *(int*)now->val, now);
    if(now->lchd == NULL) printf("lchd = (     NULL) ");
    else                  printf("lchd = (%5.1f, %2d) ", *(double*)now->lchd->key, *(int*)now->lchd->val);
    if(now->rchd == NULL) printf("rchd = (     NULL) ");
    else                  printf("rchd = (%5.1f, %2d) ", *(double*)now->rchd->key, *(int*)now->rchd->val);
    if(now->fath == NULL) printf("fath = (     NULL) ");
    else                  printf("fath = (%5.1f, %2d) ", *(double*)now->fath->key, *(int*)now->fath->val);
    printf("\n");
    dump_node(now->lchd);
    dump_node(now->rchd);
    fflush(stdout);
}
// */