ASM 字节码插桩:隐私合规方法检测

1.前言

近两年来工信部对于应用的隐私合规安全问题愈加重视,对 Android 平台的管控程度也要比 IOS 平台严格很多,很多不合规的应用也先后被下架要求整改。笔者就曾遇到过加班整改隐私合规的问题,隐私合规问题主要针对两个方面。

  • 在用户同意隐私协议之前不能收集用户隐私数据,例如 IMEI、AndroidId、MAC 等

  • 在用户同意隐私协议之后,收集用户数据行为在对应场景不能超频。比如一分钟不能超过 3 次获取 IMEI

针对上述两个方面,有以下措施来针对

  • 通过静态扫描,收集项目中(自有代码 + 三方 sdk)使用隐私合规相关 api 的相关代码

  • 通过 ASM 插桩,在调用隐私合规 api 之前插入代码,记录运行时的方法调用链和当前时间

  • hook 隐私合规 api,替换字节码指令将调用链指向工具类,在未同意隐私协议之前,不调用相关的 api

2.实现

2.1 注解和工具类

通过定义注解和工具类,用来定义要处理哪些隐私合规相关的方法,目前笔者已经处理了大部分,请放心食用

2.1.1 注解

/**
 * 收集和注解匹配的方法
 * visitMethodInsn(int opcode, String owner, String name,String desc)
 *
 * ======如果 originName 和 originDesc 传"",逻辑会在插件中处理=====
 */@Retention(RetentionPolicy.CLASS)
@Target({ElementType.METHOD})
public @interface AsmMethodReplace {    /**
     * 指令操作码
     */
    int targetMethodOpcode();    /**
     * 方法所有者类
     */
    String targetClass();    /**
     * 方法名称
     */
    String targetName() default "";    /**
     * 方法描述符
     */
    String targetDesc() default "";    /**
     * 是否进行 hook
     */
    boolean hook() default false;


}

该注解用来匹配调用隐私合规 api 的字节码指令,例如通过 ASM 调用 getImei 的字节码指令为

methodVisitor.visitMethodInsn(INVOKEVIRTUAL, "android/telephony/TelephonyManager", "getImei", "()Ljava/lang/String;", false);

targetName 属性和 targetDesc 可以不赋值,这里做了取巧的处理,会根据工具类的 method namemethod descriptor 推断出调用隐私合规方法的字节码指令,这块后面插件会处理。最后的 hook 属性表示是否 hook 掉原始的调用链,将调用指向工具类中的方法。

2.1.2 工具类

上述注解可以用在任何地方,笔者将常用的隐私合规的方法聚合起来,方法工具类中统一处理,例如处理 IMEI 的逻辑如下:

@RequiresApi(api = Build.VERSION_CODES.O)
@AsmMethodReplace(targetMethodOpcode = OPCODE_INVOKEVIRTUAL
        , targetClass = CLASS_NAME_TELEPHONYMANAGER,hook = true)
public static String getImei(TelephonyManager telephonyManager) {    if (!checkAgreePrivacy("getImei")) {        Log.e(TAG, TIP);        return "";
    }    if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.Q) {        Log.i(TAG, "getImei-SDK_INT above android Q");        return "";
    }    return telephonyManager.getImei();
}

如果还没有同意隐私协议,直接 return “”,否者走正常的调用方法。同时,通过工具类setStoreDirectory设置存储调用堆栈和时间的文件路径。

2.2 插件处理

gradle 插件的基本使用在这里就不赘述了,这里主要是两方面的处理

  • 编译时扫描代码,处理有定义特定注解的方法,分析字节码指令,收集所有需要处理的隐私合规相关 api 相关的信息

  • 再次扫描,根据第一次扫描收集到的信息,判断当前类是否含有调用隐私合规 api 的字节码指令,如果有,在该类中注入一个写文件方法及在隐私合规 api 调用指令之前插入写文件的字节码指令,用来记录调用堆栈和频次。

2.2.1 模版代码

下面这块代码是我们在自定义 gradle 插件时常用的模版代码,供大家使用

package com.zhangyue.ireaderimport com.android.build.api.transform.*import com.android.build.gradle.internal.pipeline.TransformManagerimport com.zhangyue.ireader.plugin_privacy.PrivacyGlobalConfigimport com.zhangyue.ireader.util.CommonUtilimport com.zhangyue.ireader.util.Loggerimport org.apache.commons.io.FileUtilsimport org.apache.commons.io.IOUtilsimport org.gradle.api.Projectimport java.util.concurrent.AbstractExecutorServiceimport java.util.concurrent.Callableimport java.util.concurrent.ForkJoinPoolimport java.util.jar.JarEntryimport java.util.jar.JarFileimport java.util.jar.JarOutputStreamabstract class BaseTransform extends Transform {    AbstractExecutorService executorService = ForkJoinPool.commonPool()


    private List<Callable<Void>> taskList = new ArrayList<>()


    protected Project project    BaseTransform(Project project) {        this.project = project
    }


    @Override
    String getName() {        return getClass().simpleName
    }


    @Override
    Set<QualifiedContent.ContentType> getInputTypes() {        return TransformManager.CONTENT_CLASS
    }


    @Override
    Set<? super QualifiedContent.Scope> getScopes() {        return TransformManager.SCOPE_FULL_PROJECT
    }


    @Override
    boolean isIncremental() {        return true
    }


    @Override
    void transform(TransformInvocation transformInvocation) throws TransformException, InterruptedException, IOException {        super.transform(transformInvocation)        println("transform start--------------->")        if (firstTransform()) {            printCopyRight()
        }        onTransformStart(transformInvocation)
        def startTime = System.currentTimeMillis()
        def inputs = transformInvocation.inputs
        def outputProvider = transformInvocation.outputProvider
        def context = transformInvocation.context
        def isIncremental = transformInvocation.isIncremental()        if (!isIncremental) {
            outputProvider.deleteAll()
        }        //1//        inputs.each { input ->//            input.jarInputs.each { JarInput jarInput ->//                forEachJar(jarInput, outputProvider, context, isIncremental)//            }            input.directoryInputs.each { DirectoryInput dirInput ->//                forEachDir(dirInput, outputProvider, context, isIncremental)//            }//        }


        //3
        inputs.each { input ->
            input.jarInputs.each { jarInput ->                submitTask(new Runnable() {
                    @Override
                    void run() {                        forEachJar(jarInput, outputProvider, context, isIncremental)
                    }
                })
            }
            input.directoryInputs.each { DirectoryInput dirInput ->                submitTask(new Runnable() {
                    @Override
                    void run() {                        forEachDir(dirInput, outputProvider, context, isIncremental)
                    }
                })
            }
        }
        def futures = executorService.invokeAll(taskList)
        futures.each { it ->
            it.get()
        }        onTransformEnd(transformInvocation)        println(getName() + "transform end--------------->" + "duration : " + (System.currentTimeMillis() - startTime) + " ms")
    }    void submitTask(Runnable runnable) {
        taskList.add(new Callable<Void>() {
            @Override
            Void call() throws Exception {
                runnable.run()                return null
            }
        })
    }    void forEachDir(DirectoryInput directoryInput, TransformOutputProvider outputProvider, Context context, boolean isIncremental) {
        def inputDir = directoryInput.file
        File dest = outputProvider.getContentLocation(
                directoryInput.name,
                directoryInput.contentTypes,
                directoryInput.scopes,                Format.DIRECTORY
        )
        println "directoryInputPath:" + directoryInput.file.absolutePath
        println "destPath:" + dest.absolutePath
        def srcDirPath = inputDir.absolutePath
        def destDirPath = dest.absolutePath
        def temporaryDir = context.temporaryDir
        FileUtils.forceMkdir(dest)        Logger.info("srcDirPath:${srcDirPath}, destDirPath:${destDirPath}")        if (isIncremental) {
            directoryInput.getChangedFiles().each { entry ->
                def classFile = entry.key
                switch (entry.value) {                    case Status.NOTCHANGED:                        Logger.info("处理 class:" + classFile.absoluteFile + " NOTCHANGED")                        break
                    case Status.REMOVED:                        Logger.info("处理 class:" + classFile.absoluteFile + " REMOVED")                        //最终文件应该存放的路径
                        def destFilePath = classFile.absolutePath.replace(srcDirPath, destDirPath)
                        def destFile = File(destFilePath)                        if (destFile.exists()) {
                            destFile.delete()
                        }                        break
                    case Status.ADDED:                    case Status.CHANGED:                        Logger.info("处理 class:" + classFile.absoluteFile + " ADDED or CHANGED")                        modifyClassFile(classFile, srcDirPath, destDirPath, temporaryDir)                        break
                    default:                        break
                }
            }
        } else {
            com.android.utils.FileUtils.getAllFiles(inputDir).each { File file ->                modifyClassFile(file, srcDirPath, destDirPath, temporaryDir)
            }
        }
    }    void modifyClassFile(classFile, srcDirPath, destDirPath, temporaryDir) {        Logger.info("处理 class:" + classFile.absoluteFile)        //目标路径
        def destFilePath = classFile.absolutePath.replace(srcDirPath, destDirPath)
        def destFile = new File(destFilePath)        if (destFile.exists()) {
            destFile.delete()
        }        Logger.info("处理 class:destFile" + destFile.absoluteFile)        String className = CommonUtil.path2ClassName(classFile.absolutePath.replace(srcDirPath + File.separator, ""))        Logger.info("处理 className:" + className)        File modifyFile = null
        if (CommonUtil.isLegalClass(classFile) && shouldHookClass(className)) {
            modifyFile = getModifyFile(classFile, temporaryDir, className)
        }        if (modifyFile == null) {
            modifyFile = classFile
        }        FileUtils.copyFile(modifyFile, destFile)
    }    File getModifyFile(File classFile, File temporaryDir, String className) {
        byte[] sourceBytes = IOUtils.toByteArray(new FileInputStream(classFile))
        def tempFile = new File(temporaryDir, CommonUtil.generateClassFileName(classFile))        if (tempFile.exists()) {            FileUtils.forceDelete(tempFile)
        }
        def modifyBytes = modifyClass(className, sourceBytes)        if (modifyBytes == null) {
            modifyBytes = sourceBytes
        }
        tempFile.createNewFile()
        def fos = new FileOutputStream(tempFile)
        fos.write(modifyBytes)
        fos.flush()        IOUtils.closeQuietly(fos)        return tempFile
    }    void forEachJar(JarInput jarInput, TransformOutputProvider outputProvider, Context context, boolean isIncremental) {        Logger.info("jarInput:" + jarInput.file)        File destFile = outputProvider.getContentLocation(                //防止同名被覆盖
                CommonUtil.generateJarFileName(jarInput.file), jarInput.contentTypes, jarInput.scopes, Format.JAR)        //增量编译处理
        if (isIncremental) {            Status status = jarInput.status
            switch (status) {                case Status.NOTCHANGED:                    Logger.info("处理 jar:" + jarInput.file.absoluteFile + " NotChanged")                    //Do nothing
                    return
                case Status.REMOVED:                    Logger.info("处理 jar:" + jarInput.file.absoluteFile + " REMOVED")                    if (destFile.exists()) {                        FileUtils.forceDelete(destFile)
                    }                    return
                case Status.ADDED:                case Status.CHANGED:                    Logger.info("处理 jar:" + jarInput.file.absoluteFile + " ADDED or CHANGED")                    break
            }
        }        if (destFile.exists()) {            FileUtils.forceDelete(destFile)
        }        CommonUtil.isLegalJar(jarInput.file) ? transformJar(jarInput.file, context.getTemporaryDir(), destFile)
                : FileUtils.copyFile(jarInput.file, destFile)
    }


    def transformJar(File jarFile, File temporaryDir, File destFile) {        Logger.info("处理 jar:" + jarFile.absoluteFile)        File tempOutputJarFile = new File(temporaryDir, CommonUtil.generateJarFileName(jarFile))        if (tempOutputJarFile.exists()) {            FileUtils.forceDelete(tempOutputJarFile)
        }        JarOutputStream jarOutputStream = new JarOutputStream(new FileOutputStream(tempOutputJarFile))        JarFile inputJarFile = new JarFile(jarFile, false)        try {
            def entries = inputJarFile.entries()            while (entries.hasMoreElements()) {
                def jarEntry = entries.nextElement()
                def entryName = jarEntry.getName()
                def inputStream = inputJarFile.getInputStream(jarEntry)                try {
                    byte[] sourceByteArray = IOUtils.toByteArray(inputStream)
                    def modifiedByteArray = null
                    if (!jarEntry.isDirectory() && CommonUtil.isLegalClass(entryName)) {                        String className = CommonUtil.path2ClassName(entryName)                        if (shouldHookClass(className)) {
                            modifiedByteArray = modifyClass(className, sourceByteArray)
                        }
                    }                    if (modifiedByteArray == null) {
                        modifiedByteArray = sourceByteArray
                    }
                    jarOutputStream.putNextEntry(new JarEntry(entryName))
                    jarOutputStream.write(modifiedByteArray)
                    jarOutputStream.closeEntry()
                } finally {                    IOUtils.closeQuietly(inputStream)
                }
            }
        } finally {
            jarOutputStream.flush()            IOUtils.closeQuietly(jarOutputStream)            IOUtils.closeQuietly(inputJarFile)
        }        FileUtils.copyFile(tempOutputJarFile, destFile)
    }


    private byte[] modifyClass(String className, byte[] sourceBytes) {
        byte[] classBytesCode        try {
            classBytesCode = hookClassInner(className, sourceBytes)
        } catch (Throwable e) {
            e.printStackTrace()
            classBytesCode = null
            println "throw exception when modify class ${className}"
        }        return classBytesCode
    }    /**
     * 打印日志信息
     */
    static void printCopyRight() {        println()
        println '#######################################################################'
        println '##########                                                    '
        println '##########                欢迎使用隐私合规处理插件'
        println '##########                                                    '
        println '#######################################################################'
        println '##########                                                    '
        println '##########                 插件配置参数                         '
        println '##########                                                    '
        println '##########                -isDebug: ' + PrivacyGlobalConfig.isDebug
        println '##########                -handleAnnotationName: ' + PrivacyGlobalConfig.handleAnnotationName
        println '##########                -exclude: ' + PrivacyGlobalConfig.exclude
        println '##########                                                    '
        println '##########                                                    '
        println '##########                                                    '
        println '#######################################################################'
        println()
    }


    protected boolean firstTransform() {        return false
    }


    boolean shouldHookClass(String className) {
        def excludes = PrivacyGlobalConfig.exclude
        if (excludes != null) {            for (String string : excludes) {                if (className.startsWith(string)) {                    return false
                }
            }
        }        return shouldHookClassInner(className)
    }


    protected abstract boolean shouldHookClassInner(String className)


    protected abstract byte[] hookClassInner(String className, byte[] bytes)


    protected abstract void onTransformStart(TransformInvocation transformInvocation)


    protected abstract void onTransformEnd(TransformInvocation transformInvocation)
}

2.2.2 注解处理 transform

收集和处理具有特定注解的字节码指令,给下一个 transform 使用

@Overridebyte[] hookClassInner(String className, byte[] bytes) {    ClassReader cr = new ClassReader(bytes)    ClassNode classNode = new ClassNode()
    cr.accept(classNode, 0)
    classNode.methods.each { methodNode ->        //编译期注解
        methodNode.invisibleAnnotations.
                each { annotationNode ->                    if (PrivacyGlobalConfig.getHandleAnnotationName() == annotationNode.desc) {                        collectPrivacyMethod(annotationNode, methodNode, cr.className)
                    }
                }
    }    return bytes
}/**
 * 收集注解和注解关联的方法
 * @param annotationNode 注解信息
 * @param methodNode 方法信息
 */static collectPrivacyMethod(AnnotationNode annotationNode, MethodNode methodNode, String className) {    List<Object> values = annotationNode.values
    Logger.info("annotation values : ${values}")    MethodReplaceItem item = new MethodReplaceItem(values, methodNode, CommonUtil.getClassInternalName(className))    PrivacyGlobalConfig.methodReplaceItemList.offer(item)    Logger.info("collectPrivacyMethod success: ${item}")    println("collectPrivacyMethod success: ${item}")
}

MethodReplaceItem中封装了收集到的字节码属性,同时会根据注解关联方法的字节码指令推断出想要处理的隐私合规 api 的字节码指令。

MethodReplaceItem(List<Object> annotationPair, MethodNode methodNode, String owner) {
    replaceOpcode = Opcodes.INVOKESTATIC
    replaceClass = owner
    replaceMethod = methodNode.name
    replaceDesc = methodNode.desc


    for (int i = 0; i < annotationPair.size(); i = i + 2) {
        def key = annotationPair[i]
        def value = annotationPair[i + 1]        if (key == "targetMethodOpcode") {
            targetOpcode = value
        } else if (key == "targetClass") {
            targetOwner = value
        } else if (key == "targetName") {
            targetMethod = value
        } else if (key == "targetDesc") {
            targetDesc = value
        }else if(key == "hook"){
            willHook = value
        }
    }    Logger.info("=====targetOpcode:${targetOpcode},targetOwner:${targetOwner} , replaceDesc${replaceDesc}")    if (isEmpty(targetMethod)) {
        targetMethod = replaceMethod
    }    if (isEmpty(targetDesc)) {        //静态方法,oriDesc 跟 targetDesc 一样
        if (targetOpcode == Opcodes.INVOKESTATIC) {
            targetDesc = replaceDesc
        } else {            //非静态方法,约定第一个参数是实例类名,oriDesc 比 targetDesc 少一个参数,处理一下
            // (Landroid/telephony/TelephonyManager;)Ljava/lang/String ->  ()Ljava/lang/String
            def param = replaceDesc.split('\)')[0] + ")"
            def result = replaceDesc.split('\)')[1]
            def index = replaceDesc.indexOf(targetOwner)            if (index != -1) {
                param = "(" + param.substring(index + targetOwner.length() + 1)
            }            Logger.info("index::: ${index}")
            targetDesc = param + result
        }
    }


}

2.2.3 合规方法处理 transform

再次扫描整个项目,根据在上一个 transform 中收集到的要处理的隐私合规的 api,遍历字节码指令,当匹配上时,在当前的类中注入写文件的方法,同时在调用隐私合规的字节码指令前插入写文件的字节码指令,用来记录。

@Overridebyte[] hookClassInner(String className, byte[] bytes) {    Logger.info("${getName()} modifyClassInner--------------->")
    def findHookPoint = false
    Map<MethodNode, InsertInsnPoint> collectMap = new HashMap<>()    ClassReader cr = new ClassReader(bytes)    ClassNode classNode = new ClassNode()
    cr.accept(classNode, ClassReader.EXPAND_FRAMES)
    classNode.methods.each { methodNode ->        //过滤掉含有特定注解的方法
        if (isNotHookMethod(cr.className, methodNode)) {
            methodNode.instructions.each { insnNode ->                //判断字节码能否匹配
                def methodReplaceItem = searchHookPoint(insnNode)                if (methodReplaceItem != null) {                    //判断是否需要 hook 掉当前指令
                    def inject = methodReplaceItem.willHook
                    //记录隐私合规 api 所在的类及方法
                    logHookPoint(classNode.name, methodReplaceItem, methodNode, insnNode.opcode, insnNode.owner, insnNode.name, insnNode.desc, inject)                    if (inject) {                        //hook
                        injectInsn(insnNode, methodReplaceItem)
                    }                    //插入写文件方法指令,收集调用隐私方法的堆栈
                    collectInsertInsn(insnNode, methodNode, classNode, collectMap, inject)
                    findHookPoint = true
                }
            }
        }
    }    if (!collectMap.isEmpty() && findHookPoint) {        //插入写文件指令,用来展示堆栈信息
        collectMap.each { key, value ->
            key.instructions.insert(value.hookInsnNode, value.instList)
        }        //插入 writeToFile 方法
        ClassWriter cw = new ClassWriter(ClassWriter.COMPUTE_MAXS)
        classNode.accept(cw)        insertWriteToFileMethod(cw)        return cw.toByteArray()
    }    return bytes
}

collectInsertInsn 方法中,通过 throwable 来收集当前的堆栈

/**
 * 收集 在调用特定的方法前插入调用写入文件的方法的指令
 * @param insnNode
 * @param methodNode
 * @param classNode
 * @param collectMap
 */static void collectInsertInsn(insnNode, methodNode, classNode, collectMap, Inject) {
    def className = classNode.name
    def methodName = methodNode.name
    def methodDesc = methodNode.desc
    def owner = null
    def name = null
    def desc = null
    if (insnNode instanceof MethodInsnNode) {
        owner = insnNode.owner
        name = insnNode.name
        desc = insnNode.desc
    }    //------log
    StringBuilder lintLog = new StringBuilder()
    lintLog.append(className)
    lintLog.append("  ->  ")
    lintLog.append(methodName)
    lintLog.append("  ->  ")
    lintLog.append(methodDesc)
    lintLog.append("\r\n")
    lintLog.append(owner)
    lintLog.append("  ->  ")
    lintLog.append(name)
    lintLog.append("  ->  ")
    lintLog.append(desc)    //------要插入字节码指令
    lintLog.append("\r\n")    InsnList insnList = new InsnList()
    insnList.add(new LdcInsnNode(lintLog.toString()))
    insnList.add(new TypeInsnNode(Opcodes.NEW, "java/lang/Throwable"))
    insnList.add(new InsnNode(Opcodes.DUP))
    insnList.add(new MethodInsnNode(Opcodes.INVOKESPECIAL, "java/lang/Throwable", "<init>", "()V", false))
    insnList.add(new MethodInsnNode(Opcodes.INVOKESTATIC, className, writeToFileMethodName, writeToFileMethodDesc))
    println "插入指令完成 =---------->"
    collectMap.put(methodNode, new InsertInsnPoint(insnList, insnNode))
}

最终,在项目编译完成之后,会在项目的根目录下生成 replaceInsn.txt 文件,记录包含隐私合规 api 的类和相关方法。

outside_default.png

当项目运行起来之后,会在设置的路径中(笔者设置在 getExternalCacheDir 中)生成 privacy_log.txt 文件,里面会记录隐私合规 api 的调用堆栈和时间,根据该调用链,我们就可以快速定位是哪一块业务执行了敏感操作。

outside_default.png

总结

通过 ASM + gradle plugin ,能够排查出大部分的隐私合规问题。有什么不足之处,也请读者多多提意见和建议。

源码

改项目已经在 github 上开源,希望大家多多围观      https://github.com/season-max/asm_hook.git


作者:season_y
链接:https://juejin.cn/post/7128724162645852168

关注我获取更多知识或者投稿

678fdc2ddf38f3d75e2f82e77d1c7804.jpeg

a0cb4acbaced70554ab4b42afdd8a43c.jpeg

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值