affine.ccGo to the documentation of this file.00001
00002 #include <set>
00003 #include <vector>
00004
00005 #include "Symbol/Symbol.h"
00006 #include "Expression/Expression.h"
00007 #include "Statement/DoStmt.h"
00008 #include "ProgramUnit.h"
00009 #include "ip_ssa/trans_util.h"
00010 #include "eg_utils.h"
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022 bool is_affine(Expression& f,
00023 set<Symbol*>& args,
00024 vector<Symbol*>& vargs,
00025 vector<Expression*>& vcoeffs){
00026
00027
00028
00029
00030 List<Symbol> test_syms;
00031 vector<pair<Symbol*, Symbol*> > xy;
00032 for(set<Symbol*>::iterator sit=args.begin(); sit!=args.end(); ++sit){
00033 Symbol* y=(*sit)->clone();
00034 String newname=y->name_ref();
00035 newname+="_AT__";
00036 y->name(newname);
00037 xy.push_back(make_pair(*sit, y));
00038 test_syms.ins_last(y);
00039 }
00040
00041
00042 Expression* fx=f.clone();
00043 Expression* fy=f.clone();
00044 Expression* fx_plus_y=f.clone();
00045 Expression* f0=f.clone();
00046 Expression* zero=constant(0);
00047 for(vector<pair<Symbol*, Symbol*> >::iterator vpit=xy.begin();
00048 vpit!=xy.end(); ++vpit){
00049 substitute_var(*fy, *(*vpit).first, *(*vpit).second);
00050 Expression* x_plus_y=
00051 simplify(add(id(*(*vpit).first), id(*(*vpit).second)));
00052 fx_plus_y=simplify(substitute_var(fx_plus_y, *(*vpit).first, *x_plus_y));
00053 delete x_plus_y;
00054 f0=simplify(substitute_var(f0, *(*vpit).first, *zero));
00055 }
00056 delete zero;
00057
00058
00059
00060
00061
00062
00063
00064 Expression* left=simplify(fx_plus_y);
00065 Expression* right=simplify(add(add(fx, fy), mul(constant(-1), f0->clone())));
00066 if (left->op()==OMEGA_OP || right->op()==OMEGA_OP){
00067 delete left;
00068 delete right;
00069 return false;
00070 }
00071
00072
00073 Expression* must_be_zero=simplify(sub(left, right));
00074
00075
00076
00077
00078 bool toret=false;
00079 if (is_integer_zero(*must_be_zero)){
00080 toret=true;
00081
00082 for(set<Symbol*>::iterator sit=args.begin();
00083 sit!=args.end(); ++sit){
00084 vargs.push_back(*sit);
00085 vcoeffs.push_back(f.clone());
00086 }
00087 Expression* zero=constant(0);
00088 Expression* one=constant(1);
00089 for(vector<Symbol*>::iterator sit=vargs.begin();
00090 sit!=vargs.end(); ++sit){
00091 for(vector<Expression*>::iterator cit=vcoeffs.begin();
00092 cit!=vcoeffs.end(); ++cit){
00093 if (sit-vargs.begin() != cit-vcoeffs.begin()){
00094 *cit=substitute_var(*cit, **sit, *zero);
00095 } else {
00096 *cit=substitute_var(*cit, **sit, *one);
00097 }
00098 }
00099 }
00100 f0=simplify(f0);
00101 for(vector<Expression*>::iterator cit=vcoeffs.begin();
00102 cit!=vcoeffs.end(); ++cit){
00103 *cit=sub(*cit, f0->clone());
00104 *cit=simplify(*cit);
00105 }
00106 vcoeffs.push_back(f0);
00107 delete zero;
00108 delete one;
00109 } else {
00110 delete f0;
00111 }
00112 delete must_be_zero;
00113
00114
00115 return toret;
00116 }
00117
00118
00119
00120
00121
00122
00123
00124 bool check_scalars_affine(List<Expression>& elist,
00125 ProgramUnit& pgm, DoStmt& loop,
00126 vector<vector<Symbol*> >& vargs,
00127 vector<vector<Expression*> >& vcoeffs){
00128
00129
00130 for(Iterator<Expression> eit1=elist; eit1.valid(); ++eit1){
00131 Expression& expr=eit1.current();
00132 RefList<Symbol> syms;
00133 set<Symbol*> loop_variants;
00134 referred_symbols(expr, syms);
00135 vargs.push_back(vector<Symbol*>());
00136 vcoeffs.push_back(vector<Expression*>());
00137 for(Iterator<Symbol> sit1=syms; sit1.valid(); ++sit1){
00138 if (sit1.current().intrinsic()){
00139 continue;
00140 }
00141 if (sit1.current().type().data_type()!=INTEGER_TYPE ||
00142 !sit1.current().is_scalar()){
00143 return false;
00144 }
00145 if (loop_variant(sit1.current(), &loop, pgm)){
00146 loop_variants.insert(&sit1.current());
00147 }
00148 }
00149 if (!is_affine(eit1.current(), loop_variants,
00150 vargs[vargs.size()-1], vcoeffs[vcoeffs.size()-1])){
00151 return false;
00152 }
00153 }
00154 return true;
00155 }
|