module sjit import StdEnv import StdGeneric import StdMaybe import StdOverloadedList from Data.Func import mapSt, $ from Data.Map import :: Map(..), get, put, newMap, fromList import System.CommandLine import code from "sjit_c." :: Expr = Int !Int | Var !String | App !String ![Expr] :: Function = { fun_name :: !String , fun_args :: ![String] , fun_expr :: !Expr } :: Instr = PushRef !Int | PushI !Int | Put !Int | Pop !Int | Call !Int | Ret | Halt | IAddRet | IMulRet | ISubRet | IDivRet :: Program :== {!Instr} :: CompileState = { vars :: !Map String Int , funs :: !Map String Int , sp :: !Int , pc :: !Int , blocks :: ![!Program!] , jitst :: !JITState } :: JITState = { n_instr :: !Int , code_start :: !Int , code_len :: !Int , code_ptr :: !Int , mapping :: !Int } appendProgram :: !Bool !Program !JITState -> JITState appendProgram is_main prog jitst # new_code_ptr = append jitst.code_start jitst.code_len jitst.code_ptr jitst.mapping jitst.n_instr (encode prog) is_main = { jitst & code_ptr = new_code_ptr , n_instr = jitst.n_instr + size prog } where append :: !Int !Int !Int !Int !Int !{#Int} !Bool -> Int append _ _ _ _ _ _ _ = code { ccall jit_append "pIppIAI:p" } bootstrap :: (!Program, !CompileState) bootstrap # (len_bs, bs_funs) = bootstrap_funs # is = {i \\ i <- flatten [is \\ (_,is) <- header]} = ( is, { vars = newMap , funs = fromList bs_funs , sp = 0 , pc = len_bs , blocks = [!is!] , jitst = appendProgram False is (initJITState 1000) }) where bootstrap_funs :: (!Int, ![(String, Int)]) bootstrap_funs = iter 0 header where iter :: !Int ![(String, [Instr])] -> (!Int, ![(String, Int)]) iter pc [] = (pc, []) iter pc [(name,is):rest] # fun = (name,pc) # (pc,funs) = iter (pc+length is) rest = (pc,[fun:funs]) header :: [(!String, ![Instr])] header = [ ("_", [PushI 0,Call 0 /* main address */,Halt]) , ("+", [IAddRet]) , ("*", [IMulRet]) , ("-", [ISubRet]) , ("/", [IDivRet]) ] initJITState :: !Int -> JITState initJITState maxlen # (code_start,mapping) = init maxlen (maxlen*10) = { n_instr = 0 , code_start = code_start , code_len = maxlen*10 , code_ptr = code_start , mapping = mapping } where init :: !Int !Int -> (!Int, !Int) init _ _ = code { ccall init_jit "II:Vpp" } compile :: !Function !CompileState -> CompileState compile f cs # cs & funs = put f.fun_name cs.pc cs.funs # vars = cs.vars # cs & vars = foldr (uncurry put) cs.vars [(v,sp) \\ v <- f.fun_args & sp <- [cs.sp+1..]] # (is,cs) = expr f.fun_expr cs # is = {i \\ i <- reverse [Ret:Put (max 1 (length f.fun_args)+1):is]} = { cs & vars = vars , pc = cs.pc+2 , blocks = cs.blocks ++| [!is!] , jitst = appendProgram (f.fun_name == "main") is cs.jitst } where expr :: !Expr !CompileState -> (![Instr], !CompileState) expr (Int i) cs = ([PushI i], {cs & sp=cs.sp+1, pc=cs.pc+1}) expr (Var v) cs = case get v cs.vars of Just i -> ([PushRef (i-cs.sp)], {cs & sp=cs.sp+1, pc=cs.pc+1}) Nothing -> abort "undefined variable\n" expr (App f args) cs # (iss,cs) = mapSt expr args {cs & sp=cs.sp+1} = case get f cs.funs of Just f -> ([Pop (length args-1):Call f:flatten iss], {cs & sp=cs.sp+2-length args, pc=cs.pc+2}) Nothing -> abort "undefined function\n" compile_all :: !(Maybe CompileState) ![Function] -> CompileState compile_all mcs funs # cs = case mcs of Just cs -> cs Nothing -> snd bootstrap = foldl (flip compile) cs funs interpret :: !CompileState -> Int interpret cs = exec 0 [] where prog = get_program cs sz = size prog exec :: !Int ![Int] -> Int exec i stack | i < 0 || i >= sz = abort "out of bounds\n" | otherwise = case prog.[i] of PushI n -> exec (i+1) [n:stack] PushRef r -> exec (i+1) [stack!!r:stack] Put n -> case stack of [val:stack] -> exec (i+1) (take (n-1) stack ++ [val:drop n stack]) Pop n -> exec (i+1) (drop n stack) Call f -> exec f [i+1:stack] Ret -> case stack of [ret:stack] -> exec ret stack Halt -> case stack of [r] -> r _ -> abort (toString (length stack) +++ " values left on stack\n") IAddRet -> case stack of [ret:a:b:stack] -> exec ret [a:a+b:stack] IMulRet -> case stack of [ret:a:b:stack] -> exec ret [a:a*b:stack] ISubRet -> case stack of [ret:a:b:stack] -> exec ret [a:a-b:stack] IDivRet -> case stack of [ret:a:b:stack] -> exec ret [a:a/b:stack] get_program :: !CompileState -> Program get_program cs # prog = loop 0 cs.blocks (createArray (sum [size b \\ b <|- cs.blocks]) Halt) # prog & [1] = Call (fromJust (get "main" cs.funs)) = prog where loop :: !Int ![!Program!] !*Program -> .Program loop i [!b:bs!] prog # (i,prog) = copy i 0 (size b-1) b prog = loop i bs prog where copy :: !Int !Int !Int !Program !*Program -> *(!Int, !*Program) copy i _ -1 _ prog = (i, prog) copy i bi n b prog = copy (i+1) (bi+1) (n-1) b {prog & [i]=b.[bi]} loop _ [!!] prog = prog generic gEncodedSize a :: !a -> Int gEncodedSize{|Int|} _ = 1 gEncodedSize{|{!}|} fx xs = 1 + sum [fx x \\ x <-: xs] gEncodedSize{|UNIT|} _ = 0 gEncodedSize{|PAIR|} fx fy (PAIR x y) = fx x + fy y gEncodedSize{|EITHER|} fl _ (LEFT l) = fl l gEncodedSize{|EITHER|} _ fr (RIGHT r) = fr r gEncodedSize{|CONS|} fx (CONS x) = fx x + 1 gEncodedSize{|OBJECT|} fx (OBJECT x) = fx x derive gEncodedSize Instr generic gEncode a :: !a !Int !*{#Int} -> (!Int, !*{#Int}) gEncode{|Int|} n i arr = (i+1, {arr & [i]=n}) gEncode{|{!}|} fx xs i arr = walk 0 (i+1) {arr & [i]=sz} where sz = size xs walk ai i arr | ai >= sz = (i,arr) # (i,arr) = fx xs.[ai] i arr = walk (ai+1) i arr gEncode{|UNIT|} _ i arr = (i,arr) gEncode{|PAIR|} fx fy (PAIR x y) i arr # (i,arr) = fx x i arr = fy y i arr gEncode{|EITHER|} fl _ (LEFT l) i arr = fl l i arr gEncode{|EITHER|} _ fr (RIGHT r) i arr = fr r i arr gEncode{|CONS of {gcd_index}|} fx (CONS x) i arr = fx x (i+1) {arr & [i]=gcd_index} gEncode{|OBJECT|} fx (OBJECT x) i arr = fx x i arr derive gEncode Instr encode :: !a -> *{#Int} | gEncodedSize{|*|}, gEncode{|*|} a encode x # (_,arr) = gEncode{|*|} x 0 (createArray (gEncodedSize{|*|} x) -1) = arr exec :: !CompileState -> Int exec {jitst} = exec jitst.code_start where exec :: !Int -> Int exec _ = code { ccall jit_exec "p:I" } import Text.GenPrint derive gPrint Instr Start w # (io,w) = stdio w # io = Foldl (\io b -> io <<< " " <<< printToString b <<< "\n") (io <<< "Program blocks:\n") comp_state.blocks # io = io <<< "Interpreted result: " <<< interpreted_result <<< "\n" # io = io <<< "JIT-compiled result: " <<< jit_compiled_result <<< "\n" # (_,w) = fclose io w = setReturnCode (if (interpreted_result==jit_compiled_result) 0 1) w where interpreted_result = interpret comp_state jit_compiled_result = exec comp_state comp_state =: compile_all Nothing [ {fun_name="id", fun_args=["x"], fun_expr=Var "x"} , {fun_name="const", fun_args=["x","y"], fun_expr=Var "x"} , {fun_name="main", fun_args=[], fun_expr=App "+" [App "const" [Int 37, Int 10], App "const" [Int 5, Int 10]]} ]