2 * Copyright 2016 WebAssembly Community Group participants
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
19 // Coalesce locals, in order to reduce the total number of locals. This
20 // is similar to register allocation, however, there is never any
21 // spilling, and there isn't a fixed number of locals.
27 #include <unordered_set>
32 #include "cfg/cfg-traversal.h"
33 #include "wasm-builder.h"
34 #include "support/learning.h"
35 #include "support/permutations.h"
37 #include "support/timing.h"
42 // A set of locals. This is optimized for comparisons,
43 // mergings, and iteration on elements, assuming that there
44 // may be a great many potential elements but actual sets
45 // may be fairly small. Specifically, we use a sorted
47 struct LocalSet
: std::vector
<Index
> {
50 LocalSet
merge(const LocalSet
& other
) const {
52 ret
.resize(size() + other
.size());
53 Index i
= 0, j
= 0, t
= 0;
54 while (i
< size() && j
< other
.size()) {
55 auto left
= (*this)[i
];
56 auto right
= other
[j
];
60 } else if (left
> right
) {
70 ret
[t
++] = (*this)[i
];
73 while (j
< other
.size()) {
81 void insert(Index x
) {
82 auto it
= std::lower_bound(begin(), end(), x
);
83 if (it
== end()) push_back(x
);
85 Index i
= it
- begin();
87 std::move_backward(begin() + i
, begin() + size() - 1, end());
93 auto it
= std::lower_bound(begin(), end(), x
);
94 if (it
!= end() && *it
== x
) {
95 std::move(it
+ 1, end(), it
);
103 auto it
= std::lower_bound(begin(), end(), x
);
104 return it
!= end() && *it
== x
;
107 void verify() const {
108 for (Index i
= 1; i
< size(); i
++) {
109 assert((*this)[i
- 1] < (*this)[i
]);
113 void dump(const char* str
= nullptr) const {
114 std::cout
<< "LocalSet " << (str
? str
: "") << ": ";
115 for (auto x
: *this) std::cout
<< x
<< " ";
120 // a liveness-relevant action
126 Index index
; // the local index read or written
127 Expression
** origin
; // the origin
128 bool effective
; // whether a store is actually effective, i.e., may be read
130 Action(What what
, Index index
, Expression
** origin
) : what(what
), index(index
), origin(origin
), effective(false) {}
132 bool isGet() { return what
== Get
; }
133 bool isSet() { return what
== Set
; }
136 // information about liveness in a basic block
138 LocalSet start
, end
; // live locals at the start and end
139 std::vector
<Action
> actions
; // actions occurring in this block
141 void dump(Function
* func
) {
142 if (actions
.empty()) return;
143 std::cout
<< " actions:\n";
144 for (auto& action
: actions
) {
145 std::cout
<< " " << (action
.isGet() ? "get" : "set") << " " << func
->getLocalName(action
.index
) << "\n";
150 struct CoalesceLocals
: public WalkerPass
<CFGWalker
<CoalesceLocals
, Visitor
<CoalesceLocals
>, Liveness
>> {
151 bool isFunctionParallel() override
{ return true; }
153 Pass
* create() override
{ return new CoalesceLocals
; }
157 // cfg traversal work
159 static void doVisitGetLocal(CoalesceLocals
* self
, Expression
** currp
) {
160 auto* curr
= (*currp
)->cast
<GetLocal
>();
161 // if in unreachable code, ignore
162 if (!self
->currBasicBlock
) {
163 *currp
= Builder(*self
->getModule()).replaceWithIdenticalType(curr
);
166 self
->currBasicBlock
->contents
.actions
.emplace_back(Action::Get
, curr
->index
, currp
);
169 static void doVisitSetLocal(CoalesceLocals
* self
, Expression
** currp
) {
170 auto* curr
= (*currp
)->cast
<SetLocal
>();
171 // if in unreachable code, we don't need the tee (but might need the value, if it has side effects)
172 if (!self
->currBasicBlock
) {
174 *currp
= curr
->value
;
176 *currp
= Builder(*self
->getModule()).makeDrop(curr
->value
);
180 self
->currBasicBlock
->contents
.actions
.emplace_back(Action::Set
, curr
->index
, currp
);
181 // if this is a copy, note it
182 if (auto* get
= self
->getCopy(curr
)) {
183 // add 2 units, so that backedge prioritization can decide ties, but not much more
184 self
->addCopy(curr
->index
, get
->index
);
185 self
->addCopy(curr
->index
, get
->index
);
189 // A simple copy is a set of a get. A more interesting copy
190 // is a set of an if with a value, where one side a get.
191 // That can happen when we create an if value in simplify-locals. TODO: recurse into
192 // nested ifs, and block return values? Those cases are trickier, need to
193 // count to see if worth it.
194 // TODO: an if can have two copies
195 GetLocal
* getCopy(SetLocal
* set
) {
196 if (auto* get
= set
->value
->dynCast
<GetLocal
>()) return get
;
197 if (auto* iff
= set
->value
->dynCast
<If
>()) {
198 if (auto* get
= iff
->ifTrue
->dynCast
<GetLocal
>()) return get
;
200 if (auto* get
= iff
->ifFalse
->dynCast
<GetLocal
>()) return get
;
208 void doWalkFunction(Function
* func
);
210 void increaseBackEdgePriorities();
214 void calculateInterferences();
216 void calculateInterferences(const LocalSet
& locals
);
218 // merge starts of a list of blocks, adding new interferences as necessary. return
219 // whether anything changed vs an old state (which indicates further processing is necessary).
220 bool mergeStartsAndCheckChange(std::vector
<BasicBlock
*>& blocks
, LocalSet
& old
, LocalSet
& ret
);
222 void scanLivenessThroughActions(std::vector
<Action
>& actions
, LocalSet
& live
);
224 void pickIndicesFromOrder(std::vector
<Index
>& order
, std::vector
<Index
>& indices
);
225 void pickIndicesFromOrder(std::vector
<Index
>& order
, std::vector
<Index
>& indices
, Index
& removedCopies
);
227 virtual void pickIndices(std::vector
<Index
>& indices
); // returns a vector of oldIndex => newIndex
229 void applyIndices(std::vector
<Index
>& indices
, Expression
* root
);
231 // interference state
233 std::vector
<bool> interferences
; // canonicalized - accesses should check (low, high)
234 std::unordered_set
<BasicBlock
*> liveBlocks
;
236 void interfere(Index i
, Index j
) {
238 interferences
[std::min(i
, j
) * numLocals
+ std::max(i
, j
)] = 1;
241 void interfereLowHigh(Index low
, Index high
) { // optimized version where you know that low < high
243 interferences
[low
* numLocals
+ high
] = 1;
246 bool interferes(Index i
, Index j
) {
247 return interferences
[std::min(i
, j
) * numLocals
+ std::max(i
, j
)];
252 std::vector
<uint8_t> copies
; // canonicalized - accesses should check (low, high) TODO: use a map for high N, as this tends to be sparse? or don't look at copies at all for big N?
253 std::vector
<Index
> totalCopies
; // total # of copies for each local, with all others
255 void addCopy(Index i
, Index j
) {
256 auto k
= std::min(i
, j
) * numLocals
+ std::max(i
, j
);
257 copies
[k
] = std::min(copies
[k
], uint8_t(254)) + 1;
262 uint8_t getCopies(Index i
, Index j
) {
263 return copies
[std::min(i
, j
) * numLocals
+ std::max(i
, j
)];
267 void CoalesceLocals::doWalkFunction(Function
* func
) {
268 numLocals
= func
->getNumLocals();
269 copies
.resize(numLocals
* numLocals
);
270 std::fill(copies
.begin(), copies
.end(), 0);
271 totalCopies
.resize(numLocals
);
272 std::fill(totalCopies
.begin(), totalCopies
.end(), 0);
273 // collect initial liveness info
274 super::doWalkFunction(func
);
275 // ignore links to dead blocks, so they don't confuse us and we can see their stores are all ineffective
276 liveBlocks
= findLiveBlocks();
277 unlinkDeadBlocks(liveBlocks
);
278 // increase the cost of costly backedges
279 increaseBackEdgePriorities();
283 // flow liveness across blocks
285 static Timer
timer("flow");
293 // use liveness to find interference
294 calculateInterferences();
296 std::vector
<Index
> indices
;
297 pickIndices(indices
);
299 applyIndices(indices
, func
->body
);
302 // A copy on a backedge can be especially costly, forcing us to branch just to do that copy.
303 // Add weight to such copies, so we prioritize getting rid of them.
304 void CoalesceLocals::increaseBackEdgePriorities() {
305 for (auto* loopTop
: loopTops
) {
306 // ignore the first edge, it is the initial entry, we just want backedges
307 auto& in
= loopTop
->in
;
308 for (Index i
= 1; i
< in
.size(); i
++) {
309 auto* arrivingBlock
= in
[i
];
310 if (arrivingBlock
->out
.size() > 1) continue; // we just want unconditional branches to the loop top, true phi fragments
311 for (auto& action
: arrivingBlock
->contents
.actions
) {
312 if (action
.what
== Action::Set
) {
313 auto* set
= (*action
.origin
)->cast
<SetLocal
>();
314 if (auto* get
= getCopy(set
)) {
315 // this is indeed a copy, add to the cost (default cost is 2, so this adds 50%, and can mostly break ties)
316 addCopy(set
->index
, get
->index
);
324 void CoalesceLocals::flowLiveness() {
325 interferences
.resize(numLocals
* numLocals
);
326 std::fill(interferences
.begin(), interferences
.end(), 0);
327 // keep working while stuff is flowing
328 std::unordered_set
<BasicBlock
*> queue
;
329 for (auto& curr
: basicBlocks
) {
330 if (liveBlocks
.count(curr
.get()) == 0) continue; // ignore dead blocks
331 queue
.insert(curr
.get());
332 // do the first scan through the block, starting with nothing live at the end, and updating the liveness at the start
333 scanLivenessThroughActions(curr
->contents
.actions
, curr
->contents
.start
);
335 // at every point in time, we assume we already noted interferences between things already known alive at the end, and scanned back through the block using that
336 while (queue
.size() > 0) {
337 auto iter
= queue
.begin();
341 if (!mergeStartsAndCheckChange(curr
->out
, curr
->contents
.end
, live
)) continue;
343 std::cout
<< "change noticed at end of " << debugIds
[curr
] << " from " << curr
->contents
.end
.size() << " to " << live
.size() << " (out of " << numLocals
<< ")\n";
345 assert(curr
->contents
.end
.size() < live
.size());
346 curr
->contents
.end
= live
;
347 scanLivenessThroughActions(curr
->contents
.actions
, live
);
348 // liveness is now calculated at the start. if something
349 // changed, all predecessor blocks need recomputation
350 if (curr
->contents
.start
== live
) continue;
352 std::cout
<< "change noticed at start of " << debugIds
[curr
] << " from " << curr
->contents
.start
.size() << " to " << live
.size() << ", more work to do\n";
354 assert(curr
->contents
.start
.size() < live
.size());
355 curr
->contents
.start
= live
;
356 for (auto* in
: curr
->in
) {
361 std::hash
<std::vector
<bool>> hasher
;
362 std::cout
<< getFunction()->name
<< ": interference hash: " << hasher(*(std::vector
<bool>*)&interferences
) << "\n";
363 for (Index i
= 0; i
< numLocals
; i
++) {
364 std::cout
<< "int for " << getFunction()->getLocalName(i
) << " [" << i
<< "]: ";
365 for (Index j
= 0; j
< numLocals
; j
++) {
366 if (interferes(i
, j
)) std::cout
<< getFunction()->getLocalName(j
) << " ";
373 // merge starts of a list of blocks. return
374 // whether anything changed vs an old state (which indicates further processing is necessary).
375 bool CoalesceLocals::mergeStartsAndCheckChange(std::vector
<BasicBlock
*>& blocks
, LocalSet
& old
, LocalSet
& ret
) {
376 if (blocks
.size() == 0) return false;
377 ret
= blocks
[0]->contents
.start
;
378 if (blocks
.size() > 1) {
379 // more than one, so we must merge
380 for (Index i
= 1; i
< blocks
.size(); i
++) {
381 ret
= ret
.merge(blocks
[i
]->contents
.start
);
387 void CoalesceLocals::scanLivenessThroughActions(std::vector
<Action
>& actions
, LocalSet
& live
) {
388 // move towards the front
389 for (int i
= int(actions
.size()) - 1; i
>= 0; i
--) {
390 auto& action
= actions
[i
];
391 if (action
.isGet()) {
392 live
.insert(action
.index
);
394 live
.erase(action
.index
);
399 void CoalesceLocals::calculateInterferences() {
400 for (auto& curr
: basicBlocks
) {
401 if (liveBlocks
.count(curr
.get()) == 0) continue; // ignore dead blocks
402 // everything coming in might interfere, as it might come from a different block
403 auto live
= curr
->contents
.end
;
404 calculateInterferences(live
);
405 // scan through the block itself
406 auto& actions
= curr
->contents
.actions
;
407 for (int i
= int(actions
.size()) - 1; i
>= 0; i
--) {
408 auto& action
= actions
[i
];
409 auto index
= action
.index
;
410 if (action
.isGet()) {
411 // new live local, interferes with all the rest
413 for (auto i
: live
) {
417 if (live
.erase(index
)) {
418 action
.effective
= true;
423 // Params have a value on entry, so mark them as live, as variables
424 // live at the entry expect their zero-init value.
425 LocalSet start
= entry
->contents
.start
;
426 auto numParams
= getFunction()->getNumParams();
427 for (Index i
= 0; i
< numParams
; i
++) {
430 calculateInterferences(start
);
433 void CoalesceLocals::calculateInterferences(const LocalSet
& locals
) {
434 Index size
= locals
.size();
435 for (Index i
= 0; i
< size
; i
++) {
436 for (Index j
= i
+ 1; j
< size
; j
++) {
437 interfereLowHigh(locals
[i
], locals
[j
]);
442 // Indices decision making
444 void CoalesceLocals::pickIndicesFromOrder(std::vector
<Index
>& order
, std::vector
<Index
>& indices
) {
446 pickIndicesFromOrder(order
, indices
, removedCopies
);
449 void CoalesceLocals::pickIndicesFromOrder(std::vector
<Index
>& order
, std::vector
<Index
>& indices
, Index
& removedCopies
) {
450 // mostly-simple greedy coloring
452 std::cerr
<< "\npickIndicesFromOrder on " << getFunction()->name
<< '\n';
453 std::cerr
<< getFunction()->body
<< '\n';
454 std::cerr
<< "order:\n";
455 for (auto i
: order
) std::cerr
<< i
<< ' ';
457 std::cerr
<< "interferences:\n";
458 for (Index i
= 0; i
< numLocals
; i
++) {
459 for (Index j
= 0; j
< i
+ 1; j
++) {
462 for (Index j
= i
+ 1; j
< numLocals
; j
++) {
463 std::cerr
<< int(interferes(i
, j
)) << ' ';
465 std::cerr
<< " : $" << i
<< '\n';
467 std::cerr
<< "copies:\n";
468 for (Index i
= 0; i
< numLocals
; i
++) {
469 for (Index j
= 0; j
< i
+ 1; j
++) {
472 for (Index j
= i
+ 1; j
< numLocals
; j
++) {
473 std::cerr
<< int(getCopies(i
, j
)) << ' ';
475 std::cerr
<< " : $" << i
<< '\n';
477 std::cerr
<< "total copies:\n";
478 for (Index i
= 0; i
< numLocals
; i
++) {
479 std::cerr
<< " $" << i
<< ": " << totalCopies
[i
] << '\n';
482 // TODO: take into account distribution (99-1 is better than 50-50 with two registers, for gzip)
483 std::vector
<WasmType
> types
;
484 std::vector
<bool> newInterferences
; // new index * numLocals => list of all interferences of locals merged to it
485 std::vector
<uint8_t> newCopies
; // new index * numLocals => list of all copies of locals merged to it
486 indices
.resize(numLocals
);
487 types
.resize(numLocals
);
488 newInterferences
.resize(numLocals
* numLocals
);
489 std::fill(newInterferences
.begin(), newInterferences
.end(), 0);
490 auto numParams
= getFunction()->getNumParams();
491 newCopies
.resize(numParams
* numLocals
); // start with enough room for the params
492 std::fill(newCopies
.begin(), newCopies
.end(), 0);
495 // we can't reorder parameters, they are fixed in order, and cannot coalesce
497 for (; i
< numParams
; i
++) {
498 assert(order
[i
] == i
); // order must leave the params in place
500 types
[i
] = getFunction()->getLocalType(i
);
501 for (Index j
= numParams
; j
< numLocals
; j
++) {
502 newInterferences
[numLocals
* i
+ j
] = interferes(i
, j
);
503 newCopies
[numLocals
* i
+ j
] = getCopies(i
, j
);
507 for (; i
< numLocals
; i
++) {
508 Index actual
= order
[i
];
510 uint8_t foundCopies
= -1;
511 for (Index j
= 0; j
< nextFree
; j
++) {
512 if (!newInterferences
[j
* numLocals
+ actual
] && getFunction()->getLocalType(actual
) == types
[j
]) {
513 // this does not interfere, so it might be what we want. but pick the one eliminating the most copies
514 // (we could stop looking forward when there are no more items that have copies anyhow, but it doesn't seem to help)
515 auto currCopies
= newCopies
[j
* numLocals
+ actual
];
516 if (found
== Index(-1) || currCopies
> foundCopies
) {
517 indices
[actual
] = found
= j
;
518 foundCopies
= currCopies
;
522 if (found
== Index(-1)) {
523 indices
[actual
] = found
= nextFree
;
524 types
[found
] = getFunction()->getLocalType(actual
);
526 removedCopies
+= getCopies(found
, actual
);
527 newCopies
.resize(nextFree
* numLocals
);
529 removedCopies
+= foundCopies
;
532 std::cerr
<< "set local $" << actual
<< " to $" << found
<< '\n';
534 // merge new interferences and copies for the new index
535 for (Index k
= i
+ 1; k
< numLocals
; k
++) {
536 auto j
= order
[k
]; // go in the order, we only need to update for those we will see later
537 newInterferences
[found
* numLocals
+ j
] = newInterferences
[found
* numLocals
+ j
] | interferes(actual
, j
);
538 newCopies
[found
* numLocals
+ j
] += getCopies(actual
, j
);
543 // given a baseline order, adjust it based on an important order of priorities (higher values
544 // are higher priority). The priorities take precedence, unless they are equal and then
545 // the original order should be kept.
546 std::vector
<Index
> adjustOrderByPriorities(std::vector
<Index
>& baseline
, std::vector
<Index
>& priorities
) {
547 std::vector
<Index
> ret
= baseline
;
548 std::vector
<Index
> reversed
= makeReversed(baseline
);
549 std::sort(ret
.begin(), ret
.end(), [&priorities
, &reversed
](Index x
, Index y
) {
550 return priorities
[x
] > priorities
[y
] || (priorities
[x
] == priorities
[y
] && reversed
[x
] < reversed
[y
]);
555 void CoalesceLocals::pickIndices(std::vector
<Index
>& indices
) {
556 if (numLocals
== 0) return;
557 if (numLocals
== 1) {
558 indices
.push_back(0);
561 if (getFunction()->getNumVars() <= 1) {
562 // nothing to think about here, since we can't reorder params
563 indices
= makeIdentity(numLocals
);
566 // take into account total copies. but we must keep params in place, so give them max priority
567 auto adjustedTotalCopies
= totalCopies
;
568 auto numParams
= getFunction()->getNumParams();
569 for (Index i
= 0; i
< numParams
; i
++) {
570 adjustedTotalCopies
[i
] = std::numeric_limits
<Index
>::max();
572 // first try the natural order. this is less arbitrary than it seems, as the program
573 // may have a natural order of locals inherent in it.
574 auto order
= makeIdentity(numLocals
);
575 order
= adjustOrderByPriorities(order
, adjustedTotalCopies
);
577 pickIndicesFromOrder(order
, indices
, removedCopies
);
578 auto maxIndex
= *std::max_element(indices
.begin(), indices
.end());
579 // next try the reverse order. this both gives us another chance at something good,
580 // and also the very naturalness of the simple order may be quite suboptimal
582 for (Index i
= numParams
; i
< numLocals
; i
++) {
583 order
[i
] = numParams
+ numLocals
- 1 - i
;
585 order
= adjustOrderByPriorities(order
, adjustedTotalCopies
);
586 std::vector
<Index
> reverseIndices
;
587 Index reverseRemovedCopies
;
588 pickIndicesFromOrder(order
, reverseIndices
, reverseRemovedCopies
);
589 auto reverseMaxIndex
= *std::max_element(reverseIndices
.begin(), reverseIndices
.end());
590 // prefer to remove copies foremost, as it matters more for code size (minus gzip), and
591 // improves throughput.
592 if (reverseRemovedCopies
> removedCopies
|| (reverseRemovedCopies
== removedCopies
&& reverseMaxIndex
< maxIndex
)) {
593 indices
.swap(reverseIndices
);
597 // Remove a copy from a set of an if, where one if arm is a get of the same set
598 static void removeIfCopy(Expression
** origin
, SetLocal
* set
, If
* iff
, Expression
*& copy
, Expression
*& other
, Module
* module
) {
599 // replace the origin with the if, and sink the set into the other non-copying arm
600 bool tee
= set
->isTee();
605 // if this is not a tee, then we can get rid of the copy in that arm
607 // we don't need the copy at all
610 Builder(*module
).flip(iff
);
616 void CoalesceLocals::applyIndices(std::vector
<Index
>& indices
, Expression
* root
) {
617 assert(indices
.size() == numLocals
);
618 for (auto& curr
: basicBlocks
) {
619 auto& actions
= curr
->contents
.actions
;
620 for (auto& action
: actions
) {
621 if (action
.isGet()) {
622 auto* get
= (*action
.origin
)->cast
<GetLocal
>();
623 get
->index
= indices
[get
->index
];
625 auto* set
= (*action
.origin
)->cast
<SetLocal
>();
626 set
->index
= indices
[set
->index
];
627 // in addition, we can optimize out redundant copies and ineffective sets
629 if ((get
= set
->value
->dynCast
<GetLocal
>()) && get
->index
== set
->index
) {
631 *action
.origin
= get
;
633 ExpressionManipulator::nop(set
);
637 // remove ineffective actions
638 if (!action
.effective
) {
639 *action
.origin
= set
->value
; // value may have no side effects, further optimizations can eliminate it
641 // we need to drop it
642 Drop
* drop
= ExpressionManipulator::convert
<SetLocal
, Drop
>(set
);
643 drop
->value
= *action
.origin
;
644 *action
.origin
= drop
;
648 if (auto* iff
= set
->value
->dynCast
<If
>()) {
649 if (auto* get
= iff
->ifTrue
->dynCast
<GetLocal
>()) {
650 if (get
->index
== set
->index
) {
651 removeIfCopy(action
.origin
, set
, iff
, iff
->ifTrue
, iff
->ifFalse
, getModule());
655 if (auto* get
= iff
->ifFalse
->dynCast
<GetLocal
>()) {
656 if (get
->index
== set
->index
) {
657 removeIfCopy(action
.origin
, set
, iff
, iff
->ifFalse
, iff
->ifTrue
, getModule());
666 auto numParams
= getFunction()->getNumParams();
667 Index newNumLocals
= 0;
668 for (auto index
: indices
) {
669 newNumLocals
= std::max(newNumLocals
, index
+ 1);
671 auto oldVars
= getFunction()->vars
;
672 getFunction()->vars
.resize(newNumLocals
- numParams
);
673 for (Index index
= numParams
; index
< numLocals
; index
++) {
674 Index newIndex
= indices
[index
];
675 if (newIndex
>= numParams
) {
676 getFunction()->vars
[newIndex
- numParams
] = oldVars
[index
- numParams
];
680 getFunction()->localNames
.clear();
681 getFunction()->localIndices
.clear();
684 struct CoalesceLocalsWithLearning
: public CoalesceLocals
{
685 virtual Pass
* create() override
{ return new CoalesceLocalsWithLearning
; }
687 virtual void pickIndices(std::vector
<Index
>& indices
) override
;
690 void CoalesceLocalsWithLearning::pickIndices(std::vector
<Index
>& indices
) {
691 if (getFunction()->getNumVars() <= 1) {
692 // nothing to think about here
693 CoalesceLocals::pickIndices(indices
);
697 struct Order
: public std::vector
<Index
> {
698 void setFitness(double f
) { fitness
= f
; }
699 double getFitness() { return fitness
; }
700 void dump(std::string text
) {
701 std::cout
<< text
+ ": ( ";
702 for (Index i
= 0; i
< size(); i
++) std::cout
<< (*this)[i
] << " ";
704 std::cout
<< "of quality: " << getFitness() << "\n";
711 Generator(CoalesceLocalsWithLearning
* parent
) : parent(parent
), noise(42) {}
713 void calculateFitness(Order
* order
) {
715 std::vector
<Index
> indices
; // the phenotype
717 parent
->pickIndicesFromOrder(*order
, indices
, removedCopies
);
718 auto maxIndex
= *std::max_element(indices
.begin(), indices
.end());
719 assert(maxIndex
<= parent
->numLocals
);
720 // main part of fitness is the number of locals
721 double fitness
= parent
->numLocals
- maxIndex
; // higher fitness is better
722 // secondarily, it is nice to not reorder locals unnecessarily
723 double fragment
= 1.0 / (2.0 * parent
->numLocals
);
724 for (Index i
= 0; i
< parent
->numLocals
; i
++) {
725 if ((*order
)[i
] == i
) fitness
+= fragment
; // boost for each that wasn't moved
727 fitness
= (100 * fitness
) + removedCopies
; // removing copies is a secondary concern
728 order
->setFitness(fitness
);
731 Order
* makeRandom() {
732 auto* ret
= new Order
;
733 ret
->resize(parent
->numLocals
);
734 for (Index i
= 0; i
< parent
->numLocals
; i
++) {
738 // as the first guess, use the natural order. this is not arbitrary for two reasons.
739 // first, there may be an inherent order in the input (frequent indices are lower,
740 // etc.). second, by ensuring we start with the natural order, we ensure we are at
741 // least as good as the non-learning variant.
742 // TODO: use ::pickIndices from the parent, so we literally get the simpler approach
743 // as our first option
746 // leave params alone, shuffle the rest
747 std::shuffle(ret
->begin() + parent
->getFunction()->getNumParams(), ret
->end(), noise
);
749 calculateFitness(ret
);
750 #ifdef CFG_LEARN_DEBUG
751 order
->dump("new rando");
756 Order
* makeMixture(Order
* left
, Order
* right
) {
757 // perturb left using right. this is useful since
758 // we don't care about absolute locations, relative ones matter more,
759 // and a true merge of two vectors could obscure that (e.g.
760 // a.......... and ..........a would merge a into the middle, for no
761 // reason), and cause a lot of unnecessary noise
762 Index size
= left
->size();
763 Order reverseRight
; // reverseRight[x] is the index of x in right
764 reverseRight
.resize(size
);
765 for (Index i
= 0; i
< size
; i
++) {
766 reverseRight
[(*right
)[i
]] = i
;
768 auto* ret
= new Order
;
771 for (Index i
= parent
->getFunction()->getNumParams(); i
< size
- 1; i
++) {
772 // if (i, i + 1) is in reverse order in right, flip them
773 if (reverseRight
[(*ret
)[i
]] > reverseRight
[(*ret
)[i
+ 1]]) {
774 std::swap((*ret
)[i
], (*ret
)[i
+ 1]);
775 i
++; // if we don't skip, we might end up pushing an element all the way to the end, which is not very perturbation-y
778 calculateFitness(ret
);
779 #ifdef CFG_LEARN_DEBUG
780 ret
->dump("new mixture");
786 CoalesceLocalsWithLearning
* parent
;
791 #ifdef CFG_LEARN_DEBUG
792 std::cout
<< "[learning for " << getFunction()->name
<< "]\n";
794 auto numVars
= this->getFunction()->getNumVars();
795 const int GENERATION_SIZE
= std::min(Index(numVars
* (numVars
- 1)), Index(20));
796 Generator
generator(this);
797 GeneticLearner
<Order
, double, Generator
> learner(generator
, GENERATION_SIZE
);
798 #ifdef CFG_LEARN_DEBUG
799 learner
.getBest()->dump("first best");
801 // keep working while we see improvement
802 auto oldBest
= learner
.getBest()->getFitness();
804 learner
.runGeneration();
805 auto newBest
= learner
.getBest()->getFitness();
806 if (newBest
== oldBest
) break; // unlikely we can improve
808 #ifdef CFG_LEARN_DEBUG
809 learner
.getBest()->dump("current best");
812 #ifdef CFG_LEARN_DEBUG
813 learner
.getBest()->dump("the best");
815 this->pickIndicesFromOrder(*learner
.getBest(), indices
); // TODO: cache indices in Orders, at the cost of more memory?
820 Pass
*createCoalesceLocalsPass() {
821 return new CoalesceLocals();
824 Pass
*createCoalesceLocalsWithLearningPass() {
825 return new CoalesceLocalsWithLearning();