aboutsummaryrefslogtreecommitdiff
path: root/snug-clean/src/Snug/Compile.icl
blob: 601acb089c7395bda9488e4dfcef97bb1bae676a (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
implementation module Snug.Compile

import Control.Monad
import Control.Monad.Fail
import Control.Monad.State
import Control.Monad.Trans
import Data.Error
import Data.Func
import Data.Functor
import Data.List
import qualified Data.Map
from Data.Map import :: Map
import qualified Data.Set
import Data.Tuple
import StdEnv
from Text import concat3, concat4

import MIPS.MIPS32
import Snug.Compile.ABI
import Snug.Compile.Simulate
import Snug.Syntax

:: CompileState =
	{ namespace   :: !Namespace
	, fresh_ident :: !Int
	, globals     :: !Globals
	}

:: LocalLocation = FrontPtrArg !Int

instance == (Namespaced id) | == id
where
	(==) x y = x.id == y.id && x.ns == y.ns

instance < (Namespaced id) | < id
where
	(<) x y
		| x.id < y.id = True
		| x.id > y.id = False
		| otherwise = x.ns < y.ns

:: Type | NoType

compile :: !Namespace ![Definition] -> MaybeError String [Line]
compile ns defs =
	flatten <$> evalStateT
		(
			mapM liftCases defs <$&> flatten >>= \defs ->
			mapM compileDefinition defs
		)
		init
where
	init =
		{ namespace = ns
		, fresh_ident = 0
		, globals = combineGlobals [builtin, gatherGlobals ns defs]
		}

	builtin =
		{ constructors = 'Data.Map'.fromList
			[ ({ns="", id="INT"}, ConstructorDef "INT" [])
			]
		, functions = 'Data.Map'.newMap
		}

	combineGlobals :: ![Globals] -> Globals
	combineGlobals sets =
		{ constructors = 'Data.Map'.unions [g.constructors \\ g <- sets]
		, functions = 'Data.Map'.unions [g.functions \\ g <- sets]
		}

	gatherGlobals :: !Namespace ![Definition] -> Globals
	gatherGlobals ns defs =
		{ constructors = 'Data.Map'.fromList
			[ ({ns=ns, id=id}, cons)
			\\ DataDef _ _ conses <- defs
			, cons=:(ConstructorDef id _) <- conses
			]
		, functions = 'Data.Map'.fromList
			[ ({ns=ns, id=id}, {arity=length args, type=foldr TyApp ret (map snd (reverse args))})
			\\ FunDef id args ret _ <- defs
			]
		}

addGlobalFunction :: !(Namespaced SymbolIdent) !FunctionInfo -> CompileM ()
addGlobalFunction sym fi = modify \s -> {s & globals.functions='Data.Map'.put sym fi s.globals.functions}

/**
 * This pass ensures that `Case` only appears as the toplevel expression of a
 * `FunDef`, and that the expression that is matched on is a `Symbol` (namely,
 * one of the arguments to the function). It does this by creating new
 * `FunDef`s for cases deeper in the rhs. For now we assume that the resulting
 * `Case` expressions always make the expression that is matched on strict
 * (which is strictly speaking not correct for expressions like `case x of _ ->
 * ...`).
 *
 * TODO: the remaining work for this changeset consists of:
 * - Generating the correct matching code in `compileCase`
 * - Adding a fallthrough case in the base case of `compileCase`
 */
liftCases :: !Definition -> CompileM [Definition]
liftCases (FunDef name args ret expr) =
	liftCasesFromExpr True expr <$&> \(defs,expr) -> [FunDef name args ret expr:defs]
where
	liftCasesFromExpr toplevel e = case e of
		BasicValue _ ->
			pure ([], e)
		Symbol _ ->
			pure ([], e)
		Constructor _ ->
			pure ([], e)
		Case e alts ->
			mapM liftCasesFromAlt alts <$&> unzip <$&> appFst flatten >>= \(defs,alts) ->
			liftCase e alts <$&> appFst ((++) defs)
		ExpApp e1 e2 ->
			liftCasesFromExpr False e1 >>= \(ds1,e1) ->
			liftCasesFromExpr False e2 >>= \(ds2,e2) ->
			pure (ds1 ++ ds2, ExpApp e1 e2)

	liftCasesFromAlt (CaseAlternative pat expr) =
		liftCasesFromExpr False expr <$&> appSnd (CaseAlternative pat)

	liftCase e alts =
		gets (\s -> s.namespace) >>= \ns ->
		freshLabel "case" >>= \funName ->
		addGlobalFunction {ns=ns, id=funName} {arity=length argsToLift + 1, type=NoType} $>
		( [FunDef funName (argsToLift ++ [("_casearg", NoType)]) NoType (Case (Symbol "_casearg") alts)]
		, foldl ExpApp (Symbol funName) ([Symbol sym \\ (sym,_) <- argsToLift] ++ [e])
		)
	where
		usedSyms = usedSymbols (e, alts)
		argsToLift = [arg \\ arg=:(sym,_) <- args | 'Data.Set'.member sym usedSyms]
liftCases (TestDef name ty expr expected) =
	gets (\s -> s.namespace) >>= \ns ->
	freshLabel "test" >>= \funName ->
	liftCases (FunDef funName [] ty expr) <$&> \defs ->
	[TestDef name ty (Symbol funName) expected : defs]
liftCases def =
	pure [def]

freshLabel :: !String -> CompileM Label
freshLabel prefix = state \st -> (concat3 "_l" prefix (toString st.fresh_ident), {st & fresh_ident=st.fresh_ident+1})

lookupConstructorM :: !ConstructorIdent -> CompileM ConstructorDef
lookupConstructorM id =
	gets (\s -> (s.namespace, s.globals)) >>= \(ns,globals) ->
	liftT (lookupConstructor ns id globals)

lookupConstructor :: !Namespace !ConstructorIdent !Globals -> MaybeError String ConstructorDef
lookupConstructor ns id globals = mb2error
	(concat4 "Unknown constructor " ns "." id)
	('Data.Map'.get {ns=ns, id=id} globals.constructors)

lookupFunction :: !Namespace !SymbolIdent !Globals -> MaybeError String FunctionInfo
lookupFunction ns id globals = mb2error
	(concat4 "Unknown symbol " ns "." id)
	('Data.Map'.get {ns=ns, id=id} globals.functions)

lookupLocal :: !SymbolIdent !Locals -> MaybeError String Symbol
lookupLocal id locals = mb2error
	("Unknown local " +++ id)
	('Data.Map'.get id locals)

compileDefinition :: !Definition -> CompileM [Line]
compileDefinition (TypeDef _ _) = pure
	[]
compileDefinition (DataDef _ _ constructors) =
	(++) [StartSection "data"] <$>
	flatten <$> mapM compileConstructor constructors
compileDefinition (FunDef id args ret expr) =
	gets (\s -> s.namespace) >>= \ns ->
	let n_label = functionLabel ns NodeEntry id in
	(++)
		(if (isEmpty args) [] (
			[ StartSection "data"
			, Align 2
			] ++ flatten
			[[ Label (closureLabel ns id i)
			// TODO: Ideally we would use the following here:
			//, RawByte i // pointer arity
			//, RawByte 0 // basic value arity
			//, RawByte (length args-i-1) // number of arguments that still have to be curried in minus 1
			//, RawByte 0 // reserved
			// But since SPIM does not allow .byte in the text section, we use:
			, RawWord
				(i bitor // pointer arity
				((length args-i-1) << 16)) // number of arguments that still have to be curried in minus 1
			] \\ i <- [0..length args-1]
			] ++
			[ RawWordLabel n_label
			])) <$>
	(++)
		[ StartSection "text"
		, Global n_label
		// TODO: Ideally we would use the following here:
		//, Align 1
		//, RawByte (sum [2^i \\ i <- [0..] & _ <- args]) // all strict for now, TODO change
		//, RawByte (length args) // arity
		// But since SPIM does not allow .byte in the text section, we use:
		, Align 2
		, RawWord
			(sum [2^i \\ i <- [0..] & _ <- args] bitor // all strict for now, TODO change
			(length args << 8)) // arity
		// instead... (end modification)
		, Label n_label
		] <$>
	case expr of
		// due to liftCases, all Case expressions are on the top level and have
		// a Symbol as expression which is the last argument to the function
		Case (Symbol local) alts ->
			compileCase locals local alts
		_ ->
			map Instr <$> compileExpr locals expr
where
	locals = 'Data.Map'.fromList
		[ (id, LocalSymbol (FrontPtrArg offset))
		\\ (id,_) <- args
		& offset <- [0..]
		]

compileCase :: !Locals !SymbolIdent ![CaseAlternative] -> CompileM [Line]
compileCase _ _ [] =
	pure [] // TODO: add catch for partial cases
compileCase locals exprSymbol [CaseAlternative pat rhs:rest] =
	liftM2 (++) caseAlt (compileCase locals exprSymbol rest)
where
	// NB: we can assume that the expression has been evaluated; cases are strict
	caseAlt = case pat of
		Wildcard ->
			map Instr <$>
			compileExpr locals rhs
		BasicValuePattern bv ->
			abort "compileCase: BasicValuePattern\n" // TODO
		IdentPattern sym ->
			liftT (lookupLocal exprSymbol locals) >>= \exprSymbol ->
			map Instr <$> compileExpr ('Data.Map'.put sym exprSymbol locals) rhs
		ConstructorPattern cons args ->
			gets (\s -> s.namespace) >>= \ns ->
			freshLabel "match" >>= \match ->
			freshLabel "nomatch" >>= \nomatch ->
			liftT (lookupLocal exprSymbol locals) >>= \(LocalSymbol (FrontPtrArg i)) ->
			compileExpr locals rhs >>= \rhs -> // TODO: add args to locals
			pure $
				map Instr
				[ LoadAddress (TempImm 0) (Address 0 (constructorLabel ns cons))
				, LoadWord (TempImm 1) 0 FrontEvalPtr
				, LoadWord (TempImm 1) (4+4*i) (TempImm 1)
				, LoadWord (TempImm 1) 0 (TempImm 1)
				// TODO: if the rhs is small enough we can use `bne t0,t1,nomatch` instead of a jump
				, BranchOn2 BCEq (TempImm 0) (TempImm 1) match 4 // TODO what is the right offset to `match`?
				, Nop
				, Jump NoLink (Direct (Address 0 nomatch))
				, Nop
				] ++
				[ Label match
				: map Instr rhs
				] ++
				[ Label nomatch
				]

compileConstructor :: !ConstructorDef -> CompileM [Line]
compileConstructor (ConstructorDef id args) =
	gets (\s -> s.namespace) <$&> \ns ->
	let label = constructorLabel ns id in
	[ Global label
	, Align 1
	, Label label
	, RawByte (length args) // pointer arity
	, RawByte 0 // basic value arity
	//, RawByte -1 // number of arguments still to be curried in (unused for constructors)
	]

compileExpr :: !Locals !Expression -> CompileM [Instruction]
compileExpr locals expr =
	gets (\s -> (s.namespace, s.globals)) >>= \(ns,globals) ->
	let expr` = simulator ns globals locals expr >>| indirectAndEval in
	case simulate [SVRegOffset FrontEvalPtr 0] expr` of
		Error e -> fail ("Compiling an expression failed: " +++ e)
		Ok instrs -> pure instrs

simulator :: !Namespace !Globals !Locals !Expression -> Simulator ()
simulator _ _ _ (BasicValue bv) =
	pushBasicValue bv >>|
	buildCons (constructorLabel "" (label bv)) 1
where
	label (BVInt _) = "INT"
	label (BVChar _) = "CHAR"
simulator ns globals locals (Symbol id) =
	case 'Data.Map'.get id locals of
		?Just (LocalSymbol (FrontPtrArg i)) ->
			stackSize >>= \n ->
			pushArg (n-1) i
		?None ->
			liftT (lookupFunction ns id globals) >>= \info -> case info.arity of
				0 ->
					buildThunk (functionLabel ns NodeEntry id) 0
				_ ->
					fail "symbol with arity > 0" // TODO implement
simulator ns globals locals expr=:(ExpApp _ _) =
	case f of
		Symbol id -> // TODO include locals
			liftT (lookupFunction ns id globals) >>= \info
				| info.arity == length args ->
					mapM_ (simulator ns globals locals) (reverse args) >>|
					buildThunk (functionLabel ns NodeEntry id) info.arity
				| info.arity > length args ->
					mapM_ (simulator ns globals locals) (reverse args) >>|
					buildCons (closureLabel ns id (length args)) (length args)
				| info.arity < length args ->
					let
						(closure_args,extra_args) = splitAt info.arity args
						closure = foldl ExpApp f closure_args
					in
					mapM_ (simulator ns globals locals) extra_args >>|
					simulator ns globals locals closure >>|
					mapM_ (\_ -> buildThunk (functionLabel "" NodeEntry "ap") 2) extra_args
		Constructor id ->
			liftT (lookupConstructor ns id globals) >>= \(ConstructorDef _ arg_types)
				| length arg_types == length args ->
					mapM_ (simulator ns globals locals) (reverse args) >>|
					buildCons (constructorLabel ns id) (length args)
				| otherwise -> fail ("arity mismatch in application of " +++ id) // TODO implement
		_ -> // TODO
			fail "unexpected lhs of function application"
where
	(f, args) = linearizeApp expr []

	linearizeApp (ExpApp f x) xs = linearizeApp f [x:xs]
	linearizeApp e xs = (e, xs)
simulator ns globals locals (Case e alts) =
	liftT (fail "case in simulator")
/*
	simulator ns locals e >>|
	//eval >>| // TODO
	liftT freshLabel >>= \end ->
	mapM (simulateAlternative end) alts >>|
	label end
where
	simulateAlternative end (CaseAlternative pattern expr) =
		liftT freshLabel >>= \no_match ->
		//simulatePattern no_match locals pattern >>= \new_locals ->
		//simulator ns new_locals expr >>|
		jump end >>|
		label no_match
simulator _ _ _ _ = // TODO
	pushBasicValue (BVInt 0) >>|
	buildCons (constructorLabel "" "INT") 1
*/
//	= BasicValue !BasicValue
//	| Symbol !SymbolIdent
//	| Constructor !ConstructorIdent
//	| Case !Expression ![CaseAlternative]
//	| ExpApp !Expression !Expression

//	= Wildcard
//	| BasicValuePattern !BasicValue
//	| IdentPattern !SymbolIdent
//	| ConstructorPattern !ConstructorIdent ![SymbolIdent]