]> git.proxmox.com Git - rustc.git/blame - src/binaryen/src/passes/DuplicateFunctionElimination.cpp
New upstream version 1.25.0+dfsg1
[rustc.git] / src / binaryen / src / passes / DuplicateFunctionElimination.cpp
CommitLineData
abe05a73
XL
1/*
2 * Copyright 2016 WebAssembly Community Group participants
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17//
18// Removes duplicate functions. That can happen due to C++ templates,
19// and also due to types being different at the source level, but
20// identical when finally lowered into concrete wasm code.
21//
22
23#include "wasm.h"
24#include "pass.h"
25#include "ir/utils.h"
26#include "support/hash.h"
27
28namespace wasm {
29
30struct FunctionHasher : public WalkerPass<PostWalker<FunctionHasher>> {
31 bool isFunctionParallel() override { return true; }
32
33 FunctionHasher(std::map<Function*, uint32_t>* output) : output(output) {}
34
35 FunctionHasher* create() override {
36 return new FunctionHasher(output);
37 }
38
39 void doWalkFunction(Function* func) {
40 assert(digest == 0);
41 hash(func->getNumParams());
42 for (auto type : func->params) hash(type);
43 hash(func->getNumVars());
44 for (auto type : func->vars) hash(type);
45 hash(func->result);
46 hash64(func->type.is() ? uint64_t(func->type.str) : uint64_t(0));
47 hash(ExpressionAnalyzer::hash(func->body));
48 output->at(func) = digest;
49 }
50
51private:
52 std::map<Function*, uint32_t>* output;
53 uint32_t digest = 0;
54
55 void hash(uint32_t hash) {
56 digest = rehash(digest, hash);
57 }
58 void hash64(uint64_t hash) {
59 digest = rehash(rehash(digest, uint32_t(hash >> 32)), uint32_t(hash));
60 };
61};
62
63struct FunctionReplacer : public WalkerPass<PostWalker<FunctionReplacer>> {
64 bool isFunctionParallel() override { return true; }
65
66 FunctionReplacer(std::map<Name, Name>* replacements) : replacements(replacements) {}
67
68 FunctionReplacer* create() override {
69 return new FunctionReplacer(replacements);
70 }
71
72 void visitCall(Call* curr) {
73 auto iter = replacements->find(curr->target);
74 if (iter != replacements->end()) {
75 curr->target = iter->second;
76 }
77 }
78
79private:
80 std::map<Name, Name>* replacements;
81};
82
83struct DuplicateFunctionElimination : public Pass {
84 void run(PassRunner* runner, Module* module) override {
85 while (1) {
86 // Hash all the functions
87 hashes.clear();
88 for (auto& func : module->functions) {
89 hashes[func.get()] = 0; // ensure an entry for each function - we must not modify the map shape in parallel, just the values
90 }
91 PassRunner hasherRunner(module);
92 hasherRunner.setIsNested(true);
93 hasherRunner.add<FunctionHasher>(&hashes);
94 hasherRunner.run();
95 // Find hash-equal groups
96 std::map<uint32_t, std::vector<Function*>> hashGroups;
97 for (auto& func : module->functions) {
98 hashGroups[hashes[func.get()]].push_back(func.get());
99 }
100 // Find actually equal functions and prepare to replace them
101 std::map<Name, Name> replacements;
102 std::set<Name> duplicates;
103 for (auto& pair : hashGroups) {
104 auto& group = pair.second;
105 if (group.size() == 1) continue;
106 // pick a base for each group, and try to replace everyone else to it. TODO: multiple bases per hash group, for collisions
107#if 0
108 // for comparison purposes, pick in a deterministic way based on the names
109 Function* base = nullptr;
110 for (auto* func : group) {
111 if (!base || strcmp(func->name.str, base->name.str) < 0) {
112 base = func;
113 }
114 }
115#else
116 Function* base = group[0];
117#endif
118 for (auto* func : group) {
119 if (func != base && equal(func, base)) {
120 replacements[func->name] = base->name;
121 duplicates.insert(func->name);
122 }
123 }
124 }
125 // perform replacements
126 if (replacements.size() > 0) {
127 // remove the duplicates
128 auto& v = module->functions;
129 v.erase(std::remove_if(v.begin(), v.end(), [&](const std::unique_ptr<Function>& curr) {
130 return duplicates.count(curr->name) > 0;
131 }), v.end());
132 module->updateMaps();
133 // replace direct calls
134 PassRunner replacerRunner(module);
135 replacerRunner.setIsNested(true);
136 replacerRunner.add<FunctionReplacer>(&replacements);
137 replacerRunner.run();
138 // replace in table
139 for (auto& segment : module->table.segments) {
140 for (auto& name : segment.data) {
141 auto iter = replacements.find(name);
142 if (iter != replacements.end()) {
143 name = iter->second;
144 }
145 }
146 }
147 // replace in start
148 if (module->start.is()) {
149 auto iter = replacements.find(module->start);
150 if (iter != replacements.end()) {
151 module->start = iter->second;
152 }
153 }
154 // replace in exports
155 for (auto& exp : module->exports) {
156 auto iter = replacements.find(exp->value);
157 if (iter != replacements.end()) {
158 exp->value = iter->second;
159 }
160 }
161 } else {
162 break;
163 }
164 }
165 }
166
167private:
168 std::map<Function*, uint32_t> hashes;
169
170 bool equal(Function* left, Function* right) {
171 if (left->getNumParams() != right->getNumParams()) return false;
172 if (left->getNumVars() != right->getNumVars()) return false;
173 for (Index i = 0; i < left->getNumLocals(); i++) {
174 if (left->getLocalType(i) != right->getLocalType(i)) return false;
175 }
176 if (left->result != right->result) return false;
177 if (left->type != right->type) return false;
178 return ExpressionAnalyzer::equal(left->body, right->body);
179 }
180};
181
182Pass *createDuplicateFunctionEliminationPass() {
183 return new DuplicateFunctionElimination();
184}
185
186} // namespace wasm