]>
git.proxmox.com Git - ceph.git/blob - ceph/src/boost/libs/yap/example/autodiff_library/autodiff.cpp
1 //============================================================================
5 // Copyright : Your copyright notice
6 // Description : Hello World in C++, Ansi-style
7 //============================================================================
12 #include <boost/foreach.hpp>
16 #include "BinaryOPNode.h"
17 #include "UaryOPNode.h"
26 unsigned int num_var
= 0;
28 void hess_forward(Node
* root
, unsigned int nvar
, double** hess_mat
)
30 assert(nvar
== num_var
);
31 unsigned int len
= (nvar
+3)*nvar
/2;
32 root
->hess_forward(len
,hess_mat
);
38 PNode
* create_param_node(double value
){
39 return new PNode(value
);
41 VNode
* create_var_node(double v
)
45 OPNode
* create_binary_op_node(OPCODE code
, Node
* left
, Node
* right
)
47 return BinaryOPNode::createBinaryOpNode(code
,left
,right
);
49 OPNode
* create_uary_op_node(OPCODE code
, Node
* left
)
51 return UaryOPNode::createUnaryOpNode(code
,left
);
53 double eval_function(Node
* root
)
55 assert(SD
->size()==0);
56 assert(SV
->size()==0);
57 root
->eval_function();
58 assert(SV
->size()==1);
59 double val
= SV
->pop_back();
63 double grad_reverse(Node
* root
,vector
<Node
*>& vnodes
, vector
<double>& grad
)
66 BOOST_FOREACH(Node
* node
, vnodes
)
68 assert(node
->getType()==VNode_Type
);
69 static_cast<VNode
*>(node
)->adj
= NaN_Double
;
72 assert(SD
->size()==0);
73 root
->grad_reverse_0();
74 assert(SV
->size()==1);
75 root
->grad_reverse_1_init_adj();
76 root
->grad_reverse_1();
77 assert(SD
->size()==0);
78 double val
= SV
->pop_back();
79 assert(SV
->size()==0);
81 BOOST_FOREACH(Node
* node
, vnodes
)
83 assert(node
->getType()==VNode_Type
);
84 grad
.push_back(static_cast<VNode
*>(node
)->adj
);
85 static_cast<VNode
*>(node
)->adj
= NaN_Double
;
87 assert(grad
.size()==vnodes
.size());
88 //all nodes are VNode and adj == NaN_Double -- this reset adj for this expression tree by root
92 double grad_reverse(Node
* root
, vector
<Node
*>& vnodes
, col_compress_matrix_row
& rgrad
)
94 BOOST_FOREACH(Node
* node
, vnodes
)
96 assert(node
->getType()==VNode_Type
);
97 static_cast<VNode
*>(node
)->adj
= NaN_Double
;
99 assert(SD
->size()==0);
100 root
->grad_reverse_0();
101 assert(SV
->size()==1);
102 root
->grad_reverse_1_init_adj();
103 root
->grad_reverse_1();
104 assert(SD
->size()==0);
105 double val
= SV
->pop_back();
106 assert(SV
->size()==0);
108 BOOST_FOREACH(Node
* node
, vnodes
)
110 assert((node
)->getType()==VNode_Type
);
111 double diff
= static_cast<VNode
*>(node
)->adj
;
114 static_cast<VNode
*>(node
)->adj
= NaN_Double
;
118 //all nodes are VNode and adj == NaN_Double -- this reset adj for this expression tree by root
119 assert(i
==vnodes
.size());
123 double hess_reverse(Node
* root
,vector
<Node
*>& vnodes
,vector
<double>& dhess
)
129 assert(TT
->index
==0);
130 assert(II
->index
==0);
133 // for(vector<Node*>::iterator it=nodes.begin();it!=nodes.end();it++)
135 // assert((*it)->getType()==VNode_Type);
137 // } //this work complete in hess-reverse_0_init_index
139 assert(root
->n_in_arcs
== 0);
140 root
->hess_reverse_0_init_n_in_arcs();
141 assert(root
->n_in_arcs
== 1);
142 root
->hess_reverse_0();
143 double val
= NaN_Double
;
144 root
->hess_reverse_get_x(TT
->index
,val
);
145 // cout<<TT->toString();
147 // cout<<II->toString();
148 // cout<<"======================================= hess_reverse_0"<<endl;
149 root
->hess_reverse_1_init_x_bar(TT
->index
);
150 assert(root
->n_in_arcs
== 1);
151 root
->hess_reverse_1(TT
->index
);
152 assert(root
->n_in_arcs
== 0);
153 assert(II
->index
==0);
154 // cout<<TT->toString();
156 // cout<<II->toString();
157 // cout<<"======================================= hess_reverse_1"<<endl;
159 for(vector
<Node
*>::iterator it
=vnodes
.begin();it
!=vnodes
.end();it
++)
161 assert((*it
)->getType()==VNode_Type
);
162 dhess
.push_back(TT
->get((*it
)->index
-1));
167 root
->hess_reverse_1_clear_index();
171 double hess_reverse(Node
* root
,vector
<Node
*>& vnodes
,col_compress_matrix_col
& chess
)
177 assert(TT
->index
==0);
178 assert(II
->index
==0);
180 // for(vector<Node*>::iterator it=nodes.begin();it!=nodes.end();it++)
182 // assert((*it)->getType()==VNode_Type);
184 // } //this work complete in hess-reverse_0_init_index
186 assert(root
->n_in_arcs
== 0);
187 //reset node index and n_in_arcs - for the Tape location
188 root
->hess_reverse_0_init_n_in_arcs();
189 assert(root
->n_in_arcs
== 1);
190 root
->hess_reverse_0();
191 double val
= NaN_Double
;
192 root
->hess_reverse_get_x(TT
->index
,val
);
193 // cout<<TT->toString();
195 // cout<<II->toString();
196 // cout<<"======================================= hess_reverse_0"<<endl;
197 root
->hess_reverse_1_init_x_bar(TT
->index
);
198 assert(root
->n_in_arcs
== 1);
199 root
->hess_reverse_1(TT
->index
);
200 assert(root
->n_in_arcs
== 0);
201 assert(II
->index
==0);
202 // cout<<TT->toString();
204 // cout<<II->toString();
205 // cout<<"======================================= hess_reverse_1"<<endl;
208 BOOST_FOREACH(Node
* node
, vnodes
)
210 assert(node
->getType() == VNode_Type
);
211 //node->index = 0 means this VNode is not in the tree
214 double hess
= TT
->get(node
->index
-1);
217 chess(i
) = chess(i
) + hess
;
222 assert(i
==vnodes
.size());
223 root
->hess_reverse_1_clear_index();
229 unsigned int nzGrad(Node
* root
)
231 unsigned int nzgrad
,total
= 0;
232 boost::unordered_set
<Node
*> nodes
;
233 root
->collect_vnodes(nodes
,total
);
234 nzgrad
= nodes
.size();
239 * number of non-zero gradient in constraint tree root that also belong to vSet
241 unsigned int nzGrad(Node
* root
, boost::unordered_set
<Node
*>& vSet
)
243 unsigned int nzgrad
=0, total
=0;
244 boost::unordered_set
<Node
*> vnodes
;
245 root
->collect_vnodes(vnodes
,total
);
246 //cout<<"nzGrad - vnodes size["<<vnodes.size()<<"] -- total node["<<total<<"]"<<endl;
247 for(boost::unordered_set
<Node
*>::iterator it
=vnodes
.begin();it
!=vnodes
.end();it
++)
250 if(vSet
.find(n
) != vSet
.end())
258 void nonlinearEdges(Node
* root
, EdgeSet
& edges
)
260 root
->nonlinearEdges(edges
);
263 unsigned int nzHess(EdgeSet
& eSet
,boost::unordered_set
<Node
*>& set1
, boost::unordered_set
<Node
*>& set2
)
265 list
<Edge
>::iterator i
= eSet
.edges
.begin();
266 for(;i
!=eSet
.edges
.end();)
271 if((set1
.find(a
)!=set1
.end() && set2
.find(b
)!=set2
.end())
273 (set1
.find(b
)!=set1
.end() && set2
.find(a
)!=set2
.end()))
275 //e is connected between set1 and set2
280 i
= eSet
.edges
.erase(i
);
283 unsigned int diag
=eSet
.numSelfEdges();
284 unsigned int nzHess
= (eSet
.size())*2 - diag
;
288 unsigned int nzHess(EdgeSet
& edges
)
290 unsigned int diag
=edges
.numSelfEdges();
291 unsigned int nzHess
= (edges
.size())*2 - diag
;
295 unsigned int numTotalNodes(Node
* root
)
297 unsigned int total
= 0;
298 boost::unordered_set
<Node
*> nodes
;
299 root
->collect_vnodes(nodes
,total
);
303 string
tree_expr(Node
* root
)
306 oss
<<"visiting tree == "<<endl
;
308 root
->inorder_visit(level
,oss
);
312 void print_tree(Node
* root
)
314 cout
<<"visiting tree == "<<endl
;
316 root
->inorder_visit(level
,cout
);
319 void autodiff_setup()
321 Stack::diff
= new Stack();
322 Stack::vals
= new Stack();
323 Tape
<unsigned int>::indexTape
= new Tape
<unsigned int>();
324 Tape
<double>::valueTape
= new Tape
<double>();
327 void autodiff_cleanup()
331 delete Tape
<unsigned int>::indexTape
;
332 delete Tape
<double>::valueTape
;
335 } //AutoDiff namespace end