aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--Sjit/Compile.dcl19
-rw-r--r--Sjit/Compile.icl81
-rw-r--r--Sjit/Run.icl4
-rw-r--r--Sjit/Syntax.dcl1
-rw-r--r--Sjit/Syntax.icl7
-rw-r--r--sjit_c.c35
-rw-r--r--test/fib.result1
-rw-r--r--test/fib.test2
8 files changed, 114 insertions, 36 deletions
diff --git a/Sjit/Compile.dcl b/Sjit/Compile.dcl
index 7cb74e3..fd4b2ee 100644
--- a/Sjit/Compile.dcl
+++ b/Sjit/Compile.dcl
@@ -12,6 +12,8 @@ from Sjit.Syntax import :: Function
| Pop !Int
| Call !Int
+ | Jmp !Int
+ | JmpTrue !Int
| Ret
| Halt
@@ -20,16 +22,19 @@ from Sjit.Syntax import :: Function
| ISubRet
| IDivRet
+ | PlaceHolder !Int !Int // only used during compilation
+
:: Program :== {!Instr}
:: CompileState =
- { vars :: !Map String Int
- , funs :: !Map String Int
- , sp :: !Int
- , pc :: !Int
- , blocks :: ![!Program!]
- , new_block :: ![!Instr!]
- , jitst :: !JITState
+ { vars :: !Map String Int
+ , funs :: !Map String Int
+ , sp :: !Int
+ , pc :: !Int
+ , blocks :: ![!Program!]
+ , new_block :: ![!Instr!]
+ , placeholder :: !Int
+ , jitst :: !JITState
}
:: JITState =
diff --git a/Sjit/Compile.icl b/Sjit/Compile.icl
index 4f10096..55b8ff3 100644
--- a/Sjit/Compile.icl
+++ b/Sjit/Compile.icl
@@ -2,14 +2,15 @@ implementation module Sjit.Compile
import StdEnv
import StdGeneric
-import StdMaybe
import StdOverloadedList
import Control.Applicative
import Control.Monad
import Data.Either
from Data.Func import mapSt, $
+import Data.Functor
from Data.Map import :: Map(..), get, put, newMap, fromList
+import Data.Maybe
import Sjit.Syntax
@@ -40,13 +41,14 @@ bootstrap
# is = {i \\ i <- flatten [is \\ (_,is) <- header]}
=
( is,
- { vars = newMap
- , funs = fromList bs_funs
- , sp = 0
- , pc = len_bs
- , blocks = [!is!]
- , new_block = [!!]
- , jitst = appendProgram False is (initJITState 1000)
+ { vars = newMap
+ , funs = fromList bs_funs
+ , sp = 0
+ , pc = len_bs
+ , blocks = [!is!]
+ , new_block = [!!]
+ , placeholder = 0
+ , jitst = appendProgram False is (initJITState 1000)
})
where
bootstrap_funs :: (!Int, ![(String, Int)])
@@ -94,32 +96,51 @@ where
add i cs = {cs & new_block=[!i:cs.new_block!], sp=sp, pc=cs.pc+1}
where
sp = cs.sp + case i of
- PushRef _ -> 1
- PushI _ -> 1
- Put _ -> -1
- Pop n -> 0-n
- Call _ -> 1
- JmpRelTrue _ -> 0
- Ret -> -1
- Halt -> -2
- IAddRet -> -1
- IMulRet -> -1
- ISubRet -> -1
- IDivRet -> -1
+ PushRef _ -> 1
+ PushI _ -> 1
+ Put _ -> -1
+ Pop n -> 0-n
+ Call _ -> 0
+ Jmp _ -> 0
+ JmpTrue _ -> 0
+ Ret -> -1
+ Halt -> -2
+ IAddRet -> -1
+ IMulRet -> -1
+ ISubRet -> -1
+ IDivRet -> -1
+ PlaceHolder _ n -> n
+
+reserve :: !Int !CompileState -> m (!Int, !CompileState) | Monad m
+reserve stack_effect cs=:{placeholder=p} =
+ gen (PlaceHolder p stack_effect) {cs & placeholder=p+1} >>= \cs -> pure (p,cs)
+
+fillPlaceHolder :: !Int !Instr !CompileState -> Either String CompileState
+fillPlaceHolder p newi cs = case replace cs.new_block of
+ Nothing -> Left "internal error with placeholder"
+ Just nb -> Right {cs & new_block=nb}
+where
+ replace :: ![!Instr!] -> Maybe [!Instr!]
+ replace [!PlaceHolder n _:is!] | n==p = Just [!newi:is!]
+ replace [!i:is!] = (\is -> [!i:is!]) <$> replace is
+ replace [!!] = Nothing
compile :: !Function !CompileState -> Either String CompileState
compile f cs
+# cs & sp = 0
# cs & funs = put f.fun_name cs.pc cs.funs
# vars = cs.vars
-# cs & vars = foldr (uncurry put) cs.vars [(v,sp) \\ v <- f.fun_args & sp <- [cs.sp+1..]]
-= case expr f.fun_expr cs of
+# cs & vars = foldr (uncurry put) cs.vars [(v,sp) \\ v <- f.fun_args & sp <- [1..]]
+# nargs = max 1 (length f.fun_args)
+= case expr f.fun_expr cs >>= gen [Ret,Put nargs] of
Left e -> Left e
Right cs
- # is = {i \\ i <|- Reverse [!Ret:Put (max 1 (length f.fun_args)+1):cs.new_block!]}
+ | cs.sp <> -1 -> Left ("sp was " +++ toString cs.sp +++ " after compiling '" +++ f.fun_name +++ "'")
+ # is = {i \\ i <|- Reverse cs.new_block}
-> Right
{ cs
& vars = vars
- , pc = cs.pc+2
+ , pc = cs.pc
, blocks = cs.blocks ++| [!is!]
, new_block = [!!]
, jitst = appendProgram (f.fun_name == "main") is cs.jitst
@@ -129,13 +150,21 @@ where
expr (Int i) cs = gen (PushI i) cs
expr (Bool b) cs = gen (PushI (if b 1 0)) cs
expr (Var v) cs = case get v cs.vars of
- Just i -> gen (PushRef (i-cs.sp)) cs
+ Just i -> gen (PushRef (i+cs.sp)) cs
Nothing -> Left ("undefined variable '" +++ v +++ "'")
expr (App f args) cs
# args = if (args=:[]) [Int 0] args
- = foldM (flip expr) {cs & sp=cs.sp+1} (reverse args) >>= \cs -> case get f cs.funs of
+ = foldM (flip expr) cs (reverse args) >>= \cs -> case get f cs.funs of
Nothing -> Left ("undefined function '" +++ toString f +++ "'")
Just f -> gen [Pop (length args-1),Call f] cs
+ expr (If b t e) cs =
+ expr b cs >>=
+ reserve -1 >>= \(jmptrue,cs=:{sp=orgsp}) ->
+ expr e cs >>=
+ reserve 0 >>= \(jmpend,cs) ->
+ fillPlaceHolder jmptrue (JmpTrue cs.pc) {cs & sp=orgsp} >>=
+ expr t >>= \cs ->
+ fillPlaceHolder jmpend (Jmp cs.pc) cs
generic gEncodedSize a :: !a -> Int
gEncodedSize{|Int|} _ = 1
diff --git a/Sjit/Run.icl b/Sjit/Run.icl
index b5858ec..e3e623f 100644
--- a/Sjit/Run.icl
+++ b/Sjit/Run.icl
@@ -23,6 +23,10 @@ where
[val:stack] -> exec (i+1) (take (n-1) stack ++ [val:drop n stack])
Pop n -> exec (i+1) (drop n stack)
Call f -> exec f [i+1:stack]
+ Jmp f -> exec f stack
+ JmpTrue f -> case stack of
+ [0:stack] -> exec (i+1) stack
+ [_:stack] -> exec f stack
Ret -> case stack of
[ret:stack] -> exec ret stack
Halt -> case stack of
diff --git a/Sjit/Syntax.dcl b/Sjit/Syntax.dcl
index 6a9f056..64895f6 100644
--- a/Sjit/Syntax.dcl
+++ b/Sjit/Syntax.dcl
@@ -7,6 +7,7 @@ from Data.Either import :: Either
| Bool !Bool
| Var !String
| App !String ![Expr]
+ | If !Expr !Expr !Expr
:: Function =
{ fun_name :: !String
diff --git a/Sjit/Syntax.icl b/Sjit/Syntax.icl
index 7412d7c..90aeb39 100644
--- a/Sjit/Syntax.icl
+++ b/Sjit/Syntax.icl
@@ -16,6 +16,8 @@ import Text.Parsers.Simple.Core
| TTrue
| TFalse
+ | TIf
+
| TEq
| TComma
@@ -32,6 +34,7 @@ where
TInt n -> toString n
TTrue -> "True"
TFalse -> "False"
+ TIf -> "if"
TEq -> "="
TComma -> ","
TParenOpen -> "("
@@ -51,6 +54,7 @@ where
# tk = case n of
"True" -> TTrue
"False" -> TFalse
+ "if" -> TIf
n -> TIdent n
-> lex [tk:tks] i e s
@@ -112,7 +116,8 @@ where
noInfix :: Parser Token Expr
noInfix =
- liftM2 App ident (pToken TParenOpen *> pSepBy expr (pToken TComma) <* pToken TParenClose)
+ liftM2 App ident (pToken TParenOpen *> pSepBy expr (pToken TComma) <* pToken TParenClose)
+ <|> liftM3 If (pToken TIf *> expr) expr expr
<|> Var <$> ident
<|> Int <$> int
<|> Bool <$> bool
diff --git a/sjit_c.c b/sjit_c.c
index 879ec18..325de47 100644
--- a/sjit_c.c
+++ b/sjit_c.c
@@ -11,6 +11,8 @@ enum instr {
Pop,
Call,
+ Jmp,
+ JmpTrue,
Ret,
Halt,
@@ -28,6 +30,8 @@ static inline uint32_t instr_size(enum instr instr) {
case Pop: return 4;
case Call: return 5;
+ case Jmp: return 5;
+ case JmpTrue: return 1+3+6;
case Ret: return 1;
case Halt: return 1+1;
@@ -52,7 +56,7 @@ static inline void gen_instr(char *full_code, char **code_p, uint64_t **pgm_p, u
case PushRef:
arg=pgm[1];
#ifdef DEBUG_JIT_INSTRUCTIONS
- fprintf(stderr,"PushRef %lu\n",arg);
+ fprintf(stderr,"PushRef %ld\n",arg);
#endif
code[0]='\x48'; /* mov rcx,[rsp+ARG*8] */
code[1]='\x8b';
@@ -86,7 +90,7 @@ static inline void gen_instr(char *full_code, char **code_p, uint64_t **pgm_p, u
code[2]='\x89';
code[3]='\x4c';
code[4]='\x24';
- code[5]=(unsigned char)(arg-1)*8;
+ code[5]=(unsigned char)arg*8;
pgm+=2;
code+=6;
break;
@@ -113,6 +117,31 @@ static inline void gen_instr(char *full_code, char **code_p, uint64_t **pgm_p, u
pgm+=2;
code+=5;
break;
+ case Jmp:
+ arg=pgm[1];
+#ifdef DEBUG_JIT_INSTRUCTIONS
+ fprintf(stderr,"Jmp %lu -> %d\n",arg,mapping[arg]-(uint32_t)(&code[5]-full_code));
+#endif
+ code[0]='\xe9'; /* jmpq ARG */
+ *(uint32_t*)&code[1]=mapping[arg]-(&code[5]-full_code);
+ pgm+=2;
+ code+=5;
+ break;
+ case JmpTrue:
+ arg=pgm[1];
+#ifdef DEBUG_JIT_INSTRUCTIONS
+ fprintf(stderr,"JmpTrue %lu -> %d\n",arg,mapping[arg]-(uint32_t)(&code[10]-full_code));
+#endif
+ code[0]='\x59'; /* pop rcx */
+ code[1]='\x48'; /* test rcx,rcx */
+ code[2]='\x85';
+ code[3]='\xc9';
+ code[4]='\x0f'; /* jne ARG */
+ code[5]='\x85';
+ *(uint32_t*)&code[6]=mapping[arg]-(&code[10]-full_code);
+ pgm+=2;
+ code+=10;
+ break;
case Ret:
#ifdef DEBUG_JIT_INSTRUCTIONS
fprintf(stderr,"Ret\n");
@@ -206,6 +235,8 @@ char *jit_append(char *code_block, uint32_t code_len, char *code_ptr,
case Put:
case Pop:
case Call:
+ case Jmp:
+ case JmpTrue:
pgm_p+=2;
break;
default:
diff --git a/test/fib.result b/test/fib.result
new file mode 100644
index 0000000..08c2ab3
--- /dev/null
+++ b/test/fib.result
@@ -0,0 +1 @@
+1346269
diff --git a/test/fib.test b/test/fib.test
new file mode 100644
index 0000000..cb5a5fa
--- /dev/null
+++ b/test/fib.test
@@ -0,0 +1,2 @@
+fib n = if n (if (n-1) (fib(n-1) + fib(n-2)) 1) 1
+fib(30)