Peano
Loading...
Searching...
No Matches
frontend.py
Go to the documentation of this file.
1from sympy import *
2from sympy.codegen.ast import integer, real, none
3from exahype import sympy
4
5def viable(dim,patch_size,halo_size):
6 if dim not in [2,3]:
7 return False
8 if patch_size < 1:
9 return False
10 if halo_size < 0:
11 return False
12 return True
13
15 def __init__(self,dim,patch_size,halo_size,n_real,n_aux,n_patches=1):
16 if not viable(dim,patch_size,halo_size):
17 raise Exception('check viability of inputs')
18 self.dim = dim
19 self.patch_size = patch_size
20 self.halo_size = halo_size
21 self.n_patches = n_patches
22 self.n_real = n_real
23 self.n_aux = n_aux
24
25 self.indexes = [index for index in symbols('patch i j', cls=Idx)]
26 if dim == 3:
27 self.indexes.append(symbols('k', cls=Idx))
28 self.indexes.append(symbols('var', cls=Idx))
29
30 self.literals = [] #lines written in c++
31
32 self.parents = {} #which items are parents of which items
33 self.inputs = []
34 self.input_types = []
35 self.items = [] #stored as strings
36 self.directional_items = [] #stored as strings
37 self.directional_consts = {} #stores values of the const for each direction
38 self.functions = [] #stored as sympy functions
39 self.item_struct = {} #0 for none, 1 for n_real, 2 for n_real + n_aux, -1 for not applicable (for example a constant)
40
41 halo_range = (0,self.patch_size+2*self.halo_size)
42 default_range = halo_range
43 self.default_shape = ([self.n_patches] + [default_range for _ in range(self.dim)])
44 self.all_items = {'i':Idx('i',default_range),'j':Idx('j',default_range),'k':Idx('k',default_range),'patch':Idx('patch',(0,self.n_patches)),'var':Idx('var',(0,n_real+n_aux))} #as sympy objects
45
46 self.LHS = []
47 self.RHS = []
48 self.directions = [] #used for cutting the halo in particular directions
49 self.struct_inclusion = [] #how much of the struct to loop over, 0 for none, 1 for n_real, 2 for n_real + n_aux
50
51 self.const('dim',define=f'int dim = {dim};')
52 self.const('patch_size',define=f'int patch_size = {patch_size};')
53 self.const('halo_size',define=f'int halo_size = {halo_size};')
54 self.const('n_real',define=f'int n_real = {n_real};')
55 self.const('n_aux',define=f'int n_aux = {n_aux};')
56
57 def const(self,expr,in_type="double",parent=None,define=None):
58 self.all_items[expr] = symbols(expr)
59 if parent != None:
60 self.parents[expr] = str(parent)
61 return symbols(expr)
62 if define != None:
63 self.literals.append(define)
64 return symbols(expr)
65 self.inputs.append(expr)
66 self.input_types.append(in_type)
67 return symbols(expr, real=True)
68
69 def directional_const(self,expr,vals):
70 if len(vals) != self.dim:
71 raise Exception("directional constant must have values for each direction")
72 self.directional_consts[expr] = vals
73 self.all_items[expr] = symbols(expr, real=True)
74 return symbols(expr, real=True)
75
76 def item(self,expr,struct=True,in_type="double*",parent=None):
77 self.items.append(expr)
78 self.all_items[expr] = IndexedBase(expr, real=True)
79 if len(self.items) == 1:
80 self.inputs.append(expr)# = expr
81 self.input_types.append(in_type)
82 self.item_struct[expr] = 0 + struct*2
83 if parent != None:
84 self.parents[expr] = str(parent)
85 return IndexedBase(expr, real=True)
86
87 def directional_item(self,expr,struct=True):
88 self.directional_items.append(expr)
89 self.item_struct[expr] = 0 + struct*1
90 tmp = ''
91 extra = ['_x','_y','_z']
92 for i in range(self.dim):
93 direction = extra[i]
94 tmp = expr + direction
95 self.all_items[tmp] = IndexedBase(tmp, real=True)
96 self.item_struct[tmp] = 0 + struct*1
97 return IndexedBase(expr, real=True)
98
99 def function(self, expr, parent=None, parameter_types = [], return_type = none, ):
100 if parent != None:
101 self.parents[expr] = str(parent)
102 self.functions.append(expr)
103 func = sympy.TypedFunction(expr)
104 func.returnType(return_type)
105 func.parameterTypes(parameter_types)
106 self.all_items[expr] = func
107 return func
108
109 def single(self,LHS,RHS='',direction=-1,struct=False):
110 if struct:
111 self.struct_inclusion.append(1)
112 elif str(type(LHS)) in self.functions or str(type(RHS)) in self.functions:
113 self.struct_inclusion.append(0)
114 elif str(LHS).partition('[')[0] in self.inputs:
115 self.struct_inclusion.append(2)
116 elif self.RHS == '':
117 self.struct_inclusion.append(0)
118 else:
119 tmp = [val for key,val in self.item_struct.items() if key in (str(LHS)+str(RHS))]
120 self.struct_inclusion.append(min(tmp))
121
122 if str(LHS).partition('[')[0] in self.inputs:
123 self.directions.append(-2)
124 else:
125 self.directions.append(direction)
126
127 self.LHS.append(self.index(LHS,direction))
128 self.RHS.append(self.index(RHS,direction))
129
130 def directional(self,LHS,RHS='',struct=False):
131 for i in range(self.dim):
132 for j, key in enumerate(self.directional_consts):
133 if key in str(LHS) or key in str(RHS):
134 self.LHS.append(self.all_items[key])
135 self.RHS.append(self.directional_consts[key][i])
136 self.struct_inclusion.append(-1)
137 self.directions.append(-1)
138 self.single(LHS,RHS,i+1,struct)
139
140 def index(self,expr_in,direction=-1):
141 if expr_in == '':
142 return ''
143
144 expr = ''
145 word = ''
146 wait = False
147 for i,char in enumerate(str(expr_in)):
148 if char == ']':
149 wait = False
150
151 if direction >= 0 and word in self.directional_items and not (str(expr_in)+"1")[i+1].isalpha():
152 thing = ['_patch','_x','_y','_z']
153 expr += thing[direction]
154 word += thing[direction]
155
156 if char == '[':
157 wait = True
158 expr += char
159
160 for j,index in enumerate(self.indexes):
161 if word in self.item_struct:
162 if self.item_struct[word] == 0 and str(index) == 'var':
163 continue
164
165 if j != 0:
166 expr += ','
167 expr += str(index)
168
169 if j == direction and str(expr_in)[i+1] != '0':
170 tmp = str(expr_in)[i+1]
171 if tmp == '-':
172 expr += tmp
173 i += 1
174 tmp = str(expr_in)[i+1]
175 else:
176 expr += '+'
177
178 while tmp.isnumeric():
179 expr += tmp
180 i += 1
181 tmp = str(expr_in)[i+1]
182 elif word == self.items[1] and str(index) != 'var':
183 expr += '-1'
184 elif not wait:
185 expr += char
186
187 if char.isalpha() or char == '_':
188 word += char
189 else:
190 word = ''
191
192 return sympify(expr,locals=self.all_items)
193
194
195
item(self, expr, struct=True, in_type="double*", parent=None)
Definition frontend.py:76
directional_item(self, expr, struct=True)
Definition frontend.py:87
directional(self, LHS, RHS='', struct=False)
Definition frontend.py:130
directional_const(self, expr, vals)
Definition frontend.py:69
function(self, expr, parent=None, parameter_types=[], return_type=none)
Definition frontend.py:99
const(self, expr, in_type="double", parent=None, define=None)
Definition frontend.py:57
index(self, expr_in, direction=-1)
Definition frontend.py:140
single(self, LHS, RHS='', direction=-1, struct=False)
Definition frontend.py:109
__init__(self, dim, patch_size, halo_size, n_real, n_aux, n_patches=1)
Definition frontend.py:15
viable(dim, patch_size, halo_size)
Definition frontend.py:5