diff options
Diffstat (limited to 'llvm_mode/compare-transform-pass.so.cc')
-rw-r--r-- | llvm_mode/compare-transform-pass.so.cc | 306 |
1 files changed, 306 insertions, 0 deletions
diff --git a/llvm_mode/compare-transform-pass.so.cc b/llvm_mode/compare-transform-pass.so.cc new file mode 100644 index 00000000..acca3ff0 --- /dev/null +++ b/llvm_mode/compare-transform-pass.so.cc @@ -0,0 +1,306 @@ +/* + * Copyright 2016 laf-intel + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <stdio.h> +#include <stdlib.h> +#include <unistd.h> + +#include "llvm/ADT/Statistic.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/LegacyPassManager.h" +#include "llvm/IR/Module.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" +#include "llvm/Transforms/IPO/PassManagerBuilder.h" +#include "llvm/Transforms/Utils/BasicBlockUtils.h" +#include "llvm/IR/Verifier.h" +#include "llvm/Pass.h" +#include "llvm/Analysis/ValueTracking.h" + +#include <set> + +using namespace llvm; + +namespace { + + class CompareTransform : public ModulePass { + + public: + static char ID; + CompareTransform() : ModulePass(ID) { + } + + bool runOnModule(Module &M) override; + +#if __clang_major__ < 4 + const char * getPassName() const override { +#else + StringRef getPassName() const override { +#endif + return "transforms compare functions"; + } + private: + bool transformCmps(Module &M, const bool processStrcmp, const bool processMemcmp + ,const bool processStrncmp, const bool processStrcasecmp, const bool processStrncasecmp); + }; +} + + +char CompareTransform::ID = 0; + +bool CompareTransform::transformCmps(Module &M, const bool processStrcmp, const bool processMemcmp + , const bool processStrncmp, const bool processStrcasecmp, const bool processStrncasecmp) { + + std::vector<CallInst*> calls; + LLVMContext &C = M.getContext(); + IntegerType *Int8Ty = IntegerType::getInt8Ty(C); + IntegerType *Int32Ty = IntegerType::getInt32Ty(C); + IntegerType *Int64Ty = IntegerType::getInt64Ty(C); + Constant* c = M.getOrInsertFunction("tolower", + Int32Ty, + Int32Ty +#if __clang_major__ < 7 + , nullptr +#endif + ); + Function* tolowerFn = cast<Function>(c); + + /* iterate over all functions, bbs and instruction and add suitable calls to strcmp/memcmp/strncmp/strcasecmp/strncasecmp */ + for (auto &F : M) { + for (auto &BB : F) { + for(auto &IN: BB) { + CallInst* callInst = nullptr; + + if ((callInst = dyn_cast<CallInst>(&IN))) { + + bool isStrcmp = processStrcmp; + bool isMemcmp = processMemcmp; + bool isStrncmp = processStrncmp; + bool isStrcasecmp = processStrcasecmp; + bool isStrncasecmp = processStrncasecmp; + + Function *Callee = callInst->getCalledFunction(); + if (!Callee) + continue; + if (callInst->getCallingConv() != llvm::CallingConv::C) + continue; + StringRef FuncName = Callee->getName(); + isStrcmp &= !FuncName.compare(StringRef("strcmp")); + isMemcmp &= !FuncName.compare(StringRef("memcmp")); + isStrncmp &= !FuncName.compare(StringRef("strncmp")); + isStrcasecmp &= !FuncName.compare(StringRef("strcasecmp")); + isStrncasecmp &= !FuncName.compare(StringRef("strncasecmp")); + + if (!isStrcmp && !isMemcmp && !isStrncmp && !isStrcasecmp && !isStrncasecmp) + continue; + + /* Verify the strcmp/memcmp/strncmp/strcasecmp/strncasecmp function prototype */ + FunctionType *FT = Callee->getFunctionType(); + + + isStrcmp &= FT->getNumParams() == 2 && + FT->getReturnType()->isIntegerTy(32) && + FT->getParamType(0) == FT->getParamType(1) && + FT->getParamType(0) == IntegerType::getInt8PtrTy(M.getContext()); + isStrcasecmp &= FT->getNumParams() == 2 && + FT->getReturnType()->isIntegerTy(32) && + FT->getParamType(0) == FT->getParamType(1) && + FT->getParamType(0) == IntegerType::getInt8PtrTy(M.getContext()); + isMemcmp &= FT->getNumParams() == 3 && + FT->getReturnType()->isIntegerTy(32) && + FT->getParamType(0)->isPointerTy() && + FT->getParamType(1)->isPointerTy() && + FT->getParamType(2)->isIntegerTy(); + isStrncmp &= FT->getNumParams() == 3 && + FT->getReturnType()->isIntegerTy(32) && + FT->getParamType(0) == FT->getParamType(1) && + FT->getParamType(0) == IntegerType::getInt8PtrTy(M.getContext()) && + FT->getParamType(2)->isIntegerTy(); + isStrncasecmp &= FT->getNumParams() == 3 && + FT->getReturnType()->isIntegerTy(32) && + FT->getParamType(0) == FT->getParamType(1) && + FT->getParamType(0) == IntegerType::getInt8PtrTy(M.getContext()) && + FT->getParamType(2)->isIntegerTy(); + + if (!isStrcmp && !isMemcmp && !isStrncmp && !isStrcasecmp && !isStrncasecmp) + continue; + + /* is a str{n,}{case,}cmp/memcmp, check is we have + * str{case,}cmp(x, "const") or str{case,}cmp("const", x) + * strn{case,}cmp(x, "const", ..) or strn{case,}cmp("const", x, ..) + * memcmp(x, "const", ..) or memcmp("const", x, ..) */ + Value *Str1P = callInst->getArgOperand(0), *Str2P = callInst->getArgOperand(1); + StringRef Str1, Str2; + bool HasStr1 = getConstantStringInfo(Str1P, Str1); + bool HasStr2 = getConstantStringInfo(Str2P, Str2); + + /* handle cases of one string is const, one string is variable */ + if (!(HasStr1 ^ HasStr2)) + continue; + + if (isMemcmp || isStrncmp || isStrncasecmp) { + /* check if third operand is a constant integer + * strlen("constStr") and sizeof() are treated as constant */ + Value *op2 = callInst->getArgOperand(2); + ConstantInt* ilen = dyn_cast<ConstantInt>(op2); + if (!ilen) + continue; + /* final precaution: if size of compare is larger than constant string skip it*/ + uint64_t literalLength = HasStr1 ? GetStringLength(Str1P) : GetStringLength(Str2P); + if (literalLength < ilen->getZExtValue()) + continue; + } + + calls.push_back(callInst); + } + } + } + } + + if (!calls.size()) + return false; + errs() << "Replacing " << calls.size() << " calls to strcmp/memcmp/strncmp/strcasecmp/strncasecmp\n"; + + for (auto &callInst: calls) { + + Value *Str1P = callInst->getArgOperand(0), *Str2P = callInst->getArgOperand(1); + StringRef Str1, Str2, ConstStr; + Value *VarStr; + bool HasStr1 = getConstantStringInfo(Str1P, Str1); + getConstantStringInfo(Str2P, Str2); + uint64_t constLen, sizedLen; + bool isMemcmp = !callInst->getCalledFunction()->getName().compare(StringRef("memcmp")); + bool isSizedcmp = isMemcmp + || !callInst->getCalledFunction()->getName().compare(StringRef("strncmp")) + || !callInst->getCalledFunction()->getName().compare(StringRef("strncasecmp")); + bool isCaseInsensitive = !callInst->getCalledFunction()->getName().compare(StringRef("strcasecmp")) + || !callInst->getCalledFunction()->getName().compare(StringRef("strncasecmp")); + + if (isSizedcmp) { + Value *op2 = callInst->getArgOperand(2); + ConstantInt* ilen = dyn_cast<ConstantInt>(op2); + sizedLen = ilen->getZExtValue(); + } + + if (HasStr1) { + ConstStr = Str1; + VarStr = Str2P; + constLen = isMemcmp ? sizedLen : GetStringLength(Str1P); + } + else { + ConstStr = Str2; + VarStr = Str1P; + constLen = isMemcmp ? sizedLen : GetStringLength(Str2P); + } + if (isSizedcmp && constLen > sizedLen) { + constLen = sizedLen; + } + + errs() << callInst->getCalledFunction()->getName() << ": len " << constLen << ": " << ConstStr << "\n"; + + /* split before the call instruction */ + BasicBlock *bb = callInst->getParent(); + BasicBlock *end_bb = bb->splitBasicBlock(BasicBlock::iterator(callInst)); + BasicBlock *next_bb = BasicBlock::Create(C, "cmp_added", end_bb->getParent(), end_bb); + BranchInst::Create(end_bb, next_bb); + PHINode *PN = PHINode::Create(Int32Ty, constLen + 1, "cmp_phi"); + + TerminatorInst *term = bb->getTerminator(); + BranchInst::Create(next_bb, bb); + term->eraseFromParent(); + + for (uint64_t i = 0; i < constLen; i++) { + + BasicBlock *cur_bb = next_bb; + + char c = isCaseInsensitive ? tolower(ConstStr[i]) : ConstStr[i]; + + + BasicBlock::iterator IP = next_bb->getFirstInsertionPt(); + IRBuilder<> IRB(&*IP); + + Value* v = ConstantInt::get(Int64Ty, i); + Value *ele = IRB.CreateInBoundsGEP(VarStr, v, "empty"); + Value *load = IRB.CreateLoad(ele); + if (isCaseInsensitive) { + // load >= 'A' && load <= 'Z' ? load | 0x020 : load + std::vector<Value *> args; + args.push_back(load); + load = IRB.CreateCall(tolowerFn, args, "tmp"); + } + Value *isub; + if (HasStr1) + isub = IRB.CreateSub(ConstantInt::get(Int8Ty, c), load); + else + isub = IRB.CreateSub(load, ConstantInt::get(Int8Ty, c)); + + Value *sext = IRB.CreateSExt(isub, Int32Ty); + PN->addIncoming(sext, cur_bb); + + + if (i < constLen - 1) { + next_bb = BasicBlock::Create(C, "cmp_added", end_bb->getParent(), end_bb); + BranchInst::Create(end_bb, next_bb); + + TerminatorInst *term = cur_bb->getTerminator(); + Value *icmp = IRB.CreateICmpEQ(isub, ConstantInt::get(Int8Ty, 0)); + IRB.CreateCondBr(icmp, next_bb, end_bb); + term->eraseFromParent(); + } else { + //IRB.CreateBr(end_bb); + } + + //add offset to varstr + //create load + //create signed isub + //create icmp + //create jcc + //create next_bb + } + + /* since the call is the first instruction of the bb it is safe to + * replace it with a phi instruction */ + BasicBlock::iterator ii(callInst); + ReplaceInstWithInst(callInst->getParent()->getInstList(), ii, PN); + } + + + return true; +} + +bool CompareTransform::runOnModule(Module &M) { + + llvm::errs() << "Running compare-transform-pass by laf.intel@gmail.com, extended by heiko@hexco.de\n"; + transformCmps(M, true, true, true, true, true); + verifyModule(M); + + return true; +} + +static void registerCompTransPass(const PassManagerBuilder &, + legacy::PassManagerBase &PM) { + + auto p = new CompareTransform(); + PM.add(p); + +} + +static RegisterStandardPasses RegisterCompTransPass( + PassManagerBuilder::EP_OptimizerLast, registerCompTransPass); + +static RegisterStandardPasses RegisterCompTransPass0( + PassManagerBuilder::EP_EnabledOnOptLevel0, registerCompTransPass); + |