-- vim: et ts=2 sw=2 ai:
{-# LANGUAGE LambdaCase #-}
module SPL.Lex
where

import Prelude hiding(lex)

import Control.Applicative
import Control.Monad
import Data.Char
import Data.List.Utils

data Token
  = TIdent String
  | TInt Int
  | TChar Char
  | TBool Bool

  | TParenOpen
  | TParenClose
  | TBrackOpen
  | TBrackClose
  | TBraceOpen
  | TBraceClose
  | TColonColon
  | TEquals
  | TSemicolon
  | TDot
  | TComma
  | TArrow

  | TIf
  | TWhile
  | TReturn
  | TVar
  | TVoidType
  | TCharType
  | TBoolType
  | TIntType

  | TPlus
  | TMinus
  | TAsterisk
  | TSlash
  | TPercent
  | TEqEq
  | TLt
  | TGt
  | TLtEq
  | TGtEq
  | TExclamEq
  | TAmpAmp
  | TPipePipe
  | TColon
  | TExclam

  | TSingleComment String
  | TBlockComment String
  deriving (Show, Eq)

isCommentToken :: Token -> Bool
isCommentToken (TSingleComment _) = True
isCommentToken (TBlockComment _)  = True
isCommentToken _                  = False

isIdentToken :: Token -> Bool
isIdentToken (TIdent _) = True
isIdentToken _          = False

lex :: (Monad m, Alternative m) => String -> m [Token]
lex [] = pure []
lex (c:s) | isSpace c = lex s
lex s = (comment s <|> item s <|> int s <|> char s <|> bool s <|> ident s) >>=
  \(t,s') -> lex s' >>= \ts -> pure (t:ts)
  where
    ident :: (Alternative m) => String -> m (Token, String)
    ident (c:s)
      | isAlpha c = pure (TIdent (c:cs), s')
      | otherwise = empty
      where (cs,s') = span isIdentChar s

    int :: (Alternative m) => String -> m (Token, String)
    int (c:s)
      | isDigit c = pure (TInt $ read (c:cs), s')
      | otherwise = empty
      where (cs,s') = span isDigit s

    char :: (Alternative m) => String -> m (Token, String)
    char ('\'':c:'\'':s) = pure (TChar c, s)
    char _ = empty

    bool :: (Alternative m) => String -> m (Token, String)
    bool ('F':'a':'l':'s':'e':s) = noIdentifier (TBool False) s
    bool ('T':'r':'u':'e':s) = noIdentifier (TBool True) s
    bool _ = empty

    comment :: (Alternative m) => String -> m (Token, String)
    comment ('/':'/':s) = pure (TSingleComment cs, s')
      where
        (cs, s') = span (/= '\n') s
    comment ('/':'*':s) = pure (TBlockComment cs, s')
      where
        (cs, _:_:s') = spanList (\case
          ('*':'/':s) -> False
          _           -> True) s
    comment _ = empty

    item :: (Alternative m) => String -> m (Token, String)
    item ('i':'f':s) = noIdentifier TIf s
    item ('w':'h':'i':'l':'e':s) = noIdentifier TWhile s
    item ('r':'e':'t':'u':'r':'n':s) = noIdentifier TReturn s
    item ('v':'a':'r':s) = noIdentifier TVar s
    item ('V':'o':'i':'d':s) = noIdentifier TVoidType s
    item ('C':'h':'a':'r':s) = noIdentifier TCharType s
    item ('B':'o':'o':'l':s) = noIdentifier TBoolType s
    item ('I':'n':'t':s) = noIdentifier TIntType s
    item ('=':'=':s) = pure (TEqEq, s)
    item ('<':'=':s) = pure (TLtEq, s)
    item ('>':'=':s) = pure (TGtEq, s)
    item ('!':'=':s) = pure (TExclamEq, s)
    item ('&':'&':s) = pure (TAmpAmp, s)
    item ('|':'|':s) = pure (TPipePipe, s)
    item (':':':':s) = pure (TColonColon, s)
    item ('-':'>':s) = pure (TArrow, s)
    item ('(':s) = pure (TParenOpen, s)
    item (')':s) = pure (TParenClose, s)
    item ('[':s) = pure (TBrackOpen, s)
    item (']':s) = pure (TBrackClose, s)
    item ('{':s) = pure (TBraceOpen, s)
    item ('}':s) = pure (TBraceClose, s)
    item ('=':s) = pure (TEquals, s)
    item (';':s) = pure (TSemicolon, s)
    item ('.':s) = pure (TDot, s)
    item (',':s) = pure (TComma, s)
    item ('+':s) = pure (TPlus, s)
    item ('-':s) = pure (TMinus, s)
    item ('*':s) = pure (TAsterisk, s)
    item ('/':s) = pure (TSlash, s)
    item ('%':s) = pure (TPercent, s)
    item ('<':s) = pure (TLt, s)
    item ('>':s) = pure (TGt, s)
    item (':':s) = pure (TColon, s)
    item ('!':s) = pure (TExclam, s)
    item _       = empty

    noIdentifier :: (Alternative m) => Token -> String -> m (Token, String)
    noIdentifier t [] = pure (t, [])
    noIdentifier t s@(c:_)
      | isIdentChar c = empty
      | otherwise     = pure (t, s)

    isIdentChar :: Char -> Bool
    isIdentChar = liftM2 (||) isAlphaNum (== '_')