module Resolution where
import qualified Unification
import Data.Maybe
import Data.List
import Data.Tuple
import PrettyPrinter as PrettyPrinter
import Control.Arrow
import Control.Monad
import Data.Semigroup
import qualified Data.Set as Set

{- ========================================
    Definitions
   ======================================== -}

data Polarized a = Pos a | Neg a | NC a deriving (Ord, Eq)
type Ray = Polarized Unification.Term
newtype Star = Star { get_star :: [Ray] }
newtype Constellation = Const { get_const :: [Star] }
newtype Link = Link { get_link :: ((Ray, Ray), Int, Int) }
newtype StarGraph = StarGraph { get_graph :: ([Star], [Link]) }
type UGraph = StarGraph
type Diagram = StarGraph

instance Functor Polarized where
  fmap f (Pos t) = Pos (f t)
  fmap f (Neg t) = Neg (f t)
  fmap f (NC t) = NC (f t)

instance Eq Star where
  (Star s1) == (Star s2) = sort s1 == sort s2

instance Eq Constellation where
  (Const s1) == (Const s2) = sort s1 == sort s2

instance Eq Link where
  Link (e, i, j) == Link (e', i', j') =
      (i==i' && j==j' && i/=j) || (i==i' && j==j' && i==j && e==swap e')

instance Ord Star where
  compare (Star s1) (Star s2) = compare s1 s2

instance Show a => Show (Polarized a) where
  show (Pos t) = "+" ++ show t
  show (Neg t) = "-" ++ show t
  show (NC t) = show t

instance Show Star where
  show (Star s) = show s

instance Show Constellation where
  show (Const c) = show $ PrettyPrinter.addop '+' (map show c)

instance Show Link where
  show (Link x) = show x

instance Show StarGraph where
  show (StarGraph (v, e)) =
      "------------------------\n" ++
      "Vertices:\n" ++ unlines (mapInd (\x i -> show i ++ ":" ++ show x) v) ++
      "Edges:\n" ++ unlines (map show e)

instance Semigroup Constellation where
  (Const c1) <> (Const c2) = Const (c1 ++ c2)

mapInd :: (a -> Int -> b) -> [a] -> [b]
mapInd f l = zipWith f l [0..]

change_color :: String -> String -> Ray -> Ray
change_color s s' r@(Neg (Unification.Func f ts))
  | f==s = Neg (Unification.Func s' ts)
change_color s s' r@(Pos (Unification.Func f ts))
  | f==s = Neg (Unification.Func s' ts)
change_color _ _ r = r

dual :: Ray -> Ray -> Bool
dual (Pos _) (Neg _) = True
dual (Neg _) (Pos _) = True
dual _ _ = False

get_term :: Ray -> Unification.Term
get_term (Pos t) = t
get_term (Neg t) = t
get_term (NC t) = t

vars_ray :: Ray -> [Unification.Id]
vars_ray (Pos t) = Unification.vars t
vars_ray (Neg t) = Unification.vars t
vars_ray (NC t)  = Unification.vars t

vars_star :: Star -> [Unification.Id]
vars_star (Star s) = concat $ map vars_ray s

subst_ray :: Unification.Subst -> Ray -> Ray
subst_ray = fmap . Unification.subst

subst_rays :: Unification.Subst -> [Ray] -> [Ray]
subst_rays = map . subst_ray

extends_ray :: String -> Ray ->  Ray
extends_ray = fmap . Unification.extends_varname

extends_star :: String -> Star -> Star
extends_star s (Star e) = Star (map (extends_ray s) e)

rays :: Constellation -> [Ray]
rays (Const cs) = [r | s<-cs, r<-(get_star s)]

{- ========================================
    Unification graph
   ======================================== -}

matchable_rays :: Ray -> Ray -> Bool
matchable_rays r1 r2 = dual r1 r2 &&
  Unification.matchable (get_term r1, get_term r2)

get_links :: Star -> Star -> [(Ray, Ray)]
get_links (Star s1) (Star s2) =
  [(r1, r2) | r1<-s1, r2<-s2, matchable_rays r1 r2]

ugraph :: Constellation -> UGraph
ugraph (Const cs) =
  let imax = length cs - 1 in
  let pairs = [(i, j) | i<-[0..imax], j<-[0..imax], i<=j] in
  let e = foldl (\acc (i, j) ->
        let links = get_links (cs!!i) (cs!!j) in
        let make_edge e = Link (e, i, j) in
        if links==[] then acc else (map make_edge links) ++ acc) [] pairs in
  StarGraph (cs, nub e)

{- ========================================
    Actualisation
   ======================================== -}

make_edge :: [Star] -> (Int, Int) -> (Int, Int) -> Link
make_edge xs (is1, ir1) (is2, ir2) =
    let (Star s1) = xs !! is1 in
    let (Star s2) = xs !! is2 in
    Link ((s1!!ir1, s2!!ir2), is1, is2)

equations :: Diagram -> [Unification.Equation]
equations (StarGraph (v, e)) =
    map (\(Link ((t, u), i, j)) ->
        let new_t = Unification.extends_varname (show i) (get_term t) in
        let new_u = Unification.extends_varname (show j) (get_term u) in
        (new_t, new_u)) e

appears_in_link :: Ray -> Link -> Bool
appears_in_link r (Link ((t, u), _, _)) = r==t || r==u

already_linked :: Link -> [Link] -> Bool
already_linked (Link ((t, u), i, j)) ls =
    any (\(Link ((t', u'), i', j')) ->
    (t==t' && i==i') || (u==u' && j==j') ||
    (t==u' && i==j') || (u==t' && j==i')) ls

free_rays :: Diagram -> [Ray]
free_rays (StarGraph (v, e)) =
    let linked = concat [[(extends_ray (show i) r, i), (extends_ray (show j) r', j)] | (Link ((r, r'), i, j))<-e] in
    let stars = mapInd (curry id) v in
    let rays = concat $ map (\(Star s, i) -> map (\r -> (extends_ray (show i) r, i)) s) stars in
    map fst (rays \\ linked)

correct :: Diagram -> Bool
correct d = free_rays d /= [] && (isJust $ Unification.solution (equations d))

actualise :: Diagram -> Maybe Star
actualise d = do
  u <- Unification.solution (equations d)
  pure $ Star (subst_rays u (free_rays d))

fusions :: Star -> Star -> [Star]
fusions (Star star1) (Star star2) = do
    ray1 <- star1
    ray2 <- star2
    guard (ray1 `matchable_rays` ray2)
    u <- maybeToList $ Unification.solution [(get_term ray1, get_term ray2)]
    [Star (subst_rays u (delete ray1 star1 ++ delete ray2 star2))]

{- ========================================
    Execution
   ======================================== -}

connect :: (Star, Int) -> (Star, Int) -> [Link]
connect (Star s1, i) (Star s2, j) = do
    r1 <- s1
    r2 <- s2
    guard (r1 `matchable_rays` r2)
    [Link ((r1, r2), i, j)]

saturated :: Constellation -> Diagram -> Bool
saturated cs d =
    let free = free_rays d in
    let rs = rays cs in
    free_rays d /= [] &&
    [(r1, r2) | r1<-free, r2<-rs, r1 `matchable_rays` r2] == []

shift_link :: Int -> Link -> Link
shift_link k (Link (e, i, j)) = Link (e, i+k, j+k)

singleStep :: Constellation -> Diagram -> [Diagram]
singleStep (Const cs) d@(StarGraph (stars, links)) =
    let cs' = mapInd (curry id) cs in
    let stars' = mapInd (curry id) stars in
    do
    (s1, i) <- cs'
    (s2, j) <- stars'
    link <- connect (s1, 0) (s2, j+1)
    guard (not $ already_linked link (map (shift_link 1) links))
    pure (StarGraph (s1:stars, link:map (shift_link 1) links))

allCorrectDiagrams :: Constellation -> [Diagram]
allCorrectDiagrams (Const []) = []
allCorrectDiagrams (Const cs) = go [StarGraph ([cs!!0], [])] where
    go (d : ds) = do
        d : go (ds ++ filter correct (singleStep (Const cs) d))
    go [] = []

allCorrectSaturatedDiagrams :: Constellation -> [Diagram]
allCorrectSaturatedDiagrams cs = filter (saturated cs) $ allCorrectDiagrams cs

execution :: Constellation -> Constellation
execution cs = Const (catMaybes $ map actualise (allCorrectSaturatedDiagrams cs))


skeleton :: Constellation -> Constellation


deleteAt :: Int -> [a] -> [a]
deleteAt idx xs = lft ++ rgt
  where (lft, (_:rgt)) = splitAt idx xs

fusion :: (Star, Int) -> (Star, Int) -> [Star]
fusion (Star s1, i) (Star s2, j) =
    let r1 = s1 !! i in
    let r2 = s2 !! j in
    case Unification.solution [(get_term r1, get_term r2)] of
      Just u -> [Star (subst_rays u (deleteAt i s1 ++ deleteAt j s2))]
      Nothing -> []

step :: Constellation -> [Star] -> [Star]
step (Const cs) stars = do
    s1' <- cs
    s2' <- stars
    let s1 = get_star s1'
    let s2 = get_star s2'
    guard (s1 /= s2)
    i <- [0..length s1 - 1]
    j <- [0..length s2 - 1]
    guard ((s1 !! i) `matchable_rays` (s2 !! j))
    fusion (s1', i) (s2', j)

steps :: Int -> Constellation -> [Star] -> [Star]
steps 0 _ mem = mem
steps k cs mem =
    let new = step cs mem in
    if new == [] then mem else mem ++ steps (k-1) cs new

exec :: Constellation -> Constellation
exec wcs@(Const cs) = Const (steps 0 wcs cs)