/* Hilbert space filling curve computation
 * Copyright 2007, Geoffrey Irving
 * 25aug2005 */

#include <stdlib.h>
#include <stdio.h>
#include <png.h>

/*** types ***/

typedef unsigned int uint;
typedef struct {uint x,y;} pair;
static inline pair make_pair(uint x,uint y){pair r;r.x=x;r.y=y;return r;}

/*** helper functions ***/

static inline uint spread(uint t)
{
    t&=0x0000ffff;
    t=(t&0x00ff00ff)|((t&0xff00ff00)<<8);
    t=(t&0x0f0f0f0f)|((t&0xf0f0f0f0)<<4);
    t=(t&0x33333333)|((t&0xcccccccc)<<2);
    t=(t&0x55555555)|((t&0xaaaaaaaa)<<1);
    return t;
}

static inline uint interleave(uint x,uint y)
{
    return (spread(x)<<1)+spread(y);
}

static inline uint unspread(uint t)
{
    t&=0x55555555;
    t=(t&0x33333333)|((t&0xcccccccc)>>1);
    t=(t&0x0f0f0f0f)|((t&0xf0f0f0f0)>>2);
    t=(t&0x00ff00ff)|((t&0xff00ff00)>>4);
    t=(t&0x0000ffff)|((t&0xffff0000)>>8);
    return t;
}

static inline pair uninterleave(uint t)
{
    return make_pair(unspread(t>>1),unspread(t));
}

#define increment_uninterleave(x,y)  \
do{                                  \
    uint c=x&y;c^=c+1;c>>=1;         \
    uint nc=c+1;                     \
    x^=c^(y&nc);y^=c^nc;             \
}while(0)

static inline uint prefix_parity(uint t)
{
    t^=t>>1;t^=t>>2;t^=t>>4;t^=t>>8;
    return t;
}

/*** Hilbert curve computation ***/

inline pair hilbert_uninterleaved(uint tx,uint ty)
{
    uint fx=tx^ty,fy=tx;
    uint p=prefix_parity(tx&ty)>>1;
    uint q=prefix_parity(~fx&0xffff)>>1;
    //printf("c: t %d, tx %d, ty %d, p %d, q 0x%x\n",interleave(tx,ty),tx,ty,p,q);
    uint gx=fx^p,gy=fy^p;
    uint hx=(gx&~q)|(gy&q),hy=(gy&~q)|(gx&q);
    return make_pair(hx,hy);
}

static inline pair hilbert(uint t_input)
{
    pair t=uninterleave(t_input);
    return hilbert_uninterleaved(t.x,t.y);
}

#define start_hilbert(hx,hy,p,q,tx,ty,t)  \
do{                                       \
    pair t_pair=uninterleave(t);          \
    tx=t_pair.x;ty=t_pair.y;              \
    uint fx=tx^ty,fy=tx;                  \
    p=prefix_parity(tx&ty)>>1;            \
    q=prefix_parity(~fx&0xffff)>>1;       \
    uint gx=fx^p,gy=fy^p;                 \
    hx=(gx&~q)|(gy&q);                    \
    hy=(gy&~q)|(gx&q);                    \
}while(0)

#define increment_hilbert(hx,hy,p,q,tx,ty)  \
do{                                         \
    uint c=tx&ty;c^=c+1;                    \
    uint cl=c>>1,ch=cl+1;                   \
    uint xch=tx&ch,ych=ty&ch;               \
    tx^=cl^ych;                             \
    ty^=c;                                  \
    uint fx=tx^ty,fy=tx;                    \
    if(xch) p^=cl;                          \
    p^=cl&(ch&0x5555?0x5555:0xaaaa);        \
    if(!ych) q^=cl;                         \
    uint gx=fx^p,gy=fy^p;                   \
    hx=(gx&~q)|(gy&q);                      \
    hy=(gy&~q)|(gx&q);                      \
}while(0)

/*** testing functions ***/

static void test_incremental()
{
    uint tx,ty,p,q,hx,hy;
    start_hilbert(hx,hy,p,q,tx,ty,0);
    for(uint t=1;t;t++){
        //printf("o: t %d, tx %d, ty %d, p %d, q 0x%x, hx %d, hy %d\n\n",t-1,tx,ty,p,q,hx,hy);
        pair h=hilbert(t);
        increment_hilbert(hx,hy,p,q,tx,ty);
        if(h.x!=hx || h.y!=hy){printf("f: t %d, tx %d, ty %d, p %d, q 0x%x, hx %d, hy %d, FAIL: h.x %d, h.y %d\n",t,tx,ty,p,q,hx,hy,h.x,h.y);exit(1);}
        if(!(t&0x00ffffff)){printf("%d / 256 done\n",t>>24);break;}}
    printf("full incremental test passed\n");
}

static void test_speed_incremental()
{
    uint total=0;
    uint tx,ty,p,q,hx,hy;
    start_hilbert(hx,hy,p,q,tx,ty,0);
    for(uint t=1;t;t++){
        increment_hilbert(hx,hy,p,q,tx,ty);
        total+=hx+hy;
        if(!(t&0x0fffffff)) printf("%d / 16 done\n",t>>28);}
    printf("total = %d\n",total);
    printf("incremental speed test complete\n");
}

static void test_continuity(int n)
{
    pair prev=hilbert(((1<<(2*n))-1)<<(32-2*n));
    prev.x>>=16-n;prev.y>>=16-n;
    for(uint t=0;t<(1<<(2*n));t++){
        pair h=hilbert(t<<(32-2*n));
        h.x>>=16-n;h.y>>=16-n;
        //printf("h(%d) = (%d,%d)\n",t,h.x,h.y);
        int dx=h.x-prev.x,dy=h.y-prev.y;
        int d=(dx*dx+dy*dy)&((1<<n)-1);
        if(d!=1){printf("FAIL: d = %d\n",d);exit(1);}
        if(!(t&(t-1))) printf("checked through t = %d\n",t);
        prev=h;}
    printf("continuity test passed with n = %d\n",n);
}

static void test_continuity_full()
{
    pair prev=hilbert(0);
    for(uint t=1;t;t++){
        pair h=hilbert(t);
        int dx=h.x-prev.x,dy=h.y-prev.y;
        int d=dx*dx+dy*dy;
        if(d!=1){printf("FAIL: d = %d\n",d);exit(1);}
        if(!(t&0x0fffffff)) printf("%d / 16 done\n",t>>28);
        prev=h;}
    printf("full continuity test passed\n");
}

static void test_bijection(int n)
{
    char *counts=calloc(sizeof(char),1<<(2*n));
    for(uint t=0;t<(1<<(2*n));t++){
        pair h=hilbert(t<<(32-2*n));
        h.x>>=16-n;h.y>>=16-n;
        //printf("h(%d) = (%d,%d)\n",t,h.x,h.y);
        counts[(h.x<<n)+h.y]++;}
    for(uint i=0;i<(1<<(2*n));i++)
        if(counts[i]!=1){printf("bijection error\n");exit(1);}
    printf("bijection test passed with n = %d\n",n); 
    free(counts);
}

typedef struct
{
    unsigned char r,g,b;
} color;

static color background={255,255,255};

static inline color color_map(int n,uint t)
{
    int s=t*256*2/(1<<(2*n));
    color c;

    // green to blue to red
    if(s<256){
        c.r=0;c.g=255-s;c.b=s;}
    else{
        s-=256;
        c.r=s;c.g=0;c.b=255-s;}

    //c.r=~c.r;c.g=~c.g;c.b=~c.b;

    return c;
}

static void render(int n)
{
    // setup
    int image_size=2*(1<<n)+1;
    color* flat_image=calloc(sizeof(color),image_size*image_size);
    color** image=calloc(sizeof(color*),image_size);
    for(int j=0;j<image_size;j++) image[image_size-1-j]=flat_image+image_size*j;
    for(int j=0;j<image_size*image_size;j++) flat_image[j]=background;

    // actually render
    pair prev=hilbert(((1<<(2*n))-1)<<(32-2*n));
    prev.x>>=16-n;prev.y>>=16-n;
    for(uint t=0;t<(1<<(2*n));t++){
        pair h=hilbert(t<<(32-2*n));
        h.x>>=16-n;h.y>>=16-n;
        image[2*prev.x][2*prev.y]=image[2*h.x][2*h.y]=image[prev.x+h.x][prev.y+h.y]=color_map(n,t);
        prev=h;}

    // write image
    const char *filename="hilbert.png";
    FILE* file=fopen(filename,"wb");
    if(!file){fprintf(stderr,"Failed to open %s for writing\n",filename);exit(1);}

    png_structp png_ptr=png_create_write_struct(PNG_LIBPNG_VER_STRING,0,0,0);
    if(!png_ptr){fprintf(stderr,"Error writing png file %s\n",filename);exit(1);}
    png_infop info_ptr=png_create_info_struct(png_ptr);
    if(!info_ptr){fprintf(stderr,"Error writing png file %s\n",filename);exit(1);}
    if(setjmp(png_jmpbuf(png_ptr))){fprintf(stderr,"Error writing png file %s\n",filename);exit(1);}
    png_init_io(png_ptr,file);
    png_set_IHDR(png_ptr,info_ptr,image_size,image_size,8,PNG_COLOR_TYPE_RGB,PNG_INTERLACE_NONE,PNG_COMPRESSION_TYPE_DEFAULT,PNG_FILTER_TYPE_DEFAULT);

    png_set_rows(png_ptr,info_ptr,(png_byte**)image);
    png_write_png(png_ptr,info_ptr,PNG_TRANSFORM_IDENTITY,0);
    free(image);free(flat_image);

    png_destroy_write_struct(&png_ptr,&info_ptr);fclose(file);
}

int main()
{
    if(1) render(8);
    if(0) test_bijection(6);
    if(0) test_continuity(15);
    if(0) test_continuity_full();
    if(0) test_incremental();
    if(0) test_speed_incremental();
    return 0;
}
