module Unification where
import PrettyPrinter
import Data.List
import Data.Maybe
import Control.Arrow

{- ========================================
    Definitions
   ======================================== -}

type Id = String
data Term =
      Var Id
    | Func Id [Term]
    deriving Eq

type Subst = [(Id, Term)]
type Equation = (Term, Term)

instance Show Term where
    show (Var x) = x
    show (Func f []) = f
    show (Func f xs) = f ++ "(" ++ (addcomma (map show xs)) ++ ")"

instance Ord Term where
    compare (Var x) (Var y) = compare x y
    compare (Var _) (Func _ _) = LT
    compare (Func _ _) (Var _) = GT
    compare (Func f fs) (Func g gs) =
      case compare f g of
        EQ -> compare fs gs
        LT -> LT
        GT -> GT

{- ========================================
    Predicates
   ======================================== -}

indom :: Id -> Subst -> Bool
indom x s = isJust $ lookup x s

occurs :: Id -> Term -> Bool
occurs x (Var y) = (x==y)
occurs x (Func _ ts) = any (occurs x) ts

{- ========================================
    Renaming
   ======================================== -}

extends_varname :: String -> Term -> Term
extends_varname e (Var x) = Var (x++e)
extends_varname e (Func f ts) = Func f (map (extends_varname e) ts)

vars :: Term -> [Id]
vars (Var x) = [x]
vars (Func _ xs) = concat (map vars xs)

  {- ========================================
      Substitution
     ======================================== -}

apply :: Subst -> Id -> Term
apply ((y, t):s) x = if (x==y) then t else apply s x

subst :: Subst -> Term -> Term
subst s (Var x) = if indom x s then (apply s x) else (Var x)
subst s (Func f ts) = Func f (map (subst s) ts)

{- ========================================
    Unification
   ======================================== -}

solve' :: [Equation] -> Subst -> Maybe Subst
solve' [] s = Just s
solve' ((Var x, t):ps) s = if (Var x == t) then solve' ps s else elim x t ps s
solve' ((t, Var x):ps) s = elim x t ps s
solve' ((Func f fs, Func g gs):ps) s =
    if (f == g) then solve' (zip fs gs ++ ps) s else Nothing

elim :: Id -> Term -> [Equation] -> Subst -> Maybe Subst
elim x t ps s =
    if occurs x t then Nothing
    else let sigma_xy = subst [(x, t)] in
    solve' (map (\(t1, t2) -> (sigma_xy t1, sigma_xy t2)) ps)
        ((x,t):map (\(y,u) -> (y, sigma_xy u)) s)

solution :: [Equation] -> Maybe Subst
solution p = solve' p []

solvable :: [Equation] -> Bool
solvable p = isJust (solution p)

matchable :: Equation -> Bool
matchable p = isJust $ solution [(extends_varname "0" *** extends_varname "1") p]