type state = { stPosition: float; stVelocity: float; stAcceleration: float}

type derivative = { dstPosition: float; dstVelocity: float }

let euler st dt =
        {
                stPosition = st.stPosition +. st.stVelocity *. dt;
                stVelocity = st.stVelocity +. st.stAcceleration *. dt;
                stAcceleration = st.stAcceleration
        }

let newton_stormer_verlet st dt =
        let v = st.stVelocity +. st.stAcceleration *. dt in
        {
                stPosition = st.stPosition +. v *. dt;
                stVelocity = v;
                stAcceleration = st.stAcceleration
        }

let runge_kutta_4th_order st dt =
        let evaluate dt d =
                let s = { 
                                stPosition = st.stPosition +. d.dstPosition *. dt;
                                stVelocity = st.stVelocity +. d.dstVelocity *. dt;
                                stAcceleration = st.stAcceleration 
                        } 
                in
                {
                        dstPosition = s.stVelocity;
                        dstVelocity = s.stAcceleration
                }
        in          
        let a = evaluate 0.0 { dstPosition = st.stVelocity; dstVelocity = st.stAcceleration} in
        let b = evaluate (dt *. 0.5) a in
        let c = evaluate (dt *. 0.5) b in
        let d = evaluate dt c in
        let dxdt = 1.0 /. 6.0 *. (a.dstPosition +. 2.0 *. (b.dstPosition +. c.dstPosition) +. d.dstPosition) 
        and dvdt = 1.0 /. 6.0 *. (a.dstVelocity +. 2.0 *. (b.dstVelocity +. c.dstVelocity) +. d.dstVelocity) in
        {
                stPosition = st.stPosition +. dxdt *. dt;
                stVelocity = st.stVelocity +. dvdt *. dt;
                stAcceleration = st.stAcceleration
        }

let explicit_verlet st dt =
        let x0 = st.stPosition -. st.stVelocity *. dt and x1 = st.stPosition in
        let v = x1 -. x0 +. st.stAcceleration *. dt *. dt in
        {
                stPosition = x1 +. v;
                stVelocity = v /. dt;
                stAcceleration = st.stAcceleration
        }

let evaluate integrator initial time dt = 
        let rec eval st t = 
                if t >= time then st else
                eval (integrator st dt) (t +. dt)
        in
        eval initial 0. 
        
(* Test *)

let width = 800
let height = 300
let topx = 0
let topy = 0
                
let init_graphics () =
        let init_string = (Printf.sprintf " %dx%d" width height) in
        Graphics.open_graph init_string;
        Graphics.set_window_title "Simple 1d spring demo"
        
let draw_spring st col h = 
        let w = width / 2 in
        let x = (truncate st.stPosition) + w in
        let _ = Graphics.set_color Graphics.cyan in
        let _ = Graphics.draw_poly_line [| (0, h) ; (width, h)|] in
        let _ = Graphics.set_color col in
        Graphics.draw_circle x h 20 

let x0 = 0.4 *. (float width)
let v0 = 0.
let k = 1.0
let k2 = 0.1
let spring_equation x = -. k *. x  
let ball = { stPosition = x0; stVelocity = v0; stAcceleration = spring_equation x0}
let force st = {stPosition = st.stPosition; stVelocity = st.stVelocity; stAcceleration = spring_equation st.stPosition}
let evaluators = [|euler; newton_stormer_verlet; runge_kutta_4th_order; explicit_verlet|]
let colours    = [| Graphics.red; Graphics.cyan; Graphics.green; Graphics.blue |]
let initial_states = [| ball; ball; ball; ball |]
        
let rec main t0 st0 =
        let _ = Graphics.auto_synchronize false in
        let _ = Graphics.clear_graph () in 
        let rf h = draw_spring ball Graphics.black h in
        let t1 = Unix.gettimeofday ()in
        let dt = t1 -. t0 in
        let st1 = 
                        Array.mapi (fun i f -> force (evaluate f st0.(i) dt (dt/.10.))) evaluators
        in        
        let _ = Array.iteri (fun i c -> let h = ((i+1)*60) in rf h; draw_spring c colours.(i) h) st1 in
        let _ = Graphics.auto_synchronize true in 
        let status = Graphics.wait_next_event [Graphics.Poll; Graphics.Key_pressed] in
        if (status.Graphics.keypressed = false) then main t1 st1 

let ()  = init_graphics (); main (Unix.gettimeofday ()) initial_states