#include #include #include #include #include enum instr { PushRef, PushI, Put, Pop, Call, Jmp, JmpCond, Ret, Halt, Op }; enum op { OAdd, OMul, OSub, ODiv }; enum cond { CEq, CNe, CLt, CLe, CGt, CGe, CTrue }; static inline uint32_t instr_size(uint64_t *pgm) { switch (*pgm) { case PushRef: return 5+1; case PushI: return 7+1; case Put: return 1+5; case Pop: return 4; case Call: return 5; case Jmp: return 5; case JmpCond: switch (pgm[1]) { case CTrue: return 1+3+6; default: return 1+1+3+6; } case Ret: return 1; case Halt: return 1+1; case Op: switch (pgm[1]) { case OAdd: case OSub: return 1+4+3+4; case OMul: return 1+4+4+4; case ODiv: return 1+4+3+3+4; default: fprintf(stderr,"unknown operator %d\n",(int)pgm[1]); exit(1); } default: fprintf(stderr,"unknown instruction %d\n",(int)*pgm); exit(1); } } static inline void gen_instr(char *full_code, char **code_p, uint64_t **pgm_p, uint32_t *mapping) { uint64_t arg; char *code=*code_p; uint64_t *pgm=*pgm_p; switch (*pgm) { case PushRef: arg=pgm[1]; #ifdef DEBUG_JIT_INSTRUCTIONS fprintf(stderr,"PushRef %ld\n",arg); #endif code[0]='\x48'; /* mov rcx,[rsp+ARG*8] */ code[1]='\x8b'; code[2]='\x4c'; code[3]='\x24'; code[4]=(unsigned char)arg*8; code[5]='\x51'; /* push rcx */ pgm+=2; code+=6; break; case PushI: arg=pgm[1]; #ifdef DEBUG_JIT_INSTRUCTIONS fprintf(stderr,"PushI %lu\n",arg); #endif code[0]='\x48'; /* mov rcx,ARG */ code[1]='\xc7'; code[2]='\xc1'; *(uint32_t*)&code[3]=(uint32_t)arg; code[7]='\x51'; /* push rcx */ pgm+=2; code+=8; break; case Put: arg=pgm[1]; #ifdef DEBUG_JIT_INSTRUCTIONS fprintf(stderr,"Put %lu\n",arg); #endif code[0]='\x59'; /* pop rcx */ code[1]='\x48'; /* mov [rsp+ARG*8],rcx */ code[2]='\x89'; code[3]='\x4c'; code[4]='\x24'; code[5]=(unsigned char)arg*8; pgm+=2; code+=6; break; case Pop: arg=pgm[1]; #ifdef DEBUG_JIT_INSTRUCTIONS fprintf(stderr,"Pop %lu\n",arg); #endif code[0]='\x48'; /* add rsp,ARG*8 */ code[1]='\x83'; code[2]='\xc4'; code[3]=(unsigned char)arg*8; pgm+=2; code+=4; break; case Call: arg=pgm[1]; #ifdef DEBUG_JIT_INSTRUCTIONS fprintf(stderr,"Call %lu -> %d\n",arg,mapping[arg]-(uint32_t)(&code[5]-full_code)); #endif code[0]='\xe8'; /* callq ARG */ *(uint32_t*)&code[1]=mapping[arg]-(&code[5]-full_code); pgm+=2; code+=5; break; case Jmp: arg=pgm[1]; #ifdef DEBUG_JIT_INSTRUCTIONS fprintf(stderr,"Jmp %lu -> %d\n",arg,mapping[arg]-(uint32_t)(&code[5]-full_code)); #endif code[0]='\xe9'; /* jmpq ARG */ *(uint32_t*)&code[1]=mapping[arg]-(&code[5]-full_code); pgm+=2; code+=5; break; case JmpCond: { enum cond cond=pgm[1]; arg=pgm[2]; #ifdef DEBUG_JIT_INSTRUCTIONS fprintf(stderr,"JmpCond %d %lu -> %d\n",(int)cond,arg,mapping[arg]-(uint32_t)(&code[10]-full_code)); #endif code[0]='\x59'; /* pop rcx */ if (cond==CTrue) { code[1]='\x48'; /* test rcx,rcx */ code[2]='\x85'; code[3]='\xc9'; code+=4; } else { code[1]='\x58'; /* pop rax */ code[2]='\x48'; /* cmp rax,rcx */ code[3]='\x3b'; code[4]='\xc1'; code+=5; } code[0]='\x0f'; /* jcc */ switch (cond) { case CEq: code[1]='\x84'; break; case CTrue: case CNe: code[1]='\x85'; break; case CLt: code[1]='\x8c'; break; case CLe: code[1]='\x8e'; break; case CGt: code[1]='\x8f'; break; case CGe: code[1]='\x8d'; break; } *(uint32_t*)&code[2]=mapping[arg]-(&code[6]-full_code); pgm+=3; code+=6; break; } case Ret: #ifdef DEBUG_JIT_INSTRUCTIONS fprintf(stderr,"Ret\n"); #endif code[0]='\xc3'; /* retq */ pgm++; code+=1; break; case Halt: #ifdef DEBUG_JIT_INSTRUCTIONS fprintf(stderr,"Halt\n"); #endif code[0]='\x58'; /* pop rax */ code[1]='\xc3'; /* retq */ pgm++; code+=2; break; case Op: arg=pgm[1]; #ifdef DEBUG_JIT_INSTRUCTIONS fprintf(stderr,"Op %d\n",(int)arg); #endif /* pop rax */ code[0]='\x58'; /* mov rcx,[rsp] */ code[1]='\x48'; code[2]='\x8b'; code[3]='\x0c'; code[4]='\x24'; switch (arg) { case OAdd: case OSub: /* {add,sub} rax,rcx */ code[5]='\x48'; code[6]=arg==OAdd ? '\x01' : '\x29'; code[7]='\xc8'; code+=8; break; case OMul: /* imul rax,rcx */ code[5]='\x48'; code[6]='\x0f'; code[7]='\xaf'; code[8]='\xc1'; code+=9; break; case ODiv: /* xor rdx,rdx */ code[5]='\x48'; code[6]='\x31'; code[7]='\xd2'; /* idiv rcx */ code[8]='\x48'; code[9]='\xf7'; code[10]='\xf9'; code+=11; break; } /* mov [rsp],rax */ code[0]='\x48'; code[1]='\x89'; code[2]='\x04'; code[3]='\x24'; pgm+=2; code+=4; break; default: fprintf(stderr,"unknown instruction %d\n",(int)pgm[-1]); exit(1); } *code_p=code; *pgm_p=pgm; } void init_jit(uint64_t max_instrs, uint64_t max_code_len, char **code_block, uint32_t **mapping) { *code_block=mmap (NULL,max_code_len,PROT_READ | PROT_EXEC,MAP_ANONYMOUS | MAP_PRIVATE,0,0); *mapping=malloc (max_instrs*sizeof(uint32_t)); } char *jit_append(char *code_block, uint32_t code_len, char *code_ptr, uint32_t *mapping, uint32_t n_instr, uint64_t *pgm, int is_main) { uint32_t len=*pgm++; uint32_t i; uint32_t code_i=code_ptr-code_block; uint64_t *pgm_p=pgm; for (i=n_instr; i