exception DecompressionError of string

(* Ad one byte to an output string *)
let add_byte c (dst, offset) =
        dst.[offset] <- c;
        (dst, offset + 1)

(* Add n bytes to an output string *)
let add_bytes s (dst, offset) =
        let n = (String.length s) in
        String.blit s 0 dst offset n; 
        (dst, offset + n)

(* Get n bytes from location places back in an output string (rfc 1951 3.2.3) *)
let get_bytes location n (dst, offset) =
        let concatenate s len =
                let rec loop str = function
                        | 0 -> str
                        | k -> str ^ (loop str (k - 1))
                in 
                let cs = (loop s ((len / (String.length s)) + 1)) in
                String.sub cs 0 len        
        in
        if n < location then        
                String.sub dst (offset - location) n
        else
                concatenate (String.sub dst (offset - location) location) n

(* Get the next bit from a string *)
let get_next_bit (s, current_byte_offset, current_bit_offset) =
        let mask = [|1; 2; 4; 8; 16; 32; 64; 128 |] in
        let bin i = if i = 0 then 0 else 1 in
        let bit = bin ((int_of_char s.[current_byte_offset]) land mask.(current_bit_offset)) in
        let        next_byte_offset, next_bit_offset =
        if (current_bit_offset = 7) then
                current_byte_offset + 1, 0
        else
                current_byte_offset, current_bit_offset + 1 
        in
        bit, (s, next_byte_offset, next_bit_offset)

(* Skip any remaining bits in byte and set stream to next byte *)
let skip_remaining_bits (s, current_byte_offset, current_bit_offset) =
        if current_bit_offset = 0 then
                (s, current_byte_offset, current_bit_offset)
        else
                (s, current_byte_offset + 1, 0)

(* Get the next n bits from a string (n < 31 on 32 bit systems or 63 on 64 bit systems) *)
let get_next_bits n src =
        let rec loop sum p s  = function
                | 0 -> sum, s
                | k -> let (b, s') = get_next_bit s in 
                                        loop (sum + p*b) (p*2) s' (k-1)
        in loop 0 1 src n


(* Get the next n bytes from a string, skipping any left-over bits in the current byte *)
let get_next_bytes n (s, current_byte_offset, current_bit_offset) =
        let byte = if current_bit_offset = 0 then current_byte_offset else (current_byte_offset + 1) in
        let str = String.sub s byte n in
        str, (s, byte + n, 0)

(* Get array of n code lengths from stream (rfc 1951: 3.2.7) *)
let get_code_lengths n src =
        let seq = [|16; 17; 18; 0; 8; 7; 9; 6; 10; 5; 11; 4; 12; 3; 13; 2; 14; 1; 15|] in
        let rec loop src lst = function
        | 0 -> src, lst
        | k -> let l, s = get_next_bits 3 src in loop s (lst @ [l]) (k-1) 
        in
        let s, lst = loop src [] n in 
        let l = Array.of_list lst in
        let nl = Array.length l in
        let code_length = Array.make (Array.length seq) 0 in
        for i = 0 to (nl - 1) do
                code_length.(seq.(i)) <- l.(i) 
        done;
        code_length, s        

(* Construct a code table from a length table (rfc 1951: 3.2.2)  *)
let build_code_table code_length =
        let max_bits = (Array.fold_left (fun x a -> if a > x then a else x) 0 code_length) + 1 in
        let count n =
                Array.fold_left (fun x a -> if a = n then (x+1) else x) 0 code_length
        in
        let bl_count = Array.init max_bits (fun i -> if i = 0 then 0 else count i) in
        let sum_to n = 
                Array.fold_left (fun x a -> (x + a) lsl 1) 0 (Array.sub bl_count 0 n) 
        in
        let next_code = Array.init max_bits (fun i -> if i = 0 then 0 else sum_to i) in
        let code_table = Array.map 
                (fun c -> let nc = next_code.(c) in if (c > 0) then next_code.(c) <- (nc + 1); nc) 
                code_length 
        in
        Array.mapi (fun i c -> code_length.(i), c) code_table
        
(* Get the next letter in the alphabet from stream. 
        code_table contains all (length, code) pairs *)
let get_next_letter code_table src =
        let max_bits = 15 in
        let rollback bits (s, current_byte_offset, current_bit_offset) =
                let offset =  current_byte_offset*8 + current_bit_offset - bits in
                let new_byte_offset = offset / 8 and new_bit_offset = offset mod 8 in
                (s, new_byte_offset, new_bit_offset)                
        in
        let get_code_bits n src =
                let rec loop sum s  = function
                        | 0 -> sum, s
                        | k -> let (b, s') = get_next_bit s in 
                                                loop (2*sum + b) s' (k-1)
                in loop 0 src n
        in
        let letter, src' = (get_code_bits max_bits src) in
        let f (found, len, offset) (length, code) =
                if (not found) then
                        if (length > 0) && ((letter lsr (max_bits - length)) =  code) then
                                (true, (max_bits - length), offset)
                        else
                                (false, len, offset + 1)
                else
                        (found, len, offset)                
        in
        let found, diff, offset = Array.fold_left f (false, -1, 0) code_table in
        if (not found) then 
                raise (DecompressionError "letter not found")
        else
                offset, (rollback diff src')

(* Construct literal/length or distance alphabet (rfc 1951: 3.2.7) *)
let build_table code_table len h src =
        let lengths = Array.make len 0 in
        let rec loop s = function
                | 0 -> s
                | n -> 
                        let idx = h - n in 
                        let letter, s1 = get_next_letter code_table s in  
                        let s', n' = 
                        if letter < 16 then
                                let _ = lengths.(idx) <- letter in
                                s1,  n - 1
                        else
                                match letter with
                                | 16 -> let r, s2 = get_next_bits 2 s1 in
                                                Array.fill lengths idx (r + 3) lengths.(idx - 1);
                                                s2, (n - r - 3)
                                | 17 -> let r, s2 = get_next_bits 3 s1 in
                                                Array.fill lengths idx (r + 3) 0;
                                                s2, (n - r - 3)
                                | 18 -> let r, s2 = get_next_bits 7 s1 in
                                                Array.fill lengths idx (r + 11) 0;
                                                s2, (n - r - 11)
                                | _  -> raise (DecompressionError "wrong literal length")
                        in        
                        loop s'  n' 
        in 
        let s' = loop src h in         
        (build_code_table lengths, s')

(* Get the length from the length code (rfc 1951: 3.2.5) *)
let get_length_from_code b src =
        let base = Array.get 
                [| 3; 4; 5; 6; 7; 8; 9; 10; 11; 13; 15; 17; 19; 23; 27; 31; 
                35; 43; 51; 59; 67; 83; 99; 115; 131; 163; 195; 227; 258 |] (b-257)
        and bits = Array.get
                [|0; 0; 0; 0; 0; 0; 0; 0; 1; 1; 1; 1; 2; 2; 2; 2;
                3; 3; 3; 3; 4; 4; 4; 4; 5; 5; 5; 5; 0 |] (b-257)        in
        let r, src' = get_next_bits        bits src in
        (r + base, src')

(* Get the distance from the distance code (rfc 1951: 3.2.5) *)
let get_distance_from_code b src =
        let base = Array.get
                [|1; 2; 3; 4; 5; 7; 9; 13; 17; 25; 33; 49; 65; 97; 129; 193; 257; 385;
                 513; 769; 1025; 1537; 2049; 3073; 4097; 6145; 8193; 12289; 16385; 24577 |] b
        and bits = Array.get
                [| 0; 0; 0; 0; 1; 1; 2; 2; 3; 3; 4; 4; 5; 5; 6; 6; 7; 7; 8; 8; 9; 9; 
                10; 10; 11; 11; 12; 12; 13; 13 |] b in        
        let r, src' = get_next_bits        bits src in
        (r + base, src')

(* Decode the data in a compressed block (rfc 1951: 3.2.5) *)
let get_data lit dist s d =
        let rec loop src dst = function
                | false  -> src, dst
                | true -> 
                        let byte, s1 = get_next_letter lit src in
                        let src', dst', continue =
                        if byte = 256 then
                                s1, dst, false
                        else
                                if byte < 256 then
                                s1, (add_byte (char_of_int byte) dst), true
                                else
                                        let length, s2 = get_length_from_code byte s1 in
                                        let next_byte, s3 = get_next_letter dist s2 in
                                        let distance, s4 = get_distance_from_code next_byte s3 in
                                        let c = get_bytes distance length dst in
                                        s4, (add_bytes c dst), true
                        in
                        loop src' dst' continue
        in
        loop s d true
        
(* Make fixed length tables *)
let make_fixed_tables () =
        let length = Array.init 288 
                (fun i -> if i < 144 then 8 else if i < 256 then 9 else if i < 280 then 7 else 8)
        and distance = Array.make 30 5 in
        (build_code_table length, build_code_table distance)        

(*        Uncompress a string compressed with the zlib deflate algorithm (zip, infozip, gzip etc)
         Input:         s - an 8-bit clean string containing zlib-deflated data (and nothing else)
                        d - an 8-bit clean string of sufficient size to contain the decompressed data 
        Output: decompressed data will be placed in d 
        Raises DecompressionError if any problems are encountered. The contents of d may be 
        partially or wholly incorrect *)
let rec uncompress  s d =
        let rec decode (src, dst) continue = 
                if continue = true then
                let bfinal, s0 = get_next_bit src in
                let continue' =
                        if bfinal = 1 then false else true
                in                
                let btype, s1 = get_next_bits 2 s0 in
                match btype with
                | 0 -> 
                        let s2 = skip_remaining_bits s1 in
                        let len, s3 = get_next_bits 16 s2 in
                        let nlen,   s4 = get_next_bits 16 s3 in
                        let str, s5 = get_next_bytes len s4 in
                        decode (s5, (add_bytes str dst)) continue'
                | 1 -> 
                        let fixed_length_table, fixed_distance_table = make_fixed_tables () in
                        decode (get_data fixed_length_table fixed_distance_table s1 dst) continue'
                | 2 -> 
                        let hlit',  s2 = get_next_bits 5 s1 in
                        let hdist', s3 = get_next_bits 5 s2 in
                        let hclen', s4 = get_next_bits 4 s3 in
                        let hlit, hdist, hclen = hlit' + 257, hdist' + 1, hclen' + 4 in
                        let code_lengths, s5 = get_code_lengths hclen s4 in
                        let code_table = build_code_table code_lengths in
                        let lit_table, s6 = build_table code_table 288 hlit s5 in
                        let dist_table, s7 = build_table code_table 30 hdist s6 in
                        decode (get_data lit_table dist_table s7 dst) continue'
                | _ -> raise (DecompressionError "Unkown compression method")        
        in
        try
                let _ = decode ((s, 0, 0), (d,0)) true in
                d
        with
                | DecompressionError e -> raise (DecompressionError e)
                | _ -> raise (DecompressionError "Compressed data corrupted")        

(*
                 --------------- Test ------------------ 
                 A simple program to gunzip gzipped files
                 Usage: zlib <file.gz> <file>
*)
open Printf

(* Reads in a (binary) channel and returns the contents in a string *)
let binary_channel_to_string ic =
        let len = in_channel_length ic in
        let s = String.create len in
        let _ = really_input ic s 0 len in
        s

(*         Merge 4 consecutive bytes of a string to an integer, 
        assuming reverse network byte order *)
let rmerge_to_int a c =
        (int_of_char a.[c]) + (int_of_char a.[c + 1])*256 + 
        (int_of_char a.[c + 2])*256*256 + (int_of_char a.[c + 3])*256*256*256

(* gzunip: Get uncompressed data length from gzip compressed string *)
let get_uncompressed_data_length s =
        rmerge_to_int s ((String.length s) - 4)        

(* gunzip: Check first three bytes for gzip and deflate *)
let check_gzip s =
        (s.[0] = '\031' && s.[1] = '\139' && s.[2] = '\008')

(* gunzip: Get start of compressed data *)
let get_offset_compressed_data s =
        let flag = int_of_char s.[3] in
        let skip_base = 10 in
        let        skip_extra = if (flag land 4 != 0) then (((int_of_char s.[10]) + 256 * (int_of_char s.[11])) + 2 + skip_base) else skip_base in
        let skip_fname =  if (flag land 8 != 0) then ((String.index_from s skip_extra '\000') + 1) else skip_extra in
        let skip_comment =  if (flag land 16 != 0) then ((String.index_from s skip_fname '\000') + 1) else skip_fname in
        let skip_crc = if (flag land 2 != 0) then ((int_of_char s.[skip_comment]) + 256 * (int_of_char s.[skip_comment + 1]) + 2 + skip_comment) else skip_comment in
        skip_crc

(* gunzip: get compressed data *)
let get_compressed_data s =
        let l = get_offset_compressed_data s in
        String.sub s l ((String.length s) - l - 7)        

(* gunzip: return the uncompressed data in a string *)
let gunzip s =
        if (check_gzip s) then
                let l = get_uncompressed_data_length s in
                let data = get_compressed_data s in
                let dst = (String.make l '\000') in
                uncompress data dst 
        else
                raise (DecompressionError "Not a valid gzip file")

(* test program entry point *)
let main () =
                if Array.length Sys.argv > 2 then
                let s = binary_channel_to_string (open_in_bin Sys.argv.(1)) in
                let out = gunzip s in
                let oc = open_out_bin Sys.argv.(2) in
                output_string oc out                
        else
                printf "Usage: %s <file.gz> <file>\n" Sys.argv.(0)

let () = main ()