Peano
Loading...
Searching...
No Matches
SyntaxTree.py
Go to the documentation of this file.
1from abc import ABC, abstractmethod
2import os
3import copy
4from enum import Enum
5
6class Node(ABC):
7 output_format = None
8 spaces_per_tab = 4
9
10 @abstractmethod
11 def print_cpp(self, indent_level = 0):
12 pass
13
14 @abstractmethod
15 def print_mlir(self, indent_level = 0):
16 pass
17
18 @abstractmethod
19 def print_omp(self, indent_level = 0):
20 pass
21
22 @abstractmethod
23 def print_sycl(self, indent_level = 0):
24 pass
25
26 @abstractmethod
27 def print_tree(self, indent_level = 0):
28 pass
29
30class Statement(Node):
31 pass
32
33
35 def multidimensional_index(self, index_list, start_index=0):
36 return self
37
38 @abstractmethod
39 def get_type(self):
40 pass
41
42
43class Type(Node):
44 pass
45
46
48 def __init__(self, id, type: Type):
49 self.id = id
50 self._type = type
51
52 def print_cpp(self, indent_level = 0):
53 return self._type.print_cpp() + " " + self.id
54
55 def print_omp(self, indent_level = 0):
56 return self._type.print_omp() + " " + self.id
57
58 def print_sycl(self, indent_level = 0):
59 return self._type.print_sycl() + " " + self.id
60
61 def print_mlir(self, indent_level = 0):
62 return "%" + self.id + ": " + self._type.print_mlir()
63
64 def print_tree(self, indent_level = 0):
65 return ' ' * indent_level * Node.spaces_per_tab + f"Argument: {self.id}"
66
67
69 _stateless = False
70
71 def __init__(self, id, return_type: Type, template = None, namespaces = [], stateless = False):
72 self.id = id
73 self._return_type = return_type
74 self._template = template
75 self._arguments = []
76 self._body = []
77 self._namespaces = namespaces
78 FunctionDefinition._stateless = stateless
79
80 def add_statement(self, statement: Statement):
81 self._body.append(statement)
82
83 def add_argument(self, argument: Argument):
84 self._arguments.append(argument)
85
86 def print_declaration(self, indent_level = 0):
87 namespace_header = f"""namespace {"::".join(self._namespaces)} {{"""
88 namespace_footer = "}"
89 template_string = ""
90 if self._template is not None:
91 template_strings = [argument.print_cpp() for argument in self._template]
92 template_string = "template <" + ",".join(template_strings) + ">\n"
93
94 current_indent = indent_level + 1 # Indentation increases inside namespace
95 argument_prints = [argument.print_cpp() for argument in self._arguments]
96
97 return f"""#include <fstream>\n{namespace_header}
98{template_string}{' ' * current_indent * Node.spaces_per_tab}{self._return_type.print_cpp()} {self.id}({', '.join(argument_prints)});
99{namespace_footer}
100"""
101
102 def print_cpp(self, indent_level=0):
103 namespace_header = f"""namespace {"::".join(self._namespaces)} {{"""
104 namespace_footer = "}"
105 template_string = ""
106 if self._template is not None:
107 template_strings = [argument.print_cpp() for argument in self._template]
108 template_string = "template <" + ",".join(template_strings) + ">\n"
109
110 current_indent = indent_level + 1 # Indentation increases inside namespace
111 statement_prints = [statement.print_cpp(current_indent + 1) for statement in self._body]
112 argument_prints = [argument.print_cpp() for argument in self._arguments]
113
114 return f"""{namespace_header}
115{template_string}{' ' * current_indent * Node.spaces_per_tab}{self._return_type.print_cpp()} {self.id}({', '.join(argument_prints)}) {{
116std::ofstream log;
117{os.linesep.join(statement_prints)}
118{' ' * current_indent * Node.spaces_per_tab}}}
119{namespace_footer}
120"""
121
122 def print_omp(self, indent_level=0):
123 namespace_header = f"""namespace {"::".join(self._namespaces)} {{"""
124 namespace_footer = "}"
125 template_string = ""
126 if self._template is not None:
127 template_strings = [argument.print_omp() for argument in self._template]
128 template_string = "template <" + ",".join(template_strings) + ">\n"
129
130 current_indent = indent_level + 1 # Indentation increases inside namespace
131 statement_prints = [statement.print_omp(current_indent + 1) for statement in self._body]
132 argument_prints = [argument.print_omp() for argument in self._arguments]
133
134 return f"""{namespace_header}
135{template_string}{' ' * current_indent * Node.spaces_per_tab}{self._return_type.print_omp()} {self.id}({', '.join(argument_prints)}) {{
136{os.linesep.join(statement_prints)}
137{' ' * current_indent * Node.spaces_per_tab}}}
138{namespace_footer}
139"""
140
141 def print_sycl(self, indent_level=0):
142 namespace_header = f"""namespace {"::".join(self._namespaces)} {{"""
143 namespace_footer = "}"
144 template_string = ""
145 if self._template is not None:
146 template_strings = [argument.print_sycl() for argument in self._template]
147 template_string = "template <" + ",".join(template_strings) + ">\n"
148
149 current_indent = indent_level + 1 # Indentation increases inside namespace
150 statement_prints = [statement.print_sycl(current_indent + 1) for statement in self._body]
151 argument_prints = [argument.print_sycl() for argument in self._arguments]
152 #statement_prints.append("::sycl::queue Queue;")
153
154 return f"""{namespace_header}
155{template_string}{' ' * current_indent * Node.spaces_per_tab}{self._return_type.print_sycl()} {self.id}({', '.join(argument_prints)}) {{
156{os.linesep.join(statement_prints)}
157{' ' * current_indent * Node.spaces_per_tab}}}
158{namespace_footer}
159"""
160
161 def print_mlir(self, indent_level = 0):
162 statement_prints = [statement.print_mlir(indent_level + 1) for statement in self._body]
163 argument_prints = [argument.print_mlir() for argument in self._arguments]
164 return f"""
165func.func @{self.id}({', '.join(argument_prints)}) -> ({self._return_type.print_mlir()}) {{
166{os.linesep.join(statement_prints)}
167}}
168"""
169
170 def print_tree(self, indent_level = 0):
171 statement_prints = [statement.print_tree(indent_level + 1) for statement in self._body]
172 argument_prints = [argument.print_tree(indent_level + 1) for argument in self._arguments]
173 return f"""
174FunctionDefinition: {self.id}:
175{os.linesep.join(argument_prints)}
176{os.linesep.join(statement_prints)}"""
177
178 def print_definition_with_timer(self, indent_level = 0):
179 namespace_header = f"""namespace {"::".join(self._namespaces)} {{"""
180 namespace_footer = "}"
181 template_string = ""
182 function_call_string = self.id
183 if self._template is not None:
184 template_strings = [argument.print_cpp() for argument in self._template]
185 template_string = "template <" + ",".join(template_strings) + ">\n"
186 function_call_string += "<" + ",".join([template.split(' ')[1] for template in template_strings]) + ">"
187
188 arguments = copy.deepcopy(self._arguments)
189 arguments.append(Argument("measurement", TCustom("tarch::timing::Measurement&")))
190
191 current_indent = indent_level + 1 # Indentation increases inside namespace
192 argument_prints = [argument.print_cpp() for argument in arguments]
193
194 return f"""{namespace_header}
195{template_string}{' ' * current_indent * Node.spaces_per_tab}{self._return_type.print_cpp()} {self.id}({', '.join(argument_prints)}) {{
196tarch::timing::Watch watch("{"::".join(self._namespaces)}", "{self.id}", false, true);
197{function_call_string}({",".join([argument.print_cpp().split(' ')[-1] for argument in self._arguments])});
198watch.stop();
199measurement.setValue(watch.getCalendarTime());
200{' ' * current_indent * Node.spaces_per_tab}}}
201{namespace_footer}
202"""
203
204 def print_declaration_with_timer(self, indent_level = 0):
205 namespace_header = f"""namespace {"::".join(self._namespaces)} {{"""
206 namespace_footer = "}"
207 template_string = ""
208 if self._template is not None:
209 template_strings = [argument.print_cpp() for argument in self._template]
210 template_string = "template <" + ",".join(template_strings) + ">\n"
211
212 arguments = copy.deepcopy(self._arguments)
213 arguments.append(Argument("measurement", TCustom("tarch::timing::Measurement&")))
214
215 current_indent = indent_level + 1 # Indentation increases inside namespace
216 argument_prints = [argument.print_cpp() for argument in arguments]
217
218 return f"""{namespace_header}
219{template_string}{' ' * current_indent * Node.spaces_per_tab}{self._return_type.print_cpp()} {self.id}({', '.join(argument_prints)});
220{namespace_footer}
221"""
222
224 def __init__(self, statementToLog, outputStream):
225 self._statement = statementToLog
226 self._output = outputStream
227
228 def print_cpp(self, indent_level=0):
229 return f"""{self._output.print_cpp()} << {self._statement.print_cpp()};"""
230
231 def print_omp(self, indent_level=0):
232 return f"""{self._output.print_omp()} << {self._statement.print_omp()};"""
233
234 def print_sycl(self, indent_level=0):
235 return f"""{self._output.print_sycl()} << {self._statement.print_sycl()};"""
236
237 def print_mlir(self, indent_level=0):
238 pass
239
240 def print_tree(self, indent_level=0):
241 return "LogToFile"
242
243
244# Expressions
246 variables = dict()
247
248 def __init__(self, id, type = None):
249 self.id = id
250 self.type = type
251 if self.type is None:
252 self.type = Name.variables[self.id].get_type()
253
254 def multidimensional_index(self, index_list, start_index=0):
255 if type(Name.variables[self.id].get_type()) is TDataBlock:
256 return Name.variables[self.id].multidimensional_index(index_list, start_index)
257 else:
258 return self
259
260 def __add__(self, rhs):
261 return BinaryOperation("+", self, rhs)
262
263 def __sub__(self, rhs):
264 return BinaryOperation("-", self, rhs)
265
266 def __mul__(self, rhs):
267 return BinaryOperation("*", self, rhs)
268
269 def __div__(self, rhs):
270 return BinaryOperation("/", self, rhs)
271
272 def __neg__(self):
273 return UnaryOperation("-", self)
274
275 def __gt__(self, rhs):
276 if type(rhs) is Name:
277 return (Name.variables[self.id] > Name.variables[rhs.id])
278 return (Name.variables[self.id] > rhs)
279
280 def __eq__(self, rhs):
281 if type(rhs) is Name:
282 if rhs.id in Name.variables and self.id in Name.variables:
283 return (self._value == Name.variables[rhs.id])
284 else:
285 return False
286 else:
287 if self.id in Name.variables:
288 return (Name.variables[self.id] == rhs)
289 else:
290 return False
291
292 def get_type(self):
293 return self.type
294
295 def print_cpp(self, indent_level=0):
296 return ' ' * indent_level * Node.spaces_per_tab + self.id
297
298 def print_omp(self, indent_level=0):
299 return ' ' * indent_level * Node.spaces_per_tab + self.id
300
301 def print_sycl(self, indent_level=0):
302 return ' ' * indent_level * Node.spaces_per_tab + self.id
303
304 def print_mlir(self, indent_level=0):
305 return ' ' * indent_level * Node.spaces_per_tab + "%" + self.id
306
307 def print_tree(self, indent_level=0):
308 return ' ' * indent_level * Node.spaces_per_tab + "Name:" + self.id
309
311 def __init__(self, value, string = None):
312 self._value = int(value)
313 self._string = string
314
315 def __add__(self, rhs):
316 if self._value == 0:
317 return rhs
318 if type(rhs) is Integer:
319 return Integer(self._value + rhs._value)
320 return BinaryOperation("+", self, rhs)
321
322 def __sub__(self, rhs):
323 if self._value == 0:
324 if type(rhs) is Integer:
325 return Integer(-rhs._value)
326 if type(rhs) is UnaryOperation and rhs._operation == "-":
327 return rhs._operand
328 return -rhs
329 if type(rhs) is Integer:
330 return Integer(self._value - rhs._value)
331 return BinaryOperation("-", self, rhs)
332
333 def __mul__(self, rhs):
334 return BinaryOperation("*", self, rhs)
335
336 def __div__(self, rhs):
337 return BinaryOperation("/", self, rhs)
338
339 def __neg__(self):
340 if self._value == 0:
341 return self
342 else:
343 return UnaryOperation("-", self)
344
345 def __lt__(self, rhs):
346 if type(rhs) is Name:
347 if rhs.id in Name.variables:
348 return (self._value < Name.variables[rhs.id])
349 else:
350 return False
351 return (self._value < rhs)
352
353 def __gt__(self, rhs):
354 if type(rhs) is Name:
355 if rhs.id in Name.variables:
356 return (self._value > Name.variables[rhs.id])
357 else:
358 return False
359 return (self._value > rhs)
360
361 def __eq__(self, rhs):
362 if type(rhs) is Integer:
363 return (self._value == rhs._value)
364
365 def __int__(self):
366 return self._value
367
368 def get_type(self):
369 return TInteger()
370
371 def print_cpp(self, indent_level=0):
372 if self._string is None:
373 return ' ' * indent_level * Node.spaces_per_tab + str(self._value)
374 else:
375 return ' ' * indent_level * Node.spaces_per_tab + str(self._string)
376
377 def print_omp(self, indent_level=0):
378 if self._string is None:
379 return ' ' * indent_level * Node.spaces_per_tab + str(self._value)
380 else:
381 return ' ' * indent_level * Node.spaces_per_tab + str(self._string)
382
383 def print_sycl(self, indent_level=0):
384 if self._string is None:
385 return ' ' * indent_level * Node.spaces_per_tab + str(self._value)
386 else:
387 return ' ' * indent_level * Node.spaces_per_tab + str(self._string)
388
389 def print_mlir(self, indent_level=0):
390 if self._string is None:
391 return ' ' * indent_level * Node.spaces_per_tab + str(self._value)
392 else:
393 return ' ' * indent_level * Node.spaces_per_tab + str(self._string)
394
395 def print_tree(self, indent_level=0):
396 if self._string is None:
397 return ' ' * indent_level * Node.spaces_per_tab + "Integer: " + str(self._value)
398 else:
399 return ' ' * indent_level * Node.spaces_per_tab + "Integer: " + str(self._string)
400
401
403 def __init__(self, value):
404 self._value = value
405
406 def __str__(self):
407 return str(self._value)
408
409 def get_type(self):
410 return TBoolean()
411
412 def print_cpp(self, indent_level=0):
413 return ' ' * indent_level * Node.spaces_per_tab + str(self._value)
414
415 def print_omp(self, indent_level=0):
416 return ' ' * indent_level * Node.spaces_per_tab + str(self._value)
417
418 def print_sycl(self, indent_level=0):
419 return ' ' * indent_level * Node.spaces_per_tab + str(self._value)
420
421 def print_mlir(self, indent_level=0):
422 return ' ' * indent_level * Node.spaces_per_tab + str(self._value)
423
424 def print_tree(self, indent_level=0):
425 return ' ' * indent_level * Node.spaces_per_tab + "Boolean: " + str(self._value)
426
428 def __init__(self, value):
429 self._value = value
430
431 def __str__(self):
432 return self._value
433
434 def get_type(self):
435 return TString()
436
437 def print_cpp(self, indent_level=0):
438 return ' ' * indent_level * Node.spaces_per_tab + self._value
439
440 def print_omp(self, indent_level=0):
441 return ' ' * indent_level * Node.spaces_per_tab + self._value
442
443 def print_sycl(self, indent_level=0):
444 return ' ' * indent_level * Node.spaces_per_tab + self._value
445
446 def print_mlir(self, indent_level=0):
447 return ' ' * indent_level * Node.spaces_per_tab + self._value
448
449 def print_tree(self, indent_level=0):
450 return ' ' * indent_level * Node.spaces_per_tab + "String: " + self._value
451
452
454 def __init__(self, value, string = None, reference = False):
455 self._value = value
456 self._string = string
457 self._reference = reference
458
459
460 def __float__(self):
461 return self._value
462
463 def get_type(self):
464 return TFloat(self._reference)
465
466 def print_cpp(self, indent_level=0):
467 if self._string is None:
468 return ' ' * indent_level * Node.spaces_per_tab + str(self._value)
469 else:
470 return ' ' * indent_level * Node.spaces_per_tab + str(self._string)
471
472 def print_omp(self, indent_level=0):
473 if self._string is None:
474 return ' ' * indent_level * Node.spaces_per_tab + str(self._value)
475 else:
476 return ' ' * indent_level * Node.spaces_per_tab + str(self._string)
477
478 def print_sycl(self, indent_level=0):
479 if self._string is None:
480 return ' ' * indent_level * Node.spaces_per_tab + str(self._value)
481 else:
482 return ' ' * indent_level * Node.spaces_per_tab + str(self._string)
483
484 def print_mlir(self, indent_level=0):
485 if self._string is None:
486 return ' ' * indent_level * Node.spaces_per_tab + str(self._value)
487 else:
488 return ' ' * indent_level * Node.spaces_per_tab + str(self._string)
489
490 def print_tree(self, indent_level=0):
491 return ' ' * indent_level * Node.spaces_per_tab + "Float: " + str(self._value)
492
493
495 def __init__(self, iteration_range, internal, requires_memory_allocation, id = None, underlying_type = String("double")):
496 self._iteration_range = iteration_range
497 self._internal = internal
498 self.requires_memory_allocation = requires_memory_allocation
499 self._underlying_type = underlying_type
500 self.id = id
501 self._offset = []
503 for [x, y] in self._iteration_range:
504 self._offset.append(Integer(0))
505 self._memory_range.append([Integer(0), y - x])
506
507 def offset(self, offset):
508 output = DataBlock(copy.deepcopy(self._iteration_range), copy.deepcopy(self._internal), True)
509 output.id = self.id
510
511 for i in range(len(offset)):
512 if type(offset[i]) is list:
513 if offset[i][0] is None:
514 offset[i][0] = self._iteration_range[i][0]
515 if offset[i][1] is None:
516 offset[i][1] = self._iteration_range[i][1]
517
518 output._offset[i] = offset[i][0] - self._iteration_range[i][0]
519 output._iteration_range[i][1] = offset[i][1]
520 output._iteration_range[i][0] = offset[i][0]
521 else:
522 raise Exception("Invalid index")
523
524 for i in range(len(output._offset), len(output._iteration_range)):
525 output._offset.append(Integer(0))
526
527 return output
528
529 def linearise_index(self, indices, dimensions, offset_start_index = 0):
530 factors = [Integer(1)]
531 index_list = indices[0:-1]
532 for i in range(0, len(dimensions) - 1):
533 factors.append(factors[-1] * (self._memory_range[i][1]))
534 if type(dimensions[-len(index_list)][1] - dimensions[-len(index_list)][0]) is Integer and (dimensions[-len(index_list)][1] - dimensions[-len(index_list)][0])._value == 1:
535 index = self._offset[offset_start_index]
536 else:
537 if type(self._offset[offset_start_index + 0].get_type()) is TDataBlock:
538 index = index_list[0] + self._offset[offset_start_index + 0].multidimensional_index(index_list, offset_start_index)
539 else:
540 index = index_list[0] + self._offset[offset_start_index + 0]
541
542 index = factors[-len(index_list)] * index
543 for i in range(1, len(index_list)):
544 if type(self._offset[offset_start_index + i].get_type()) is TDataBlock:
545 temp = index_list[i] + self._offset[offset_start_index + i].multidimensional_index(indices, offset_start_index)
546 else:
547 temp = index_list[i] + self._offset[offset_start_index + i]
548 index = index + factors[-len(index_list) + i] * temp
549 return index
550
551 def multidimensional_index(self, index_list, start_index = 0):
552 if len(self._iteration_range) == 1:
553 return Subscript(self, index_list[-1])
554 else:
555 return Subscript(Subscript(self, index_list[-1]), self.linearise_index(index_list, self._iteration_range[0:-1], start_index))
556
557 def get_type(self):
558 return TDataBlock(len(self._iteration_range), self._underlying_type)
559
560 def print_cpp(self, indent_level=0):
561 if self.id is None:
562 return ' ' * indent_level * Node.spaces_per_tab + self._internal.print_cpp()
563 else:
564 return ' ' * indent_level * Node.spaces_per_tab + self.id
565
566 def print_omp(self, indent_level=0):
567 if self.id is None:
568 return ' ' * indent_level * Node.spaces_per_tab + self._internal.print_omp()
569 else:
570 return ' ' * indent_level * Node.spaces_per_tab + self.id
571
572 def print_sycl(self, indent_level=0):
573 if self.id is None:
574 return ' ' * indent_level * Node.spaces_per_tab + self._internal.print_sycl()
575 else:
576 return ' ' * indent_level * Node.spaces_per_tab + self.id
577
578
579 def print_mlir(self, indent_level=0):
580 if self.id is None:
581 return ' ' * indent_level * Node.spaces_per_tab + self._internal.print_mlir()
582 else:
583 return ' ' * indent_level * Node.spaces_per_tab + "%" + self.id
584
585 def print_tree(self, indent_level=0):
586 if self.id is None:
587 return ' ' * indent_level * Node.spaces_per_tab + f"""DataBlock:
588{self._internal.print_tree(indent_level + 1)}"""
589 else:
590 return ' ' * indent_level * Node.spaces_per_tab + "DataBlock: " + self.id
591
593 def __init__(self, dataBlock, filename):
594 self._filename = filename
595
596 self._dataBlock = dataBlock
597
598 loops = []
599 index = []
600
601 loops.append(For([Integer(0), self._dataBlock._iteration_range[-1][1], - self._dataBlock._iteration_range[-1][0]]))
602 for i in range(len(self._dataBlock._memory_range) - 2, -1, -1):
603 loops.append(For([Integer(0), self._dataBlock._iteration_range[i][1] - self._dataBlock._iteration_range[i][0]]))
604
605 for i in range(len(self._dataBlock._memory_range) - 1, -1, -1):
606 index.append(loops[i].get_iteration_variable())
607
608 loops[-1].add_statement(LogToFile(self._dataBlock.multidimensional_index(index), String("log")))
609 loops[-1].add_statement(LogToFile(String("\" \""), String("log")))
610 for i in range(0, len(loops) - 1):
611 loops[i].add_statement(loops[i + 1])
612
613 loops[-2].add_statement(LogToFile(String("\", \""), String("log")))
614 for i in range(0, len(self._dataBlock._memory_range) - 2):
615 loops[i].add_statement(LogToFile(String("\"\\n\""), String("log")))
616
617 for loop in loops[::-1]:
618 loop.close_scope()
619 self._loop = loops[0]
620
621 def print_cpp(self, indent_level=0):
622 return f"""log.open("{self._filename.print_cpp()}");
623{self._loop.print_cpp()}
624log.close();
625"""
626
627 def print_omp(self, indent_level=0):
628 return ""
629
630 def print_sycl(self, indent_level=0):
631 return ""
632
633 def print_mlir(self, indent_level=0):
634 return ""
635
636 def print_tree(self, indent_level=0):
637 return ""
638
640 def __init__(self, value: Expression, index: Expression):
641 self._value = value
642 self._index = index
643
644 def __neg__(self):
645 return UnaryOperation("-", self)
646
647 def get_type(self):
648 return self._value.get_type()
649
650 def print_cpp(self, indent_level=0):
651 return ' ' * indent_level * Node.spaces_per_tab + f"{self._value.print_cpp()}[{self._index.print_cpp()}]"
652
653 def print_omp(self, indent_level=0):
654 return ' ' * indent_level * Node.spaces_per_tab + f"{self._value.print_omp()}[{self._index.print_omp()}]"
655
656 def print_sycl(self, indent_level=0):
657 return ' ' * indent_level * Node.spaces_per_tab + f"{self._value.print_sycl()}[{self._index.print_sycl()}]"
658
659 def print_mlir(self, indent_level=0):
660 return ' ' * indent_level * Node.spaces_per_tab + f"{self._value.print_mlir()}[{self._index.print_mlir()}]"
661
662 def print_tree(self, indent_level=0):
663 return f"""Subscript:
664{self._value.print_tree(indent_level + 1)}
665{self._index.print_tree(indent_level + 1)}"""
666
667
669 def __init__(self, expression: Expression):
670 self._expression = expression
671
672 def get_type(self):
673 return self._expression.get_type()
674
675 def print_cpp(self, indent_level=0):
676 return ' ' * indent_level * Node.spaces_per_tab + "&" + self._expression.print_cpp()
677
678 def print_omp(self, indent_level=0):
679 return ' ' * indent_level * Node.spaces_per_tab + "&" + self._expression.print_omp()
680
681 def print_sycl(self, indent_level=0):
682 return ' ' * indent_level * Node.spaces_per_tab + "&" + self._expression.print_sycl()
683
684 def print_mlir(self, indent_level=0):
685 return ' ' * indent_level * Node.spaces_per_tab + "llvm.mlir.addressof" + self._expression.print_mlir()
686
687 def print_tree(self, indent_level=0):
688 return ' ' * indent_level * Node.spaces_per_tab + f"""Reference:
689{self._expression.print_tree(indent_level + 1)}"""
690
691
693 def __init__(self, value: Expression, index: Expression):
694 self._value = value
695 self._index = index
696
697 def get_type(self):
698 return self._value.get_type()
699
700 def print_cpp(self, indent_level=0):
701 return ' ' * indent_level * Node.spaces_per_tab + f"{self._value.print_cpp()}({self._index.print_cpp()})"
702
703 def print_omp(self, indent_level=0):
704 return ' ' * indent_level * Node.spaces_per_tab + f"{self._value.print_omp()}({self._index.print_omp()})"
705
706 def print_sycl(self, indent_level=0):
707 return ' ' * indent_level * Node.spaces_per_tab + f"{self._value.print_sycl()}({self._index.print_sycl()})"
708
709 def print_mlir(self, indent_level=0):
710 return ' ' * indent_level * Node.spaces_per_tab + f"{self._value.print_mlir()}({self._index.print_mlir()})"
711
712 def print_tree(self, indent_level=0):
713 return f"""
714VectorIndex:
715{self.id}
716{self._value.print_tree(indent_level + 1)}
717{self._index.print_tree(indent_level + 1)}"""
718
719
721 def __init__(self, operation, lhs, rhs):
722 self._operation = operation
723 self._lhs = lhs
724 self._rhs = rhs
725
726 def get_type(self):
727 return self._lhs.get_type()
728
729 def print_cpp(self, indent_level = 0):
730 return ' ' * indent_level * Node.spaces_per_tab + f"({self._lhs.print_cpp()} {self._operation} {self._rhs.print_cpp()})"
731
732 def print_omp(self, indent_level = 0):
733 return ' ' * indent_level * Node.spaces_per_tab + f"({self._lhs.print_omp()} {self._operation} {self._rhs.print_omp()})"
734
735 def print_sycl(self, indent_level = 0):
736 return ' ' * indent_level * Node.spaces_per_tab + f"({self._lhs.print_sycl()} {self._operation} {self._rhs.print_sycl()})"
737
738 def print_mlir(self, indent_level = 0):
739 if type(self._lhs.get_type()) is TFloat or type(self._rhs.get_type()) is TFloat:
740 if self._operation == "*":
741 mlir_operation = "arith.mulf"
742 elif self._operation == "/":
743 mlir_operation = "arith.divf"
744 elif self._operation == "+":
745 mlir_operation = "arith.addf"
746 elif self._operation == "-":
747 mlir_operation = "arith.subf"
748 else:
749 if self._operation == "*":
750 mlir_operation = "arith.muli"
751 elif self._operation == "/":
752 mlir_operation = "arith.divi"
753 elif self._operation == "+":
754 mlir_operation = "arith.addi"
755 elif self._operation == "-":
756 mlir_operation = "arith.subi"
757
758 return ' ' * indent_level * Node.spaces_per_tab + f"""{mlir_operation} {self._lhs.print_mlir()}, {self._rhs.print_mlir()} : {self._lhs.get_type().print_mlir()}"""
759
760 def print_tree(self, indent_level=0):
761 return f"""
762Comparison {self._operation}:
763{self._lhs.print_tree(indent_level + 1)}
764{self._rhs.print_tree(indent_level + 1)}"""
765
767 def __init__(self, operation, lhs, rhs):
768 self._operation = operation
769 self._lhs = lhs
770 self._rhs = rhs
771
772 if type(self._lhs) is BinaryOperation:
773 if self._lhs._operation == "+":
774 if self._lhs._rhs == Integer(0):
775 self._lhs = self._lhs._lhs
776 elif self._lhs._lhs == Integer(0):
777 self._lhs = self._lhs._rhs
778 elif self._lhs._operation == "-":
779 if self._lhs._rhs == Integer(0):
780 self._lhs = self._lhs._lhs
781 elif self._lhs._operation == "*":
782 if self._lhs._rhs == Integer(1):
783 self._lhs = self._lhs._lhs
784 elif self._lhs._lhs == Integer(1):
785 self._lhs = self._lhs._rhs
786
787 if type(self._rhs) is BinaryOperation:
788 if self._rhs._operation == "+":
789 if self._rhs._rhs == Integer(0):
790 self._rhs = self._rhs._lhs
791 elif self._rhs._lhs == Integer(0):
792 self._rhs = self._rhs._rhs
793 elif self._rhs._operation == "-":
794 if self._rhs._rhs == Integer(0):
795 self._rhs = self._rhs._lhs
796 elif self._rhs._operation == "*":
797 if self._rhs._rhs == Integer(1):
798 self._rhs = self._rhs._lhs
799 elif self._rhs._lhs == Integer(1):
800 self._rhs = self._rhs._rhs
801
802 if self._operation == "-" and type(self._rhs) is UnaryOperation and self._rhs._operation == "-":
803 self._operation = "+"
804 self._rhs = self._rhs._operand
805 if self._operation == "+" and type(self._rhs) is UnaryOperation and self._rhs._operation == "-":
806 self._operation = "-"
807 self._rhs = self._rhs._operand
808
809 def index(self, index):
810 return BinaryOperation(self._operation, self._lhs.index(index), self._rhs.index(index))
811
812 def multidimensional_index(self, index, start_index = 0):
813 return BinaryOperation(self._operation, self._lhs.multidimensional_index(index, start_index), self._rhs.multidimensional_index(index, start_index))
814
815
816 def __add__(self, rhs):
817 return BinaryOperation("+", self, rhs)
818
819 def __sub__(self, rhs):
820 return BinaryOperation("-", self, rhs)
821
822 def __mul__(self, rhs):
823 return BinaryOperation("*", self, rhs)
824
825 def __neg__(self):
826 return UnaryOperation("-", self)
827
828 def get_type(self):
829 return self._lhs.get_type()
830
831 def print_cpp(self, indent_level = 0):
832 return ' ' * indent_level * Node.spaces_per_tab + f"({self._lhs.print_cpp()} {self._operation} {self._rhs.print_cpp()})"
833
834 def print_omp(self, indent_level = 0):
835 return ' ' * indent_level * Node.spaces_per_tab + f"({self._lhs.print_omp()} {self._operation} {self._rhs.print_omp()})"
836
837 def print_sycl(self, indent_level = 0):
838 return ' ' * indent_level * Node.spaces_per_tab + f"({self._lhs.print_sycl()} {self._operation} {self._rhs.print_sycl()})"
839
840 def print_mlir(self, indent_level = 0):
841 if type(self._lhs.get_type()) is TFloat or type(self._rhs.get_type()) is TFloat:
842 if self._operation == "*":
843 mlir_operation = "arith.mulf"
844 elif self._operation == "/":
845 mlir_operation = "arith.divf"
846 elif self._operation == "+":
847 mlir_operation = "arith.addf"
848 elif self._operation == "-":
849 mlir_operation = "arith.subf"
850 else:
851 if self._operation == "*":
852 mlir_operation = "arith.muli"
853 elif self._operation == "/":
854 mlir_operation = "arith.divi"
855 elif self._operation == "+":
856 mlir_operation = "arith.addi"
857 elif self._operation == "-":
858 mlir_operation = "arith.subi"
859
860 return ' ' * indent_level * Node.spaces_per_tab + f"""{mlir_operation} {self._lhs.print_mlir()}, {self._rhs.print_mlir()} : {self._lhs.get_type().print_mlir()}"""
861
862 def print_tree(self, indent_level=0):
863 return f"""
864BinaryOperation {self._operation}:
865{self._lhs.print_tree(indent_level + 1)}
866{self._rhs.print_tree(indent_level + 1)}"""
867
868
870 def __init__(self, operation, operand):
871 self._operation = operation
872 self._operand = operand
873
874 def __add__(self, rhs):
875 return BinaryOperation("+", self, rhs)
876
877 def __sub__(self, rhs):
878 return BinaryOperation("-", self, rhs)
879
880 def __mul__(self, rhs):
881 return BinaryOperation("*", self, rhs)
882
883 def __neg__(self):
884 if self._operation == "-":
885 return self._operand
886 else:
887 return UnaryOperation("-", self._operand)
888
889 def get_type(self):
890 return self._operand.get_type()
891
892 def print_cpp(self, indent_level = 0):
893 return f"({self._operation} {self._operand.print_cpp()})"
894
895 def print_omp(self, indent_level = 0):
896 return f"({self._operation} {self._operand.print_omp()})"
897
898 def print_sycl(self, indent_level = 0):
899 return f"({self._operation} {self._operand.print_sycl()})"
900
901 def print_mlir(self, indent_level = 0):
902 if type(self._operand.get_type()) is TFloat:
903 if self._operation == "+":
904 return self._operand.print_mlir(indent_level)
905 elif self._operation == "-":
906 return ' ' * indent_level * Node.spaces_per_tab + f"""arith.negf {self._operand.print_mlir()} : {self._operand.get_type().print_mlir()}"""
907 else:
908 if self._operation == "+":
909 return self._operand.print_mlir(indent_level)
910 elif self._operation == "-":
911 return ' ' * indent_level * Node.spaces_per_tab + f"""arith.negf {self._operand.print_mlir()} : {self._operand.get_type().print_mlir()}"""
912
913 def print_tree(self, indent_level=0):
914 return f"""
915UnaryOperation {self._operation}:
916{self._operand.print_tree(indent_level + 1)}"""
917
919 def __init__(self, dataBlock, index):
920 if type(dataBlock) is Name:
921 self._dataBlock = Name.variables[dataBlock.id]
922 else:
923 self._dataBlock = dataBlock
924
925 self._index = index
926 self._iteration_range = self._dataBlock._iteration_range
927 self._memory_range = self._dataBlock._memory_range
928
929 def index(self, index):
930 return VectorIndex(self._dataBlock.index(index), self._index)
931
932 def multidimensional_index(self, index, start_index = 0):
933 return VectorIndex(self._dataBlock.multidimensional_index(index, start_index), self._index)
934
935 def get_type(self):
936 return TDataBlock(self._iteration_range, String("double"))
937
938 def print_cpp(self, indent_level = 0):
939 pass
940
941 def print_omp(self, indent_level = 0):
942 pass
943
944 def print_sycl(self, indent_level = 0):
945 pass
946
947 def print_mlir(self, indent_level = 0):
948 pass
949
950 def print_tree(self, indent_level=0):
951 pass
952
953
954class DataBlockComparison(Expression):
955 def __init__(self, operation, lhs, rhs):
956 if type(lhs) is Name and type(lhs.get_type()) is TDataBlock:
957 self._lhs = Name.variables[lhs.id]
958 else:
959 self._lhs = lhs
960
961 if type(rhs) is Name and type(rhs.get_type()) is TDataBlock:
962 self._rhs = Name.variables[rhs.id]
963 else:
964 self._rhs = rhs
965
966 self._operation = operation
968 self._offset = None
969
970 if type(self._lhs.get_type()) is TDataBlock and type(self._rhs.get_type()) is TDataBlock:
971 self._iteration_range = copy.deepcopy(self._lhs._iteration_range)
972 self._memory_range = copy.deepcopy(self._lhs._memory_range)
973 if type(self._rhs._iteration_range[0][1]) is not BinaryOperation and type(self._iteration_range[0][1]) is not BinaryOperation and self._rhs._iteration_range[0][1] > self._iteration_range[0][1]:
974 self._iteration_range[0][1] = self._rhs._iteration_range[0][1]
975 elif type(self._iteration_range[0][1]) is not BinaryOperation and self._iteration_range[0][1] == Integer(1):
976 self._iteration_range[0][1] = self._rhs._iteration_range[0][1]
977 if type(self._rhs._memory_range[0]) is not BinaryOperation and type(self._memory_range[0]) is not BinaryOperation and self._rhs._memory_range[0][1] > self._memory_range[0][1]:
978 self._memory_range[0] = self._rhs._memory_range[0]
979 elif type(self._memory_range[0]) is not BinaryOperation and self._memory_range[0] == Integer(1):
980 self._memory_range[0] = self._rhs._memory_range[0]
981
982 elif type(self._lhs.get_type()) is TDataBlock:
983 self._iteration_range = copy.deepcopy(self._lhs._iteration_range)
984 self._memory_range = copy.deepcopy(self._lhs._memory_range)
985 else:
986 self._iteration_range = copy.deepcopy(self._rhs._iteration_range)
987 self._memory_range = copy.deepcopy(self._rhs._memory_range)
988
989 def index(self, index):
990 return Comparison(self._operation, self._lhs.index(index), self._rhs.index(index))
991
992 def multidimensional_index(self, index, start_index = 0):
993 return Comparison(self._operation, self._lhs.multidimensional_index(index, start_index), self._rhs.multidimensional_index(index, start_index))
994
995 def get_type(self):
996 return TDataBlock(self._iteration_range, String("int"))
997
998 def print_cpp(self, indent_level = 0):
999 pass
1000
1001 def print_omp(self, indent_level = 0):
1002 pass
1003
1004 def print_sycl(self, indent_level = 0):
1005 pass
1006
1007 def print_mlir(self, indent_level = 0):
1008 pass
1009
1010 def print_tree(self, indent_level=0):
1011 pass
1012
1014 def __init__(self, operation, lhs, rhs, useFunctionSyntax=False):
1015 if type(lhs) is Name and type(lhs.get_type()) is TDataBlock:
1016 self._lhs = Name.variables[lhs.id]
1017 else:
1018 self._lhs = lhs
1019
1020 if type(rhs) is Name and type(rhs.get_type()) is TDataBlock:
1021 self._rhs = Name.variables[rhs.id]
1022 else:
1023 self._rhs = rhs
1024
1025 self._operation = operation
1027 self._useFunctionSyntax = useFunctionSyntax
1028
1029 if type(self._lhs.get_type()) is TDataBlock and type(self._rhs.get_type()) is TDataBlock:
1030 self._iteration_range = copy.deepcopy(self._lhs._iteration_range)
1031 self._memory_range = copy.deepcopy(self._lhs._memory_range)
1032 if len(self._iteration_range) == 1:
1033 self._iteration_range = copy.deepcopy(self._rhs._iteration_range)
1034 elif type(self._rhs._iteration_range[0][1]) is not BinaryOperation and type(self._iteration_range[0][1]) is not BinaryOperation and self._rhs._iteration_range[0][1] > self._iteration_range[0][1]:
1035 self._iteration_range[0][1] = self._rhs._iteration_range[0][1]
1036 elif type(self._iteration_range[0][1]) is not BinaryOperation and self._iteration_range[0][1] == Integer(1):
1037 self._iteration_range[0][1] = self._rhs._iteration_range[0][1]
1038
1039 self._memory_range[0] = [Integer(0), self._iteration_range[0][1] - self._iteration_range[0][0]]
1040
1041 elif type(self._lhs.get_type()) is TDataBlock:
1042 self._iteration_range = copy.deepcopy(self._lhs._iteration_range)
1043 self._memory_range = copy.deepcopy(self._lhs._memory_range)
1044 else:
1045 self._iteration_range = copy.deepcopy(self._rhs._iteration_range)
1046 self._memory_range = copy.deepcopy(self._rhs._memory_range)
1047
1048 self._offset = []
1049 for [x, y] in self._iteration_range:
1050 self._offset.append(x)
1051
1052
1053 def index(self, index):
1054 if self._useFunctionSyntax:
1055 return FunctionCall(self._operation, [self._lhs.index(index), self._rhs.index(index)])
1056 else:
1057 return BinaryOperation(self._operation, self._lhs.index(index), self._rhs.index(index))
1058
1059 def multidimensional_index(self, index, start_index = 0):
1060 if self._useFunctionSyntax:
1061 return FunctionCall(self._operation, [self._lhs.multidimensional_index(index, start_index), self._rhs.multidimensional_index(index, start_index)])
1062 else:
1063 return BinaryOperation(self._operation, self._lhs.multidimensional_index(index, start_index), self._rhs.multidimensional_index(index, start_index))
1064
1065 def get_type(self):
1066 lhs_is_double = False
1067 rhs_is_double = False
1068 if type(self._lhs.get_type()) is TDataBlock:
1069 lhs_is_double = (str(self._lhs.get_type().get_single()) == "double")
1070 else:
1071 lhs_is_double = (str(self._lhs.get_type()) == "double")
1072
1073 if type(self._rhs.get_type()) is TDataBlock:
1074 rhs_is_double = (str(self._rhs.get_type().get_single()) == "double")
1075 else:
1076 rhs_is_double = (str(self._rhs.get_type()) == "double")
1077
1078 if (lhs_is_double or rhs_is_double):
1079 return TDataBlock(self._iteration_range, String("double"))
1080 else:
1081 return TDataBlock(self._iteration_range, String("int"))
1082
1083 def __sub__(self, rhs):
1084 return DataBlockBinaryOperation("-", self, rhs)
1085
1086 def __add__(self, rhs):
1087 return DataBlockBinaryOperation("+", self, rhs)
1088
1089 def print_cpp(self, indent_level = 0):
1090 pass
1091
1092 def print_omp(self, indent_level = 0):
1093 pass
1094
1095 def print_sycl(self, indent_level = 0):
1096 pass
1097
1098 def print_mlir(self, indent_level = 0):
1099 pass
1100
1101 def print_tree(self, indent_level=0):
1102 pass
1103
1104
1106 def __init__(self, operation, dataBlock):
1107 if type(dataBlock) is Name and type(dataBlock.get_type()) is TDataBlock:
1108 self._dataBlock = Name.variables[dataBlock.id]
1109 else:
1110 self._dataBlock = dataBlock
1111
1112 self._operation = operation
1114 self._offset = None
1115
1116 self._iteration_range = copy.deepcopy(self._dataBlock._iteration_range)
1117 self._memory_range = copy.deepcopy(self._dataBlock._memory_range)
1118
1119 def index(self, index):
1120 return UnaryOperation(self._operation, self._dataBlock.index(index))
1121
1122 def multidimensional_index(self, index, start_index = 0):
1123 return UnaryOperation(self._operation, self._dataBlock.multidimensional_index(index, start_index))
1124
1125 def get_type(self):
1126 return TDataBlock(self._iteration_range, String("double"))
1127
1128 def print_cpp(self, indent_level = 0):
1129 pass
1130
1131 def print_omp(self, indent_level = 0):
1132 pass
1133
1134 def print_sycl(self, indent_level = 0):
1135 pass
1136
1137 def print_mlir(self, indent_level = 0):
1138 pass
1139
1140 def print_tree(self, indent_level=0):
1141 pass
1142
1143
1145 def __init__(self, dataBlock):
1146 if type(dataBlock) is Name:
1147 self._dataBlock = Name.variables[dataBlock.id]
1148 else:
1149 self._dataBlock = dataBlock
1150
1151 def set_output_variable(self, outputVariable):
1152 self._outputVariable = Name.variables[outputVariable.id]
1153 self._loop = For(self._dataBlock._iteration_range[-1])
1154 loops = [self._loop]
1155 self._loop.add_statement(Assignment(self._outputVariable.multidimensional_index([self._loop.get_iteration_variable()]), Float(0.0)))
1156 for i in range(len(self._dataBlock._iteration_range) - 2, -1, -1):
1157 loops.append(For(self._dataBlock._iteration_range[i]))
1158 loops[-2].add_statement(loops[-1])
1159 index = []
1160 for loop in loops[::-1]:
1161 index.append(loop.get_iteration_variable())
1162 loops[-1].add_statement(Assignment(self._outputVariable.multidimensional_index(index), Max(self._outputVariable.multidimensional_index(index), self._dataBlock.multidimensional_index(index))))
1163 for loop in loops:
1164 loop.close_scope()
1165
1166 def get_type(self):
1167 return TDataBlock(self._dataBlock._iteration_range, TFloat())
1168
1169 def print_cpp(self, indent_level = 0):
1170 return self._loop.print_cpp(indent_level)
1171
1172 def print_omp(self, indent_level = 0):
1173 return self._loop.print_cpp(indent_level)
1174
1175 def print_mlir(self, indent_level = 0):
1176 return self._loop.print_cpp(indent_level)
1177
1178 def print_tree(self, indent_level=0):
1179 return f"""DataBlockUnaryMax:
1180{self._dataBlock.print_tree(indent_level + 1)}"""
1181
1182
1184 def __init__(self, lhs, rhs):
1185 super().__init__("max", lhs, rhs)
1186
1187 def index(self, index):
1188 return Max(self._lhs.index(index), self._rhs.index(index))
1189
1190 def multidimensional_index(self, index, start_index = 0):
1191 return Max(self._lhs.multidimensional_index(index, start_index), self._rhs.multidimensional_index(index, start_index))
1192
1193 def print_cpp(self, indent_level = 0):
1194 pass
1195
1196 def print_omp(self, indent_level = 0):
1197 pass
1198
1199 def print_sycl(self, indent_level = 0):
1200 pass
1201
1202 def print_mlir(self, indent_level = 0):
1203 pass
1204
1205 def print_tree(self, indent_level=0):
1206 pass
1207
1208# Types
1209class TCustom(Type):
1210 def __init__(self, type_name):
1211 self._type = type_name
1212
1213 def print_cpp(self, indent_level=0):
1214 return ' ' * indent_level * Node.spaces_per_tab + self._type
1215
1216 def print_omp(self, indent_level=0):
1217 return ' ' * indent_level * Node.spaces_per_tab + self._type
1218
1219 def print_sycl(self, indent_level=0):
1220 return ' ' * indent_level * Node.spaces_per_tab + self._type
1221
1222 def print_mlir(self, indent_level=0):
1223 return ' ' * indent_level * Node.spaces_per_tab + self._type
1224
1225 def print_tree(self, indent_level=0):
1226 return ' ' * indent_level * Node.spaces_per_tab + "TCustom: " + self._type
1227
1228
1230 def print_cpp(self, indent_level = 0):
1231 return ' ' * indent_level * Node.spaces_per_tab + "int"
1232
1233 def print_omp(self, indent_level = 0):
1234 return ' ' * indent_level * Node.spaces_per_tab + "int"
1235
1236 def print_sycl(self, indent_level = 0):
1237 return ' ' * indent_level * Node.spaces_per_tab + "int"
1238
1239 def print_mlir(self, indent_level = 0):
1240 return ' ' * indent_level * Node.spaces_per_tab + "i32"
1241
1242 def print_tree(self, indent_level=0):
1243 return ' ' * indent_level * Node.spaces_per_tab + "TInteger"
1244
1246 def print_cpp(self, indent_level = 0):
1247 return ' ' * indent_level * Node.spaces_per_tab + "bool"
1248
1249 def print_omp(self, indent_level = 0):
1250 return ' ' * indent_level * Node.spaces_per_tab + "bool"
1251
1252 def print_sycl(self, indent_level = 0):
1253 return ' ' * indent_level * Node.spaces_per_tab + "bool"
1254
1255 def print_mlir(self, indent_level = 0):
1256 return ' ' * indent_level * Node.spaces_per_tab + "i1"
1257
1258 def print_tree(self, indent_level=0):
1259 return ' ' * indent_level * Node.spaces_per_tab + "TBoolean"
1260
1262 def __init__(self, reference = False):
1263 self._reference = reference
1264 def print_cpp(self, indent_level = 0):
1265 return ' ' * indent_level * Node.spaces_per_tab + "double" + ("&" if self._reference else "")
1266
1267 def print_omp(self, indent_level = 0):
1268 return ' ' * indent_level * Node.spaces_per_tab + "double" + ("&" if self._reference else "")
1269
1270 def print_sycl(self, indent_level = 0):
1271 return ' ' * indent_level * Node.spaces_per_tab + "double" + ("&" if self._reference else "")
1272
1273 def print_mlir(self, indent_level = 0):
1274 return ' ' * indent_level * Node.spaces_per_tab + "f64"
1275
1276 def print_tree(self, indent_level=0):
1277 return ' ' * indent_level * Node.spaces_per_tab + "TFloat"
1278
1279
1281 def print_cpp(self, indent_level = 0):
1282 return ' ' * indent_level * Node.spaces_per_tab + "const char*"
1283
1284 def print_omp(self, indent_level = 0):
1285 return ' ' * indent_level * Node.spaces_per_tab + "const char*"
1286
1287 def print_sycl(self, indent_level = 0):
1288 return ' ' * indent_level * Node.spaces_per_tab + "const char*"
1289
1290 def print_mlir(self, indent_level=0):
1291 pass
1292
1293 def print_tree(self, indent_level=0):
1294 return ' ' * indent_level * Node.spaces_per_tab + "TString"
1295
1296
1298 def __init__(self, dimensions, underlying_type):
1299 self._dimensions = dimensions
1300 self._underlying_type = underlying_type
1301
1302 def get_single(self):
1303 return self._underlying_type
1304
1305 def print_cpp(self, indent_level=0):
1306 return ' ' * indent_level * Node.spaces_per_tab + self._underlying_type.print_cpp() + "*"
1307
1308 def print_omp(self, indent_level=0):
1309 return ' ' * indent_level * Node.spaces_per_tab + self._underlying_type.print_omp() + "*"
1310
1311 def print_sycl(self, indent_level=0):
1312 return ' ' * indent_level * Node.spaces_per_tab + self._underlying_type.print_sycl() + "*"
1313
1314 def print_mlir(self, indent_level=0):
1315 return ' ' * indent_level * Node.spaces_per_tab + f"""memref<?>"""
1316
1317 def print_tree(self, indent_level=0):
1318 return ' ' * indent_level * Node.spaces_per_tab + "TDataBlock"
1319
1320
1321# Statements
1323 def __init__(self, boolean):
1324 self._boolean = boolean
1325 self._statements = []
1326
1327 def add_statement(self, statement):
1328 self._statements.append(statement)
1329
1330 def print_cpp(self, indent_level=0):
1331 statement_prints = [statement.print_cpp(indent_level + 1) for statement in self._statements]
1332 return ' ' * indent_level * Node.spaces_per_tab + f"""if constexpr ({self._boolean.print_cpp()}) {{
1333{os.linesep.join(statement_prints)}
1334""" + ' ' * indent_level * Node.spaces_per_tab + "}"
1335
1336 def print_omp(self, indent_level=0):
1337 statement_prints = [statement.print_cpp(indent_level + 1) for statement in self._statements]
1338 return ' ' * indent_level * Node.spaces_per_tab + f"""if constexpr ({self._boolean.print_cpp()}) {{
1339{os.linesep.join(statement_prints)}
1340""" + ' ' * indent_level * Node.spaces_per_tab + "}"
1341
1342 def print_sycl(self, indent_level=0):
1343 statement_prints = [statement.print_cpp(indent_level + 1) for statement in self._statements]
1344 return ' ' * indent_level * Node.spaces_per_tab + f"""if constexpr ({self._boolean.print_cpp()}) {{
1345{os.linesep.join(statement_prints)}
1346""" + ' ' * indent_level * Node.spaces_per_tab + "}"
1347
1348 def print_mlir(self, indent_level=0):
1349 pass
1350
1351 def print_tree(self, indent_level=0):
1352 statement_prints = [statement.print_tree(indent_level + 1) for statement in self._statements]
1353 return ' ' * indent_level * Node.spaces_per_tab + f"""If:
1354{os.linesep.join(statement_prints)}"""
1355
1357 _inuse_iteration_variables = set()
1358 _iteration_variable_names = ['i', 'j', 'k', 'l', 'n', 'm', 'a', 'b', 'c', 'd']
1359
1360 def __init__(self, iteration_range, iteration_variable_name = None, use_scheduler = False):
1361 self._iteration_range = iteration_range
1362 self._statements = []
1363 self._use_scheduler = False# use_scheduler
1364
1365 if iteration_variable_name == None:
1366 for variable_name in For._iteration_variable_names:
1367 if variable_name not in For._inuse_iteration_variables:
1368 self._iteration_variable = Name(variable_name, TInteger())
1369 For._inuse_iteration_variables.add(variable_name)
1370 break
1371 else:
1372 self._iteration_variable = Name(iteration_variable_name, TInteger())
1373 For._inuse_iteration_variables.update({iteration_variable_name : self._iteration_variable})
1374
1376 return self._iteration_range[1] - self._iteration_range[0]
1377
1379 return self._iteration_variable
1380
1381 def add_statement(self, statement):
1382 self._statements.append(statement)
1383
1384 def close_scope(self):
1385 For._inuse_iteration_variables.remove(self._iteration_variable.id)
1386
1387 def print_cpp(self, indent_level = 0):
1388 statement_prints = [statement.print_cpp(indent_level + 1) for statement in self._statements]
1389 if self._use_scheduler:
1390 return ' ' * indent_level * Node.spaces_per_tab + f"""parallelForWithSchedulerInstructions({self._iteration_variable.print_cpp()}, {self._iteration_range[1].print_cpp()}, loopParallelism) {{
1391{os.linesep.join(statement_prints)}
1392{' ' * indent_level * Node.spaces_per_tab}}}
1393endParallelFor"""
1394 else:
1395 return ' ' * indent_level * Node.spaces_per_tab + f"""for (int {self._iteration_variable.print_cpp()} = {self._iteration_range[0].print_cpp()}; {self._iteration_variable.print_cpp()} < {self._iteration_range[1].print_cpp()}; {self._iteration_variable.print_cpp()}++) {{
1396{os.linesep.join(statement_prints)}
1397""" + ' ' * indent_level * Node.spaces_per_tab + "}"
1398
1399
1400 def print_omp(self, indent_level = 0):
1401 statement_prints = [statement.print_omp(indent_level + 1) for statement in self._statements]
1402 return ' ' * indent_level * Node.spaces_per_tab + f"""for (int {self._iteration_variable.print_omp()} = {self._iteration_range[0].print_omp()}; {self._iteration_variable.print_omp()} < {self._iteration_range[1].print_omp()}; {self._iteration_variable.print_omp()}++) {{
1403{os.linesep.join(statement_prints)}
1404""" + ' ' * indent_level * Node.spaces_per_tab + "}"
1405
1406 def print_sycl(self, indent_level = 0):
1407 statement_prints = [statement.print_sycl(indent_level + 1) for statement in self._statements]
1408 return ' ' * indent_level * Node.spaces_per_tab + f"""for (int {self._iteration_variable.print_sycl()} = {self._iteration_range[0].print_sycl()}; {self._iteration_variable.print_sycl()} < {self._iteration_range[1].print_sycl()}; {self._iteration_variable.print_sycl()}++) {{
1409{os.linesep.join(statement_prints)}
1410""" + ' ' * indent_level * Node.spaces_per_tab + "}"
1411
1412 def print_mlir(self, indent_level = 0):
1413 statement_prints = [statement.print_mlir(indent_level + 1) for statement in self._statements]
1414 return ' ' * indent_level * Node.spaces_per_tab + f"""affine.for {self._iteration_variable.print_mlir()} = {self._iteration_range[0].print_mlir()} to {self._iteration_range[1].print_mlir()} {{
1415{os.linesep.join(statement_prints)}
1416""" + ' ' * indent_level * Node.spaces_per_tab + "}"
1417
1418 def print_tree(self, indent_level = 0):
1419 statement_prints = [statement.print_tree(indent_level + 1) for statement in self._statements]
1420 return ' ' * indent_level * Node.spaces_per_tab + f"""For:
1421{os.linesep.join(statement_prints)}"""
1422
1423
1424
1425class Comment(Statement):
1426 def __init__(self, comment):
1427 self._comment = comment
1428
1429 def get_type(self):
1430 return None
1431
1432 def print_cpp(self, indent_level = 0):
1433 return ' ' * indent_level * Node.spaces_per_tab + "//" + self._comment[1:]
1434
1435 def print_omp(self, indent_level = 0):
1436 return ' ' * indent_level * Node.spaces_per_tab + "//" + self._comment[1:]
1437
1438 def print_sycl(self, indent_level = 0):
1439 return ' ' * indent_level * Node.spaces_per_tab + "//" + self._comment[1:]
1440
1441 def print_mlir(self, indent_level = 0):
1442 return ' ' * indent_level * Node.spaces_per_tab + "//" + self._comment[1:]
1443
1444 def print_tree(self, indent_level=0):
1445 return ' ' * indent_level * Node.spaces_per_tab + "Comment"
1446
1447
1448class FunctionCall(Statement):
1449 def __init__(self, id, arguments, is_offloadable = False):
1450 self.id = id
1451 self._arguments = arguments
1452 self._is_offloadable = is_offloadable
1453
1454 def add_argument(self, argument: Expression):
1455 self._arguments.append(argument)
1456
1457 def print_cpp(self, indent_level=0):
1458 argument_prints = [argument.print_cpp() for argument in self._arguments]
1459 return ' ' * indent_level * Node.spaces_per_tab + f"""{self.id}({", ".join(argument_prints)});"""
1460
1461 def print_omp(self, indent_level=0):
1462 argument_prints = [argument.print_omp() for argument in self._arguments]
1463 return ' ' * indent_level * Node.spaces_per_tab + f"""{self.id}({", ".join(argument_prints)});"""
1464
1465 def print_sycl(self, indent_level=0):
1466 argument_prints = [argument.print_sycl() for argument in self._arguments]
1467 return ' ' * indent_level * Node.spaces_per_tab + f"""{self.id}({", ".join(argument_prints)});"""
1468
1469 def print_mlir(self, indent_level=0):
1470 argument_prints = [argument.print_cpp() for argument in self._arguments]
1471 return ' ' * indent_level * Node.spaces_per_tab + f"""func.call @{self.id}({", ".join(argument_prints)});"""
1472
1473 def print_tree(self, indent_level=0):
1474 argument_prints = [argument.print_tree() for argument in self._arguments]
1475 return ' ' * indent_level * Node.spaces_per_tab + f"""FunctionCall:
1476{' ' * (indent_level + 1) * Node.spaces_per_tab + ",".join(argument_prints)}"""
1477
1478
1479class Max(Statement):
1480 def __init__(self, lhs, rhs):
1481 self._lhs = lhs
1482 self._rhs = rhs
1483
1484 def print_cpp(self, indent_level=0):
1485 return ' ' * indent_level * Node.spaces_per_tab + f"""std::max({self._lhs.print_cpp()}, {self._rhs.print_cpp()});"""
1486
1487 def print_omp(self, indent_level=0):
1488 return ' ' * indent_level * Node.spaces_per_tab + f"""std::max({self._lhs.print_omp()}, {self._rhs.print_omp()});"""
1489
1490 def print_sycl(self, indent_level=0):
1491 return ' ' * indent_level * Node.spaces_per_tab + f"""std::max({self._lhs.print_sycl()}, {self._rhs.print_sycl()});"""
1492
1493 def print_mlir(self, indent_level=0):
1494 return ' ' * indent_level * Node.spaces_per_tab + f"""affine.max({self._lhs.print_mlir()}, {self._rhs.print_mlir()});"""
1495
1496 def print_tree(self, indent_level=0):
1497 return ' ' * indent_level * Node.spaces_per_tab + f"""Max:
1498{self._lhs.print_tree(indent_level + 1)}
1499{self._rhs.print_tree(indent_level + 1)}"""
1500
1501class MemoryAllocation(Statement):
1502 memoryAllocated = []
1503
1504 def __init__(self, name: Name, object_type: Type, dimensions, specify_type = True, add_to_stack=True):
1505 self._type = object_type
1506 if type(dimensions[0]) is list:
1507 self._dimensions = [dimension[1] for dimension in dimensions]
1508 else:
1509 self._dimensions = dimensions
1510 self._name = name
1511 self._specify_type = specify_type
1512
1513 if add_to_stack == True:
1514 MemoryAllocation.memoryAllocated[-1].append(self)
1515
1516 if len(self._dimensions) > 1:
1517 size = self._dimensions[-2]
1518 for i in range(len(self._dimensions) - 3, -1, -1):
1519 size = size * self._dimensions[i]
1520
1521 self._loop = For([Integer(0), self._dimensions[-1]])
1522 self._loop.add_statement(MemoryAllocation(Subscript(self._name, self._loop.get_iteration_variable()), self._type, [size], False, False))
1523 self._loop.close_scope()
1524
1525 def print_cpp(self, indent_level=0):
1526 if len(self._dimensions) == 1:
1527 return ' ' * indent_level * Node.spaces_per_tab + f"""{self._type.print_cpp() + "* " if self._specify_type else ""}{self._name.print_cpp()} = new {self._type.print_cpp()}[{self._dimensions[0].print_cpp()}];"""
1528 else:
1529 size = self._dimensions[-2]
1530 for i in range(len(self._dimensions) - 3, -1, -1):
1531 size = size * self._dimensions[i]
1532
1533 return ' ' * indent_level * Node.spaces_per_tab + f"""{self._type.print_cpp()}** {self._name.print_cpp()} = new {self._type.print_cpp()}*[{self._dimensions[-1].print_cpp()}];
1534{self._loop.print_cpp(indent_level)}"""
1535
1536 def print_omp(self, indent_level=0):
1537 if len(self._dimensions) == 1:
1538 return ' ' * indent_level * Node.spaces_per_tab + f"""{self._type.print_omp() + "* " if self._specify_type else ""}{self._name.print_omp()} = new {self._type.print_omp()}[{self._dimensions[0].print_omp()}];"""
1539 else:
1540 return ' ' * indent_level * Node.spaces_per_tab + f"""{self._type.print_omp()}** {self._name.print_omp()} = new {self._type.print_omp()}*[{self._dimensions[-1].print_omp()}];
1541{self._loop.print_omp(indent_level)}"""
1542
1543 def print_sycl(self, indent_level=0):
1544 if len(self._dimensions) == 1:
1545 return ' ' * indent_level * Node.spaces_per_tab + f"""{self._type.print_sycl() + "* " if self._specify_type else ""}{self._name.print_sycl()} = new {self._type.print_sycl()}[{self._dimensions[0].print_sycl()}];"""
1546 else:
1547 return ' ' * indent_level * Node.spaces_per_tab + f"""{self._type.print_sycl()}** {self._name.print_sycl()} = new {self._type.print_sycl()}*[{self._dimensions[-1].print_sycl()}];
1548{self._loop.print_sycl(indent_level)}"""
1549
1550 def print_mlir(self, indent_level=0):
1551 return ' ' * indent_level * Node.spaces_per_tab + f"""{self._name.print_mlir()} = memref.alloc() : {self._name.get_type().print_mlir()}"""
1552
1553 def print_tree(self, indent_level=0):
1554 return ' ' * indent_level * Node.spaces_per_tab + f"""MemoryAllocation:
1555{self._name.print_tree(indent_level + 1)}"""
1556
1557
1558class MemoryDeallocation(Statement):
1559 def __init__(self, allocation: MemoryAllocation):
1560 self._dimensions = allocation._dimensions
1561 self._name = allocation._name
1562
1563 if len(self._dimensions) > 1:
1564 size = self._dimensions[-2]
1565 for i in range(len(self._dimensions) - 3, -1, -1):
1566 size = size * self._dimensions[i]
1567
1568 self._loop = For([Integer(0), self._dimensions[-1]])
1569 self._loop.add_statement(MemoryDeallocation(MemoryAllocation(Subscript(self._name, self._loop.get_iteration_variable()), None, [size], False, False)))
1570 self._loop.close_scope()
1571
1572 def print_cpp(self, indent_level=0):
1573 if len(self._dimensions) == 1:
1574 return ' ' * indent_level * Node.spaces_per_tab + f"""delete[] {self._name.print_cpp()};"""
1575 else:
1576 return f"""{self._loop.print_cpp(indent_level)}
1577{' ' * indent_level * Node.spaces_per_tab}delete[] {self._name.print_cpp()};"""
1578
1579 def print_omp(self, indent_level=0):
1580 if len(self._dimensions) == 1:
1581 return ' ' * indent_level * Node.spaces_per_tab + f"""delete[] {self._name.print_omp()};"""
1582 else:
1583 return f"""{self._loop.print_omp(indent_level)}
1584{' ' * indent_level * Node.spaces_per_tab}delete[] {self._name.print_omp()};"""
1585
1586 def print_sycl(self, indent_level=0):
1587 if len(self._dimensions) == 1:
1588 return ' ' * indent_level * Node.spaces_per_tab + f"""delete[] {self._name.print_sycl()};"""
1589 else:
1590 return f"""{self._loop.print_sycl(indent_level)}
1591{' ' * indent_level * Node.spaces_per_tab}delete[] {self._name.print_sycl()};"""
1592
1593 def print_mlir(self, indent_level=0):
1594 return ' ' * indent_level * Node.spaces_per_tab + f"""memref.dealloc {self._name.print_mlir()} : {self._name.get_type().print_mlir()}"""
1595
1596 def print_tree(self, indent_level=0):
1597 return ' ' * indent_level * Node.spaces_per_tab + f"""MemoryDeallocation:
1598{self._name.print_tree(indent_level + 1)}"""
1599
1600
1601class Construction(Statement):
1602 def __init__(self, name: Name, expression: Expression):
1603 self._name = name
1604 self._expression = expression
1605 Name.variables.update({self._name.id: expression})
1606
1607 def print_cpp(self, indent_level = 0):
1608 return ' ' * indent_level * Node.spaces_per_tab + f"{self._name.get_type().print_cpp()} {self._name.print_cpp()} = {self._expression.print_cpp()};"
1609
1610 def print_omp(self, indent_level = 0):
1611 return ' ' * indent_level * Node.spaces_per_tab + f"{self._name.get_type().print_omp()} {self._name.print_omp()} = {self._expression.print_omp()};"
1612
1613 def print_sycl(self, indent_level = 0):
1614 return ' ' * indent_level * Node.spaces_per_tab + f"{self._name.get_type().print_sycl()} {self._name.print_sycl()} = {self._expression.print_sycl()};"
1615
1616 def print_mlir(self, indent_level = 0):
1617 return ' ' * indent_level * Node.spaces_per_tab + f"{self._name.print_mlir()} = {self._expression.print_mlir()} : {self._name.get_type().print_mlir()}"
1618
1619 def print_tree(self, indent_level=0):
1620 return f"""
1621Construction:
1622{self._name.print_tree(indent_level + 1)}
1623{self._expression.print_tree(indent_level + 1)}"""
1624
1625
1626class DataBlockConstructionFromExisting(Statement):
1627 def __init__(self, name: Name, dataBlock: DataBlock):
1628 self._name = name
1629 self._dataBlock = dataBlock
1630 self._dataBlock.id = name.id
1631 Name.variables.update({self._name.id: self._dataBlock})
1632
1633 def print_cpp(self, indent_level = 0):
1634 type_print = self._name.get_type().print_cpp()
1635 if len(self._dataBlock._iteration_range) > 1:
1636 type_print += "*"
1637 return ' ' * indent_level * Node.spaces_per_tab + f"{type_print} {self._name.print_cpp()} = {self._dataBlock._internal.print_cpp()};"
1638
1639 def print_omp(self, indent_level = 0):
1640 type_print = self._name.get_type().print_omp()
1641 if len(self._dataBlock._iteration_range) > 1:
1642 type_print += "*"
1643 return ' ' * indent_level * Node.spaces_per_tab + f"{type_print} {self._name.print_omp()} = {self._dataBlock._internal.print_omp()};"
1644
1645 def print_sycl(self, indent_level = 0):
1646 type_print = self._name.get_type().print_sycl()
1647 if len(self._dataBlock._iteration_range) > 1:
1648 type_print += "*"
1649 return ' ' * indent_level * Node.spaces_per_tab + f"{type_print} {self._name.print_sycl()} = {self._dataBlock._internal.print_sycl()};"
1650
1651 def print_mlir(self, indent_level = 0):
1652 type_print = self._name.get_type().print_cpp()
1653 if len(self._dataBlock._iteration_range) > 1:
1654 type_print = "!llvm.ptr"
1655 return ' ' * indent_level * Node.spaces_per_tab + f"{self._name.print_cpp()} = {self._dataBlock._internal.print_mlir()} : {type_print}"
1656
1657 def print_tree(self, indent_level=0):
1658 return f"""
1659DataBlockConstructionFromExisting:
1660{self._name.print_tree(indent_level + 1)}
1661{self._dataBlock._internal.print_tree(indent_level + 1)}"""
1662
1663
1664class DataBlockConstructionFromOperation:
1665 def __init__(self, name: Name, dataBlockOperation):
1666 self._name = name
1667 self._dataBlock = DataBlock(dataBlockOperation._iteration_range, None, False, self._name.id, underlying_type=dataBlockOperation.get_type().get_single())
1668 Name.variables.update({self._name.id: self._dataBlock})
1669 if type(dataBlockOperation) is DataBlockBinaryOperation or type(dataBlockOperation) is DataBlockUnaryOperation or type(dataBlockOperation) is DataBlockMax or type(dataBlockOperation) is DataBlockComparison or type(dataBlockOperation._internal) is String:
1670 self._operation = dataBlockOperation
1671 else:
1672 self._operation = dataBlockOperation._internal
1673
1674 self.memoryAllocation = MemoryAllocation(self._name, self._dataBlock.get_type().get_single(), self._dataBlock._memory_range)
1675 self.assignment = DataBlockAssignment(self._name, self._operation)
1676
1677 def print_cpp(self, indent_level = 0):
1678 return f"""{self.memoryAllocation.print_cpp(indent_level)}
1679{self.assignment.print_cpp(indent_level)}"""
1680
1681 def print_omp(self, indent_level = 0):
1682 return f"""{self.memoryAllocation.print_omp(indent_level)}
1683{self.assignment.print_omp(indent_level)}"""
1684
1685 def print_sycl(self, indent_level = 0):
1686 return f"""{self.memoryAllocation.print_sycl(indent_level)}
1687{self.assignment.print_sycl(indent_level)}"""
1688
1689 def print_mlir(self, indent_level = 0):
1690 return f"""{self.memoryAllocation.print_mlir(indent_level)}
1691{self.assignment.print_mlir(indent_level)}"""
1692
1693 def print_tree(self, indent_level = 0):
1694 return ' ' * indent_level * Node.spaces_per_tab + f"""DataBlockConstructionFromOperation:
1695{self._name.print_tree(indent_level + 1)}
1696{self._dataBlock.print_tree(indent_level + 1)}"""
1697
1698
1699class DataBlockConstructionFromFunction(Statement):
1700 def __init__(self, name: Name, dataBlock: DataBlock):
1701 self._name = name
1702 self._dataBlock = dataBlock
1703 self._dataBlock.id = name.id
1704 Name.variables.update({self._name.id: self._dataBlock})
1705
1706 self.memoryAllocation = MemoryAllocation(self._name, self._dataBlock.get_type().get_single(), self._dataBlock._memory_range)
1707 self.initialisation = For(self._dataBlock._memory_range[-1])
1708
1709
1710 inner_loops = []
1711 for i in range(len(self._dataBlock._iteration_range) - 2, 0, -1):
1712 inner_loop = For([Integer(0), self._dataBlock._memory_range[i][1] - self._dataBlock._memory_range[i][0]])
1713 inner_loops.append(inner_loop)
1714 if type(self._dataBlock._internal) is not FunctionCall:
1715 inner_loop = For(self._dataBlock._memory_range[0])
1716 inner_loops.append(inner_loop)
1717
1718 index = []
1719 for i in range(len(inner_loops)):
1720 index.append(inner_loops[-(i + 1)].get_iteration_variable())
1721 index.append(self.initialisation.get_iteration_variable())
1722
1723 output_offset = [Integer(0) for i in self._dataBlock._iteration_range[1:]]
1724 input_offset = [self._dataBlock._iteration_range[i][0] - Name.variables[self._dataBlock._internal._arguments[0].id]._iteration_range[i][0] for i in range(1, len(self._dataBlock._iteration_range))]
1725
1726 if type(self._dataBlock._internal) is FunctionCall:
1727 functionCall = FunctionCall(self._dataBlock._internal.id, [])
1728 functionCall._is_offloadable = self._dataBlock._internal._is_offloadable
1729
1730 functionCall.add_argument(Reference(self._dataBlock._internal._arguments[0].multidimensional_index(index, 1)))
1731
1732 for argument in self._dataBlock._internal._arguments[1:]:
1733 if type(argument.get_type()) is TDataBlock:
1734 offset = [i[0] for i in Name.variables[argument.id]._iteration_range[1:]]
1735 if len(Name.variables[argument.id]._iteration_range) == 1:
1736 functionCall.add_argument(Name.variables[argument.id].multidimensional_index(index, 1))
1737 else:
1738 functionCall.add_argument(Reference(Name.variables[argument.id].multidimensional_index(index, 1)))
1739 else:
1740 functionCall.add_argument(argument)
1741
1742 functionCall.add_argument(Reference(self._dataBlock.multidimensional_index(index, 1)))
1743 if functionCall._is_offloadable:
1744 functionCall.add_argument(String("Solver::Offloadable::Yes"))
1745 else:
1746 functionCall = Assignment(self._dataBlock.multidimensional_index(index), self._dataBlock._internal)
1747
1748 self.initialisation.add_statement(inner_loops[0])
1749
1750 for i in range (0, len(inner_loops) - 1):
1751 inner_loops[i].add_statement(inner_loops[i + 1])
1752 inner_loops[-1].add_statement(functionCall)
1753
1754 for i in range(len(inner_loops) - 1, -1, -1):
1755 inner_loops[i].close_scope()
1756 self.initialisation.close_scope()
1757
1758 def print_cpp(self, indent_level = 0):
1759 return f"""{self.memoryAllocation.print_cpp(indent_level)}
1760{self.initialisation.print_cpp(indent_level)}"""
1761
1762 def print_omp(self, indent_level = 0):
1763 return f"""{self.memoryAllocation.print_omp(indent_level)}
1764{' ' * indent_level * Node.spaces_per_tab}#pragma omp parallel for simd collapse(Dimensions + 1) schedule(static, 1)
1765{self.initialisation.print_omp(indent_level)}"""
1766
1767 def print_sycl(self, indent_level = 0):
1768 return f"""{self.memoryAllocation.print_sycl(indent_level)}
1769{self.initialisation.print_sycl(indent_level)}"""
1770
1771 def print_mlir(self, indent_level = 0):
1772 return f"""{self.memoryAllocation.print_mlir(indent_level)}
1773{self.initialisation.print_mlir(indent_level)}"""
1774
1775 def print_tree(self, indent_level = 0):
1776 return ' ' * indent_level * Node.spaces_per_tab + f"""DataBlockConstructionFromFunction:
1777{self._name.print_tree(indent_level + 1)}
1778{self._dataBlock.print_tree(indent_level + 1)}"""
1779
1780class Assignment(Statement):
1781 def __init__(self, lhs: Expression, rhs: Expression):
1782 self._lhs = lhs
1783 self._rhs = rhs
1784
1785 def print_cpp(self, indent_level = 0):
1786 return ' ' * indent_level * Node.spaces_per_tab + f"{self._lhs.print_cpp()} = {self._rhs.print_cpp()};"
1787
1788 def print_omp(self, indent_level = 0):
1789 return ' ' * indent_level * Node.spaces_per_tab + f"{self._lhs.print_omp()} = {self._rhs.print_omp()};"
1790
1791 def print_sycl(self, indent_level = 0):
1792 return ' ' * indent_level * Node.spaces_per_tab + f"{self._lhs.print_sycl()} = {self._rhs.print_sycl()};"
1793
1794 def print_mlir(self, indent_level = 0):
1795 return ' ' * indent_level * Node.spaces_per_tab + f"{self._lhs.print_mlir()} = {self._rhs.print_mlir()} : {self._lhs.get_type().print_mlir()}"
1796
1797 def print_tree(self, indent_level = 0):
1798 return ' ' * indent_level * Node.spaces_per_tab + f"""Assignment:
1799{self._lhs.print_tree(indent_level + 1)}
1800{self._rhs.print_tree(indent_level + 1)}"""
1801
1802
1803class DataBlockAssignment(Statement):
1804 def __init__(self, lhs: Expression, rhs: Expression):
1805 if type(lhs) is Name:
1806 self._lhs = Name.variables[lhs.id]
1807 self._lhs._id = lhs.id
1808 else:
1809 self._lhs = lhs
1810
1811 if type(rhs) is Name:
1812 self._rhs = Name.variables[rhs.id]
1813 else:
1814 self._rhs = rhs
1815
1816 loops = []
1817 index = []
1818
1819 loops.append(For([Integer(0), self._lhs._iteration_range[-1][1] - self._lhs._iteration_range[-1][0]]))
1820 for i in range(len(self._lhs._memory_range) - 2, -1, -1):
1821 loops.append(For([Integer(0), self._lhs._iteration_range[i][1] - self._lhs._iteration_range[i][0]]))
1822
1823 for i in range(len(self._lhs._memory_range) - 1, -1, -1):
1824 index.append(loops[i].get_iteration_variable())
1825
1826 if type(self._rhs.get_type()) is TDataBlock:
1827 rhs_offset = []
1828 for i in range(len(self._lhs._iteration_range)):
1829 lhs_start = self._lhs._iteration_range[i][0]
1830 rhs_start = self._rhs._iteration_range[i][0]
1831 if type(rhs_start.get_type()) is TDataBlock:
1832 rhs_start = Subscript(rhs_start, lhs_start)
1833 rhs_offset.append(lhs_start - rhs_start)
1834 else:
1835 rhs_offset = [Integer(0) for i in self._lhs._iteration_range]
1836
1837 lhs_offset = [Integer(0) for i in self._lhs._iteration_range]
1838
1839 loops[-1].add_statement(Assignment(self._lhs.multidimensional_index(index), self._rhs.multidimensional_index(index)))
1840 for i in range(0, len(loops) - 1):
1841 loops[i].add_statement(loops[i + 1])
1842
1843 for loop in loops[::-1]:
1844 loop.close_scope()
1845 self._loop = loops[0]
1846 self._numDimensions = len(loops)
1847
1848 def print_cpp(self, indent_level = 0):
1849 return self._loop.print_cpp(indent_level)
1850
1851 def print_omp(self, indent_level = 0):
1852 return ' ' * indent_level * Node.spaces_per_tab + f"""#pragma omp parallel for simd collapse(Dimensions + 1) schedule(static, 1)
1853{self._loop.print_omp(indent_level)}"""
1854
1855 def print_sycl(self, indent_level = 0):
1856 return self._loop.print_sycl(indent_level)
1857
1858 def print_mlir(self, indent_level = 0):
1859 return self._loop.print_mlir(indent_level)
1860
1861 def print_tree(self, indent_level = 0):
1862 return ' ' * indent_level * Node.spaces_per_tab + f"""DataBlockAssignment:
1863{self._lhs.print_tree(indent_level + 1)}
1864{self._rhs.print_tree(indent_level + 1)}"""
multidimensional_index(self, index, start_index=0)
__init__(self, operation, lhs, rhs)
__init__(self, operation, lhs, rhs, useFunctionSyntax=False)
multidimensional_index(self, index, start_index=0)
multidimensional_index(self, index, start_index=0)
multidimensional_index(self, index, start_index=0)
linearise_index(self, indices, dimensions, offset_start_index=0)
multidimensional_index(self, index_list, start_index=0)
__init__(self, iteration_range, internal, requires_memory_allocation, id=None, underlying_type=String("double"))
multidimensional_index(self, index_list, start_index=0)
Definition SyntaxTree.py:35
print_sycl(self, indent_level=0)
print_omp(self, indent_level=0)
print_cpp(self, indent_level=0)
__init__(self, value, string=None, reference=False)
print_mlir(self, indent_level=0)
print_tree(self, indent_level=0)
__init__(self, iteration_range, iteration_variable_name=None, use_scheduler=False)
print_cpp(self, indent_level=0)
__init__(self, id, Type return_type, template=None, namespaces=[], stateless=False)
Definition SyntaxTree.py:71
print_omp(self, indent_level=0)
print_cpp(self, indent_level=0)
print_tree(self, indent_level=0)
print_mlir(self, indent_level=0)
print_sycl(self, indent_level=0)
__init__(self, value, string=None)
__init__(self, dataBlock, filename)
__init__(self, statementToLog, outputStream)
print_omp(self, indent_level=0)
print_tree(self, indent_level=0)
__init__(self, id, type=None)
print_mlir(self, indent_level=0)
print_sycl(self, indent_level=0)
print_cpp(self, indent_level=0)
multidimensional_index(self, index_list, start_index=0)
print_cpp(self, indent_level=0)
Definition SyntaxTree.py:11
print_tree(self, indent_level=0)
Definition SyntaxTree.py:27
print_omp(self, indent_level=0)
Definition SyntaxTree.py:19
print_sycl(self, indent_level=0)
Definition SyntaxTree.py:23
print_mlir(self, indent_level=0)
Definition SyntaxTree.py:15
__init__(self, Expression expression)
print_mlir(self, indent_level=0)
print_cpp(self, indent_level=0)
print_sycl(self, indent_level=0)
print_tree(self, indent_level=0)
print_omp(self, indent_level=0)
__init__(self, Expression value, Expression index)
__init__(self, dimensions, underlying_type)
__init__(self, reference=False)
__init__(self, Expression value, Expression index)