open Unification

(* ========================================
     Definitions
   ======================================== *)

type pol = Pos | Neg | Npol
type ray = Var of id | Func of (id * pol * ray list)
type star = ray list
type constellation = star list
type graph = (int * int) * (ray * ray) list

(* List monad *)
let return x = [x] (*plongement dans la monade de liste*)
let (>>=) xs k = List.flatten (List.map k xs)
let guard c x = if c then return x else []

(* ========================================
     Useful functions
   ======================================== *)

(* Convert a pol and an id to a string, adding + or - before the id *)
let pol_to_string pol id =
    if pol = Pos then "+" ^ id 
    else if pol = Neg then "-" ^ id 
    else id

(* Convert a ray (which is polarized) to a term *)
let rec ray_to_term r =
  match r with
  | Var id -> (Var(id) : term)
  | Func(id, pol, raylist) -> (Func(pol_to_string pol id, List.map ray_to_term raylist) : term)

(* Invert polarization of a pol*)
let inv_pol pol = 
    if pol = Pos then Neg
    else if pol = Neg then Pos 
    else pol

(* Invert the polarization of a ray to allow an easier Unification writing *)
let rec inv_pol_ray ray = 
    match ray with
    | Func(id, pol, raylist) -> Func(id, inv_pol pol, List.map inv_pol_ray raylist)
    | _ -> ray

(* Checks if a ray is polarised *)
let rec is_polarised r =
    match r with
    | Var id -> false
    | Func(_, p, r) -> (p <> Npol) || (List.fold_left (fun acc b -> (is_polarised b) || acc) false r)

(* Checks if two rays are dual, meaning that after inverting polarization of one ray, the two rays can be unified *)
let dual_check r1 r2 =
  if (is_polarised r1 && is_polarised r2) then 
  (solve [(extends_varname (ray_to_term (inv_pol_ray r1)) "0"), (extends_varname ((ray_to_term r2)) "1")] []) 
  else None

(* Create an index for a constellation *)
let index_constellation const = 
    List.combine (List.init (List.length const) (fun a -> a)) const

(* ========================================
     Constellation graph
   ======================================== *)

(* Makes a dgraph from a constellation *)
let dgraph const =
  let indexed_const = index_constellation const in 
  indexed_const >>= fun (i, il) ->
  indexed_const >>= fun (j, jl) ->
  il >>= fun r1 ->
  jl >>= fun r2 ->
  guard (j >= i) ( let uni = dual_check r1 r2 in 
                   if Option.is_some uni then [((i,j),(r1,r2))]
                   else [])

(* Convert a link to a string to be printable *)
let link_to_string dg = 
  let rec aux dgl =
    match dgl with
    | [] -> ""
    | ((i,j),(r1, r2))::[] -> ("(" ^ string_of_int i ^ ", " ^ string_of_int j ^ ")" ^ "," ^ "(" ^ term_to_string (ray_to_term r1) ^ ", " ^ term_to_string (ray_to_term r2) ^ ")")
    | ((i,j),(r1, r2))::t -> ("(" ^ string_of_int i ^ ", " ^ string_of_int j ^ ")" ^ "," ^ "(" ^ term_to_string (ray_to_term r1) ^ ", " ^ term_to_string (ray_to_term r2) ^ ")") ^ "+" ^ (aux t)
  in aux dg ;;

(* Print a dgraph *)
let print_dgraph dg =
  let rec aux dgl =
    match dgl with
    | [] -> ""
    | h::[] -> (link_to_string h)
    | h::t -> (link_to_string h) ^ "\n" ^ aux t 
  in print_string (aux dg);;

let clean_dgraph g =
  List.filter (fun a -> a <> []) g

(* _________ Examples _________ *)
let make_const_pol pol c = Func (c, pol, [])
let make_const c = make_const_pol Npol c

let y = Var("y")
let x = Var("x")
let z = Var("z")
let r = Var("r")
let zero = make_const "0"
let s x = Func("s", Npol, [x])
let add p x y z = Func("add", p, [x;y;z])

(* Convert int to term *)
let rec enat i = 
  if i = 0 then zero else s (enat (i-1))

(* makes the constellation corresponding to an addition *)
let make_const_add n m = 
  [[add Pos zero y y]; [add Neg x y z; add Pos (s x) y (s z)]; [add Neg (enat n) (enat m) r; r]] 

let constellation = make_const_add 1 3 ;;

print_dgraph (dgraph constellation) ;;

(* exec graph *)
let exec graph const = 
  let rec aux graph sol = 
    match graph with
    | [] -> Some sol
    | h::tail -> 
      let rec aux2 glink sol = 
        match glink, sol with
        | [], _ -> aux tail sol
        | ((i,j),(ri, rj))::t,(Some sola, solb) ->
        aux2 t (solve [(substit (ray_to_term ri) sola,substit (ray_to_term rj) sola)] sola, (List.filter (fun a -> a <> ri) (List.nth const i))@solb )
        | (_,(None,_)) -> None
        in aux2 h sol
  in aux graph (Some [],[]) ;;

(* Exec where it just keeps the last equation and re-tries to solve it as a whole instead of applying the solution of the previous equation *)
let exec2 graph const = 
    let rec aux graph sol = 
      match graph with
      | [] -> Some sol
      | h::tail -> 
        let rec aux2 glink sol = 
          match glink, sol with
          | [], _ -> aux tail sol
          | ((i,j),(ri, rj))::t,(Some sola, solb) ->
          aux2 t (let eq = ((ray_to_term ri),(ray_to_term rj))::sola in 
            if (solve eq []) = None then 
              failwith "marchpo"
            else Some eq,
            (List.filter (fun a -> a <> ri) (List.nth const i))@solb )
          | (_,(None,_)) -> None
          in aux2 h sol
    in aux graph (Some [],[]) ;;

(* token is a couple of a family number and a star number in the constellation *)
type token = int * int
type process = token list

(* get a star using its number in the list from a constellation *)
let get_star i const =
  List.nth const i

(* takes a token, a graph and a constellation and returns the list of tokens to check next and a list of solvable equation *)
let divide_token (fam, n_star) graph const =
  let rec aux g toklist prob =
    match g with
    | [] -> Some (toklist,prob)
    | h::t -> let links = List.filter (fun ((i, _),(_, _)) -> i = n_star) h in
        let rec aux2 l tokl probb =
          match l with
          | [] -> Some (toklist,prob)
          | ((i, j),(ri,rj))::tl -> 
              if Option.is_some (solve ((ray_to_term (inv_pol_ray ri), ray_to_term rj)::probb) []) then 
                aux2 tl ((fam, j)::tokl) ((ray_to_term (inv_pol_ray ri), ray_to_term rj)::probb) 
              else None
        in if links = [] then aux t toklist prob else aux2 links toklist prob
  in aux graph [] []

(* should be deterministic exec, graph shouldn't be empty *)
let exec const =
  let graph = clean_dgraph (dgraph const) in
  let rec aux (toklist,prob) =
    begin match toklist with
      | [] -> prob
      | h::t -> aux (Option.get (divide_token h graph const))
    end              
  in let ((i,_),(_,_)) = (List.hd (List.hd graph)) in aux ([(0,i)],[])

(* test constellation cyclique déterministe *)
let test = [ [Func("c", Neg, [x]); x] ; [Func("c", Pos, [Func("f", Npol, [y])]) ; Func("c", Npol, [x]) ] ] ;;
print_dgraph (dgraph test);;
exec test ;;