Peano
Loading...
Searching...
No Matches
printers.py
Go to the documentation of this file.
1# -----------------------------------------------------------------------------
2# BSD 3-Clause License
3#
4# Copyright (c) 2024, Harrison Fullwood and Maurice Jamieson
5# All rights reserved.
6#
7# Redistribution and use in source and binary forms, with or without
8# modification, are permitted provided that the following conditions are met:
9#
10# * Redistributions of source code must retain the above copyright notice, this
11# list of conditions and the following disclaimer.
12#
13# * Redistributions in binary form must reproduce the above copyright notice,
14# this list of conditions and the following disclaimer in the documentation
15# and/or other materials provided with the distribution.
16#
17# * Neither the name of the copyright holder nor the names of its
18# contributors may be used to endorse or promote products derived from
19# this software without specific prior written permission.
20#
21# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
22# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
23# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
24# FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
25# COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
26# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
27# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
28# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
29# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
30# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
31# ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
32# POSSIBILITY OF SUCH DAMAGE.
33# -----------------------------------------------------------------------------
34
35from __future__ import annotations
36from abc import ABC, abstractmethod
37from typing_extensions import override
38import types
39import numpy as np
40from sympy import tensor, core, Range
41from sympy.core import numbers
42from sympy.codegen import ast
43from exahype.sympy import SymPyToMLIR, TypedFunction
44
45class CodePrinter(ABC):
46
47 def __init__(self: CodePrinter, kernel, name: str):
48 self.kernel = kernel
49
50 @abstractmethod
51 def loop(self: CodePrinter, expr, direction, below, struct_inclusion):
52 pass
53
54 @abstractmethod
55 def here(self):
56 pass
57
58 def file(self: cpp_printer, name: str, header = None):
59 with open(name,'w') as F:
60 F.write(self.code)
61
62
64
65 @override
66 def __init__(self, kernel, name: str = "time_step"):
67 super().__init__(kernel, name)
68
69 self.INDENT = 1
70
71 self.codecode = f'void {name}({kernel.input_types[0]} {kernel.inputs[0]}'
72 for i in range(1,len(kernel.inputs)):
73 self.codecode += f', {kernel.input_types[i]} {kernel.inputs[i]}'
74 self.codecode += ')' + ' {\n'
75
76 if len(kernel.literals) > 0:
77 for _ in kernel.literals:
78 self.indent()
79 self.codecode += _ + '\n'
80 self.codecode += '\n'
81
82 #allocate temp arrays
83 for item in kernel.all_items.values():
84 if str(item) not in kernel.inputs and isinstance(item, tensor.indexed.IndexedBase):
85 try:
86 kernel.parents[str(item)]
87 except KeyError:
88 self.alloc(item)
89 #allocate directional consts
90 for item in kernel.directional_consts:
91 self.indent()
92 self.codecode += f'double {kernel.all_items[item]};\n'
93 self.codecode += '\n'
94
95 #loops
96 for l,r,direction,struc in zip(kernel.LHS,kernel.RHS,kernel.directions,kernel.struct_inclusion):
97 if str(l) in kernel.directional_consts:
98 self.indent()
99 self.codecode += f'{l} = {r};\n'
100 else:
101 self.looploop([l,r],direction,kernel.dim+1,struc)
102
103 #delete temp arrays
104 self.codecode += '\n'
105 for item in kernel.all_items.values():
106 if str(item) not in kernel.inputs and isinstance(item, tensor.indexed.IndexedBase):
107 try:
108 kernel.parents[str(item)]
109 except KeyError:
110 self.indent()
111 self.codecode += f'delete[] {item};\n'
112 self.codecode += '}\n'
113 self.parse()
114
115 def indent(self,val=0,force=False):
116 self.INDENT += val
117 if val == 0 or force:
118 self.codecode += (self.INDENT * "\t")
119
120 @override
121 def loop(self,expr,direction,below,struct_inclusion):
122 level = self.kernel.dim + 1 - below
123 idx = self.kernel.indexes[level]
124
125 extra_condition = ('i +' in str(expr[0]) or 'i -' in str(expr[0]) or 'j +' in str(expr[0]) or 'j -' in str(expr[0]) or 'k +' in str(expr[0]) or 'k -' in str(expr[0]) or 'i +' in str(expr[1]) or 'i -' in str(expr[1]) or 'j +' in str(expr[1]) or 'j -' in str(expr[1]) or 'k +' in str(expr[1]) or 'k -' in str(expr[1]))
126 #set loop range using direction and struct_inclusion
127 if level == 0:
128 r = [0,self.kernel.n_patches]
129 # print(str(expr))
130 elif below == 0:
131 k = [val for key,val in self.kernel.item_struct.items() if key in str(expr)] + [struct_inclusion]
132 match min(k):
133 case 0:
134 r = [0,1]
135 case 1:
136 r = [0, self.kernel.n_real]
137 case 2:
138 r = [0, self.kernel.n_real+self.kernel.n_aux]
139 elif expr[0] == self.kernel.LHS[-1]:
140 r = [self.kernel.halo_size, self.kernel.patch_size + self.kernel.halo_size]
141 elif direction == -1:
142 # r = [self.kernel.halo_size, self.kernel.patch_size + self.kernel.halo_size]
143 r = [0, self.kernel.patch_size + 2*self.kernel.halo_size]
144 elif direction == level and direction >=0 and extra_condition:
145 r = r = [self.kernel.halo_size, self.kernel.patch_size + self.kernel.halo_size]
146 elif direction == level and direction >= 0:
147 r = [0, self.kernel.patch_size + 2*self.kernel.halo_size]
148 else:
149 r = [self.kernel.halo_size, self.kernel.patch_size + self.kernel.halo_size]
150
151
152 #add loop code
153 if str(idx) == 'var' and r[1] == 1:
154 self.indent(-1)
155 else:
156 self.indent()
157 self.codecode += f"for (int {idx} = {r[0]}; {idx} < {r[1]}; {idx}++)" + " {\n"
158 if below > 0: #next loop if have remaining loops
159 self.indent(1)
160 self.looploop(expr,direction,below-1,struct_inclusion)
161 self.indent(-1)
162 else: #print loop interior
163 self.indent(1,True)
164 if expr[1] == '':
165 self.codecode += f'{self.Cppify(expr[0])};\n'
166 else:
167 self.codecode += f'{self.Cppify(expr[0])} = {self.Cppify(expr[1])};\n'
168 self.indent(-1)
169 if str(idx) == 'var' and r[1] == 1: #removing unnecessary 'vars'
170 self.indent(1)
171 i = len(self.codecode) - 2
172 while self.codecode[i] != '\n':
173 if self.codecode[i:i+6] == ' + var':
174 self.codecode = self.codecode[:i] + self.codecode[i+6:]
175 i -= 1
176 None
177 else:
178 self.indent()
179 self.codecode += "}\n"
180
181 def alloc(self,item):
182 self.indent()
183 self.codecode += f'double *{item} = new double[{self.kernel.n_patches}'
184 for d in range(self.kernel.dim):
185 self.codecode += f'*{self.kernel.patch_size+2*self.kernel.halo_size}'
186 if self.kernel.item_struct[str(item)] == 0:
187 self.codecode += ']'
188 elif str(item) not in self.kernel.items:
189 self.codecode += f'*{self.kernel.n_real}]'
190 else:
191 self.codecode += f'*{self.kernel.n_real + self.kernel.n_aux}]'
192 self.codecode += ';\n'
193
194 def heritage(self,item): #for inserting parent classes
195 word = ''
196 out = ''
197 item += '1'
198 for a in item:
199 if a.isalpha():
200 word += a
201 else:
202 if word in self.kernel.parents.keys():
203 upper = self.kernel.parents[word]
204 if upper[-1] == ":":
205 out += f'{upper}{word}'
206 else:
207 out += f'{upper}.{word}'
208 else:
209 out += word
210 out += a
211 word = ''
212
213 return out[:len(out)-1]
214
215 def Cppify(self,item):
216 expr = [str(item)]
217 active = True
218 while active:
219 active = False
220 n = []
221 for a in expr:
222 if '[' in a and len(a) > 1:
223 active = True
224 for b in a.partition('['):
225 n.append(b)
226 elif ']' in a and len(a) > 1:
227 active = True
228 for b in a.partition(']'):
229 n.append(b)
230 else:
231 n.append(a)
232 expr = n
233 out = ''
234 unpack = False
235 in_func = False
236
237 for a in expr:
238 if a == '[':
239 out += a
240 unpack = True
241 elif unpack == False:
242 if ')' in a:
243 in_func = False
244
245 item = a
246 k = [str(val) for val in self.kernel.functions if val in a]
247 if len(k) != 0:
248 in_func = True
249
250
251 if in_func:
252 for b in self.kernel.items + self.kernel.directional_items:
253 if b in a:
254 a = a.replace(b,f'&{str(b)}')
255 # break
256
257 out += self.heritage(a)
258 else:
259 unpack = False
260 k = [key for key,val in self.kernel.item_struct.items() if key in item]
261 match self.kernel.item_struct[k[0]]:
262 case 0:
263 leap = 1
264 case 1:
265 leap = self.kernel.n_real
266 case 2:
267 leap = self.kernel.n_real + self.kernel.n_aux
268 if k[0] == self.kernel.items[1]:
269 size = self.kernel.patch_size
270 else:
271 size = self.kernel.patch_size + 2*self.kernel.halo_size
272 strides = [leap*size**2,leap*size,leap]
273 if self.kernel.dim == 3:
274 strides = [leap*size**3] + strides
275 i = 0
276 for char in a.split(','):
277 char = char.strip()
278 if i != 0:
279 out += ' + '
280 if i < len(strides):
281 out += f'{strides[i]}*'
282 if char in self.kernel.all_items:
283 out += f'{char}'
284 else:
285 out += f'({char})'
286
287 i += 1
288
289 return out
290
291 def parse(self):
292 mother = str(self.kernel.inputs[0])
293 j = 0
294 begin = False
295 offset = 0
296 replaces = []
297
298 for i,char in enumerate(self.codecode):
299 if char == '{' and not begin:
300 begin = True
301 i += offset
302 if begin:
303 if j < len(mother):
304 if char == mother[j]:
305 j += 1
306 else:
307 j = 0
308 elif j == len(mother):
309 if char != '.':
310 j = 0
311 else:
312 l = i + 1
313 while self.codecode[l].isalpha():# not in ['[','*',',',' ']:
314 l += 1
315 k = l
316 if self.codecode[l] == '[':
317 l += 1
318 k += 1
319 while str(self.codecode[k]) != '+':
320 k += 1
321 k += 2
322 self.codecode = self.codecode[:l] + "patch][" + self.codecode[k:]
323 else:
324 self.codecode = self.codecode[:l] + "[patch]" + self.codecode[k:]
325 offset += l - k + 7
326 j = 0
327
328 # for item in replaces:
329 # l, k = item
330 # self.code = self.code[:l] + "patch][" + self.code[k:]
331
332 @override
333 def file(self: cpp_printer, name: str = 'test.cpp', header = None):
334 inclusions = '#include "exahype2/UserInterface.h"\n#include "observers/CreateGrid.h"\n#include "observers/CreateGridAndConvergeLoadBalancing.h"\n#include "observers/CreateGridButPostponeRefinement.h"\n#include "observers/InitGrid.h"\n#include "observers/PlotSolution.h"\n#include "observers/TimeStep.h"\n#include "peano4/peano.h"\n#include "repositories/DataRepository.h"\n#include "repositories/SolverRepository.h"\n#include "repositories/StepRepository.h"\n#include "tarch/accelerator/accelerator.h"\n#include "tarch/accelerator/Device.h"\n#include "tarch/logging/CommandLineLogger.h"\n#include "tarch/logging/Log.h"\n#include "tarch/logging/LogFilter.h"\n#include "tarch/logging/Statistics.h"\n#include "tarch/multicore/Core.h"\n#include "tarch/multicore/multicore.h"\n#include "tarch/multicore/otter.h"\n#include "tarch/NonCriticalAssertions.h"\n#include "tarch/timing/Measurement.h"\n#include "tarch/timing/Watch.h"\n#include "tasks/FVRusanovSolverEnclaveTask.h"\n#include "toolbox/loadbalancing/loadbalancing.h"\n\n'
335 self.codecode = inclusions + self.codecode
336 if header != None:
337 self.codecode = f'#include "{header}"\n\n' + self.codecode
338
339 # This will perform the writing to the file
340 super().file(name, header)
341
342 @override
343 def here(self):
344 print(self.codecode)
345
346
348
349 @override
350 def __init__(self, kernel, name: str = "time_step"):
351 super().__init__(kernel, name)
352
353 # NOTE: for now, we assume that the first input is a double array and
354 # that all other variables are doubles unless stated. We don't
355 # need the dimensions of the double array
356 params = []
357 params.append(tensor.indexed.IndexedBase(kernel.inputs[0], real=True))
358 for i in range(1,len(kernel.inputs)):
359 params.append(ast.Symbol(kernel.inputs[i], real=True))
360
361 #allocate temp arrays
362 declarations = []
363 for item in self.kernel.all_items.values():
364 if str(item) not in kernel.inputs and isinstance(item, tensor.indexed.IndexedBase):
365 shape = []
366 shape.append(self.kernel.n_patches)
367 for d in range(self.kernel.dim):
368 shape.append(self.kernel.patch_size+2*self.kernel.halo_size)
369 #if self.kernel.item_struct[str(item)] == 0:
370 # continue
371 if str(item) not in self.kernel.items:
372 shape.append(self.kernel.n_real)
373 else:
374 shape.append(self.kernel.n_real + self.kernel.n_aux)
375
376 # NOTE: add in the shape
377 item._shape = tuple(shape)
378 declarations.append(ast.Declaration(item))
379
380 #allocate directional consts
381 for item in (self.kernel.directional_consts):
382 if isinstance(self.kernel.all_items[item], ast.Symbol):
383 declarations.append(ast.Declaration(self.kernel.all_items[item]))
384 else:
385 declarations.append(ast.Declaration(ast.Symbol(self.kernel.all_items[item], real=True)))
386
387 #loops
388 expr = declarations
389 for l,r,direction,struc in zip(kernel.LHS,kernel.RHS,kernel.directions,kernel.struct_inclusion):
390 if str(l) in kernel.directional_consts:
391 expr.append(ast.Assignment(l,r))
392 else:
393 # TODO: might be best to offload the loop, per Harrison's code
394 print(f"> {l} = {r} {direction} {kernel.dim+1} {struc}")
395 loop = self.looploop([l,r], direction, kernel.dim + 1, struc)
396 print(f"loop {loop}")
397 expr.append(loop)
398
399 #delete temp arrays
400 for item in kernel.all_items.values():
401 if str(item) not in kernel.inputs and isinstance(item, tensor.indexed.IndexedBase):
402 # NOTE: set the Symbol (args[0]) to 'none' -
403 # we'll then generate the 'memref.dealloc' op
404 expr.append(ast.Assignment(item.args[0], ast.none))
405
406 body = expr
407 fp = ast.FunctionPrototype(None, name, params)
408 fn = ast.FunctionDefinition.from_FunctionPrototype(fp, expr)
409
410 self.codecode = fn
411
412 @override
413 def loop(self,expr,direction,below,struct_inclusion):
414 level = self.kernel.dim + 1 - below
415 idx = self.kernel.indexes[level]
416
417 #set loop range using direction and struct_inclusion
418 if level == 0:
419 r = [0,self.kernel.n_patches]
420 elif below == 0:
421 k = [val for key,val in self.kernel.item_struct.items() if key in str(expr)] + [struct_inclusion]
422 match min(k):
423 case 0:
424 r = [0,1]
425 case 1:
426 r = [0, self.kernel.n_real]
427 case 2:
428 r = [0, self.kernel.n_real+self.kernel.n_aux]
429 elif direction == -1:
430 r = [0, self.kernel.patch_size + 2*self.kernel.halo_size]
431 elif direction != level and direction >= 0:
432 r = [0, self.kernel.patch_size + 2*self.kernel.halo_size]
433 else:
434 r = [self.kernel.halo_size, self.kernel.patch_size + self.kernel.halo_size]
435
436 #add loop code
437 if below > 0: #next loop if have remaining loops
438 body = self.looploop(expr,direction,below-1,struct_inclusion)
439 else: #print loop interior
440 if expr[1] == '':
441 body = expr[0]
442 else:
443 body = ast.Assignment(expr[0], expr[1])
444
445 return ast.For(idx, Range(r[0], r[1]), body=[ body ])
446
447
448 @override
449 def here(self):
450 mlir = SymPyToMLIR()
451
452 module = mlir.apply(self.codecode)
453 print(module)
454 return
455
456
457
458
459
460
461
462
463
464
465
466
file(cpp_printer self, str name, header=None)
Definition printers.py:58
__init__(CodePrinter self, kernel, str name)
Definition printers.py:47
loop(CodePrinter self, expr, direction, below, struct_inclusion)
Definition printers.py:51
loop(self, expr, direction, below, struct_inclusion)
Definition printers.py:413
__init__(self, kernel, str name="time_step")
Definition printers.py:350
file(cpp_printer self, str name='test.cpp', header=None)
Definition printers.py:333
Cppify(self, item)
Definition printers.py:215
heritage(self, item)
Definition printers.py:194
__init__(self, kernel, str name="time_step")
Definition printers.py:66
loop(self, expr, direction, below, struct_inclusion)
Definition printers.py:121
alloc(self, item)
Definition printers.py:181
indent(self, val=0, force=False)
Definition printers.py:115