diff options
-rw-r--r-- | Sjit/Compile.dcl | 19 | ||||
-rw-r--r-- | Sjit/Compile.icl | 81 | ||||
-rw-r--r-- | Sjit/Run.icl | 4 | ||||
-rw-r--r-- | Sjit/Syntax.dcl | 1 | ||||
-rw-r--r-- | Sjit/Syntax.icl | 7 | ||||
-rw-r--r-- | sjit_c.c | 35 | ||||
-rw-r--r-- | test/fib.result | 1 | ||||
-rw-r--r-- | test/fib.test | 2 |
8 files changed, 114 insertions, 36 deletions
diff --git a/Sjit/Compile.dcl b/Sjit/Compile.dcl index 7cb74e3..fd4b2ee 100644 --- a/Sjit/Compile.dcl +++ b/Sjit/Compile.dcl @@ -12,6 +12,8 @@ from Sjit.Syntax import :: Function | Pop !Int | Call !Int + | Jmp !Int + | JmpTrue !Int | Ret | Halt @@ -20,16 +22,19 @@ from Sjit.Syntax import :: Function | ISubRet | IDivRet + | PlaceHolder !Int !Int // only used during compilation + :: Program :== {!Instr} :: CompileState = - { vars :: !Map String Int - , funs :: !Map String Int - , sp :: !Int - , pc :: !Int - , blocks :: ![!Program!] - , new_block :: ![!Instr!] - , jitst :: !JITState + { vars :: !Map String Int + , funs :: !Map String Int + , sp :: !Int + , pc :: !Int + , blocks :: ![!Program!] + , new_block :: ![!Instr!] + , placeholder :: !Int + , jitst :: !JITState } :: JITState = diff --git a/Sjit/Compile.icl b/Sjit/Compile.icl index 4f10096..55b8ff3 100644 --- a/Sjit/Compile.icl +++ b/Sjit/Compile.icl @@ -2,14 +2,15 @@ implementation module Sjit.Compile import StdEnv import StdGeneric -import StdMaybe import StdOverloadedList import Control.Applicative import Control.Monad import Data.Either from Data.Func import mapSt, $ +import Data.Functor from Data.Map import :: Map(..), get, put, newMap, fromList +import Data.Maybe import Sjit.Syntax @@ -40,13 +41,14 @@ bootstrap # is = {i \\ i <- flatten [is \\ (_,is) <- header]} = ( is, - { vars = newMap - , funs = fromList bs_funs - , sp = 0 - , pc = len_bs - , blocks = [!is!] - , new_block = [!!] - , jitst = appendProgram False is (initJITState 1000) + { vars = newMap + , funs = fromList bs_funs + , sp = 0 + , pc = len_bs + , blocks = [!is!] + , new_block = [!!] + , placeholder = 0 + , jitst = appendProgram False is (initJITState 1000) }) where bootstrap_funs :: (!Int, ![(String, Int)]) @@ -94,32 +96,51 @@ where add i cs = {cs & new_block=[!i:cs.new_block!], sp=sp, pc=cs.pc+1} where sp = cs.sp + case i of - PushRef _ -> 1 - PushI _ -> 1 - Put _ -> -1 - Pop n -> 0-n - Call _ -> 1 - JmpRelTrue _ -> 0 - Ret -> -1 - Halt -> -2 - IAddRet -> -1 - IMulRet -> -1 - ISubRet -> -1 - IDivRet -> -1 + PushRef _ -> 1 + PushI _ -> 1 + Put _ -> -1 + Pop n -> 0-n + Call _ -> 0 + Jmp _ -> 0 + JmpTrue _ -> 0 + Ret -> -1 + Halt -> -2 + IAddRet -> -1 + IMulRet -> -1 + ISubRet -> -1 + IDivRet -> -1 + PlaceHolder _ n -> n + +reserve :: !Int !CompileState -> m (!Int, !CompileState) | Monad m +reserve stack_effect cs=:{placeholder=p} = + gen (PlaceHolder p stack_effect) {cs & placeholder=p+1} >>= \cs -> pure (p,cs) + +fillPlaceHolder :: !Int !Instr !CompileState -> Either String CompileState +fillPlaceHolder p newi cs = case replace cs.new_block of + Nothing -> Left "internal error with placeholder" + Just nb -> Right {cs & new_block=nb} +where + replace :: ![!Instr!] -> Maybe [!Instr!] + replace [!PlaceHolder n _:is!] | n==p = Just [!newi:is!] + replace [!i:is!] = (\is -> [!i:is!]) <$> replace is + replace [!!] = Nothing compile :: !Function !CompileState -> Either String CompileState compile f cs +# cs & sp = 0 # cs & funs = put f.fun_name cs.pc cs.funs # vars = cs.vars -# cs & vars = foldr (uncurry put) cs.vars [(v,sp) \\ v <- f.fun_args & sp <- [cs.sp+1..]] -= case expr f.fun_expr cs of +# cs & vars = foldr (uncurry put) cs.vars [(v,sp) \\ v <- f.fun_args & sp <- [1..]] +# nargs = max 1 (length f.fun_args) += case expr f.fun_expr cs >>= gen [Ret,Put nargs] of Left e -> Left e Right cs - # is = {i \\ i <|- Reverse [!Ret:Put (max 1 (length f.fun_args)+1):cs.new_block!]} + | cs.sp <> -1 -> Left ("sp was " +++ toString cs.sp +++ " after compiling '" +++ f.fun_name +++ "'") + # is = {i \\ i <|- Reverse cs.new_block} -> Right { cs & vars = vars - , pc = cs.pc+2 + , pc = cs.pc , blocks = cs.blocks ++| [!is!] , new_block = [!!] , jitst = appendProgram (f.fun_name == "main") is cs.jitst @@ -129,13 +150,21 @@ where expr (Int i) cs = gen (PushI i) cs expr (Bool b) cs = gen (PushI (if b 1 0)) cs expr (Var v) cs = case get v cs.vars of - Just i -> gen (PushRef (i-cs.sp)) cs + Just i -> gen (PushRef (i+cs.sp)) cs Nothing -> Left ("undefined variable '" +++ v +++ "'") expr (App f args) cs # args = if (args=:[]) [Int 0] args - = foldM (flip expr) {cs & sp=cs.sp+1} (reverse args) >>= \cs -> case get f cs.funs of + = foldM (flip expr) cs (reverse args) >>= \cs -> case get f cs.funs of Nothing -> Left ("undefined function '" +++ toString f +++ "'") Just f -> gen [Pop (length args-1),Call f] cs + expr (If b t e) cs = + expr b cs >>= + reserve -1 >>= \(jmptrue,cs=:{sp=orgsp}) -> + expr e cs >>= + reserve 0 >>= \(jmpend,cs) -> + fillPlaceHolder jmptrue (JmpTrue cs.pc) {cs & sp=orgsp} >>= + expr t >>= \cs -> + fillPlaceHolder jmpend (Jmp cs.pc) cs generic gEncodedSize a :: !a -> Int gEncodedSize{|Int|} _ = 1 diff --git a/Sjit/Run.icl b/Sjit/Run.icl index b5858ec..e3e623f 100644 --- a/Sjit/Run.icl +++ b/Sjit/Run.icl @@ -23,6 +23,10 @@ where [val:stack] -> exec (i+1) (take (n-1) stack ++ [val:drop n stack]) Pop n -> exec (i+1) (drop n stack) Call f -> exec f [i+1:stack] + Jmp f -> exec f stack + JmpTrue f -> case stack of + [0:stack] -> exec (i+1) stack + [_:stack] -> exec f stack Ret -> case stack of [ret:stack] -> exec ret stack Halt -> case stack of diff --git a/Sjit/Syntax.dcl b/Sjit/Syntax.dcl index 6a9f056..64895f6 100644 --- a/Sjit/Syntax.dcl +++ b/Sjit/Syntax.dcl @@ -7,6 +7,7 @@ from Data.Either import :: Either | Bool !Bool | Var !String | App !String ![Expr] + | If !Expr !Expr !Expr :: Function = { fun_name :: !String diff --git a/Sjit/Syntax.icl b/Sjit/Syntax.icl index 7412d7c..90aeb39 100644 --- a/Sjit/Syntax.icl +++ b/Sjit/Syntax.icl @@ -16,6 +16,8 @@ import Text.Parsers.Simple.Core | TTrue | TFalse + | TIf + | TEq | TComma @@ -32,6 +34,7 @@ where TInt n -> toString n TTrue -> "True" TFalse -> "False" + TIf -> "if" TEq -> "=" TComma -> "," TParenOpen -> "(" @@ -51,6 +54,7 @@ where # tk = case n of "True" -> TTrue "False" -> TFalse + "if" -> TIf n -> TIdent n -> lex [tk:tks] i e s @@ -112,7 +116,8 @@ where noInfix :: Parser Token Expr noInfix = - liftM2 App ident (pToken TParenOpen *> pSepBy expr (pToken TComma) <* pToken TParenClose) + liftM2 App ident (pToken TParenOpen *> pSepBy expr (pToken TComma) <* pToken TParenClose) + <|> liftM3 If (pToken TIf *> expr) expr expr <|> Var <$> ident <|> Int <$> int <|> Bool <$> bool @@ -11,6 +11,8 @@ enum instr { Pop, Call, + Jmp, + JmpTrue, Ret, Halt, @@ -28,6 +30,8 @@ static inline uint32_t instr_size(enum instr instr) { case Pop: return 4; case Call: return 5; + case Jmp: return 5; + case JmpTrue: return 1+3+6; case Ret: return 1; case Halt: return 1+1; @@ -52,7 +56,7 @@ static inline void gen_instr(char *full_code, char **code_p, uint64_t **pgm_p, u case PushRef: arg=pgm[1]; #ifdef DEBUG_JIT_INSTRUCTIONS - fprintf(stderr,"PushRef %lu\n",arg); + fprintf(stderr,"PushRef %ld\n",arg); #endif code[0]='\x48'; /* mov rcx,[rsp+ARG*8] */ code[1]='\x8b'; @@ -86,7 +90,7 @@ static inline void gen_instr(char *full_code, char **code_p, uint64_t **pgm_p, u code[2]='\x89'; code[3]='\x4c'; code[4]='\x24'; - code[5]=(unsigned char)(arg-1)*8; + code[5]=(unsigned char)arg*8; pgm+=2; code+=6; break; @@ -113,6 +117,31 @@ static inline void gen_instr(char *full_code, char **code_p, uint64_t **pgm_p, u pgm+=2; code+=5; break; + case Jmp: + arg=pgm[1]; +#ifdef DEBUG_JIT_INSTRUCTIONS + fprintf(stderr,"Jmp %lu -> %d\n",arg,mapping[arg]-(uint32_t)(&code[5]-full_code)); +#endif + code[0]='\xe9'; /* jmpq ARG */ + *(uint32_t*)&code[1]=mapping[arg]-(&code[5]-full_code); + pgm+=2; + code+=5; + break; + case JmpTrue: + arg=pgm[1]; +#ifdef DEBUG_JIT_INSTRUCTIONS + fprintf(stderr,"JmpTrue %lu -> %d\n",arg,mapping[arg]-(uint32_t)(&code[10]-full_code)); +#endif + code[0]='\x59'; /* pop rcx */ + code[1]='\x48'; /* test rcx,rcx */ + code[2]='\x85'; + code[3]='\xc9'; + code[4]='\x0f'; /* jne ARG */ + code[5]='\x85'; + *(uint32_t*)&code[6]=mapping[arg]-(&code[10]-full_code); + pgm+=2; + code+=10; + break; case Ret: #ifdef DEBUG_JIT_INSTRUCTIONS fprintf(stderr,"Ret\n"); @@ -206,6 +235,8 @@ char *jit_append(char *code_block, uint32_t code_len, char *code_ptr, case Put: case Pop: case Call: + case Jmp: + case JmpTrue: pgm_p+=2; break; default: diff --git a/test/fib.result b/test/fib.result new file mode 100644 index 0000000..08c2ab3 --- /dev/null +++ b/test/fib.result @@ -0,0 +1 @@ +1346269 diff --git a/test/fib.test b/test/fib.test new file mode 100644 index 0000000..cb5a5fa --- /dev/null +++ b/test/fib.test @@ -0,0 +1,2 @@ +fib n = if n (if (n-1) (fib(n-1) + fib(n-2)) 1) 1 +fib(30) |