2from sympy.codegen.ast
import integer, real, none
3from exahype
import sympy
5def viable(dim,patch_size,halo_size):
15 def __init__(self,dim,patch_size,halo_size,n_real,n_aux,n_patches=1):
16 if not viable(dim,patch_size,halo_size):
17 raise Exception(
'check viability of inputs')
25 self.
indexes = [index
for index
in symbols(
'patch i j', cls=Idx)]
27 self.
indexes.append(symbols(
'k', cls=Idx))
28 self.
indexes.append(symbols(
'var', cls=Idx))
42 default_range = halo_range
44 self.
all_items = {
'i':Idx(
'i',default_range),
'j':Idx(
'j',default_range),
'k':Idx(
'k',default_range),
'patch':Idx(
'patch',(0,self.
n_patches)),
'var':Idx(
'var',(0,n_real+n_aux))}
51 self.
const(
'dim',define=f
'int dim = {dim};')
52 self.
const(
'patch_size',define=f
'int patch_size = {patch_size};')
53 self.
const(
'halo_size',define=f
'int halo_size = {halo_size};')
54 self.
const(
'n_real',define=f
'int n_real = {n_real};')
55 self.
const(
'n_aux',define=f
'int n_aux = {n_aux};')
57 def const(self,expr,in_type="double",parent=None,define=None):
60 self.
parents[expr] = str(parent)
67 return symbols(expr, real=
True)
70 if len(vals) != self.
dim:
71 raise Exception(
"directional constant must have values for each direction")
73 self.
all_items[expr] = symbols(expr, real=
True)
74 return symbols(expr, real=
True)
76 def item(self,expr,struct=True,in_type="double*",parent=None):
77 self.
items.append(expr)
78 self.
all_items[expr] = IndexedBase(expr, real=
True)
79 if len(self.
items) == 1:
84 self.
parents[expr] = str(parent)
85 return IndexedBase(expr, real=
True)
91 extra = [
'_x',
'_y',
'_z']
92 for i
in range(self.
dim):
94 tmp = expr + direction
95 self.
all_items[tmp] = IndexedBase(tmp, real=
True)
97 return IndexedBase(expr, real=
True)
99 def function(self, expr, parent=None, parameter_types = [], return_type = none, ):
101 self.
parents[expr] = str(parent)
103 func = sympy.TypedFunction(expr)
104 func.returnType(return_type)
105 func.parameterTypes(parameter_types)
109 def single(self,LHS,RHS='',direction=-1,struct=False):
114 elif str(LHS).partition(
'[')[0]
in self.
inputs:
119 tmp = [val
for key,val
in self.
item_struct.
items()
if key
in (str(LHS)+str(RHS))]
122 if str(LHS).partition(
'[')[0]
in self.
inputs:
127 self.
LHS.append(self.
index(LHS,direction))
128 self.
RHS.append(self.
index(RHS,direction))
131 for i
in range(self.
dim):
133 if key
in str(LHS)
or key
in str(RHS):
138 self.
single(LHS,RHS,i+1,struct)
140 def index(self,expr_in,direction=-1):
147 for i,char
in enumerate(str(expr_in)):
151 if direction >= 0
and word
in self.
directional_items and not (str(expr_in)+
"1")[i+1].isalpha():
152 thing = [
'_patch',
'_x',
'_y',
'_z']
153 expr += thing[direction]
154 word += thing[direction]
160 for j,index
in enumerate(self.
indexes):
162 if self.
item_struct[word] == 0
and str(index) ==
'var':
169 if j == direction
and str(expr_in)[i+1] !=
'0':
170 tmp = str(expr_in)[i+1]
174 tmp = str(expr_in)[i+1]
178 while tmp.isnumeric():
181 tmp = str(expr_in)[i+1]
182 elif word == self.
items[1]
and str(index) !=
'var':
187 if char.isalpha()
or char ==
'_':
192 return sympify(expr,locals=self.
all_items)
item(self, expr, struct=True, in_type="double*", parent=None)
directional_item(self, expr, struct=True)
directional(self, LHS, RHS='', struct=False)
directional_const(self, expr, vals)
function(self, expr, parent=None, parameter_types=[], return_type=none)
const(self, expr, in_type="double", parent=None, define=None)
index(self, expr_in, direction=-1)
single(self, LHS, RHS='', direction=-1, struct=False)
__init__(self, dim, patch_size, halo_size, n_real, n_aux, n_patches=1)
viable(dim, patch_size, halo_size)