一款用于python类型推断的pycharm插件

推断依据

变量声明
函数声明
基本类型
引用类型
调用表达式
调用函数
函数返回类型
类实例化
类名
变量运算
for语句
入参类型
函数调用
返回值类型
return语句

推断遵循原则

1.不改变用户已有的类型声明
2.推断出错时,不添加类型
3.推断路线有多个分支时,选择第一个分支

待实现的功能

1.常见场景下变量声明类型的推断
2.常见场景下函数声明入参,出参类推断
3.存在依赖时递归推断
4.调用内置函数时,推断出的返回类型为None,应返回正确的类型
5.递归时,返回类型出错,应返回正确的类型
6.出错时,需要返回一个信息,显示推断出错的大致行号
7.变量类型推断适配for语句
8.变量类型推断适配下标访问语句
9.变量类型推断多个变量同时赋值
10.调用函数时指定参数名称
11.函数递归时,会无限递归下去

代码实现

type infer

import com.intellij.codeInspection.LocalQuickFix;
import com.intellij.codeInspection.ProblemDescriptor;
import com.intellij.codeInspection.ProblemsHolder;
import com.intellij.codeInspection.util.IntentionFamilyName;
import com.intellij.openapi.project.Project;
import com.intellij.openapi.ui.Messages;
import com.intellij.psi.PsiElement;
import com.intellij.psi.PsiElementVisitor;
import com.intellij.psi.PsiReference;
import com.intellij.psi.search.searches.ReferencesSearch;
import com.intellij.psi.tree.IElementType;
import com.jetbrains.python.PyTokenTypes;
import com.jetbrains.python.inspections.PyInspection;
import com.jetbrains.python.psi.*;
import com.jetbrains.python.psi.impl.PyFunctionImpl;
import com.jetbrains.python.psi.impl.PyTargetExpressionImpl;
import com.jetbrains.python.psi.types.PyType;
import com.jetbrains.python.psi.types.TypeEvalContext;
import org.jetbrains.annotations.NotNull;

import java.util.Objects;

import static com.jetbrains.python.PyElementTypes.*;
import static com.jetbrains.python.inspections.PyTypeCheckerInspection.Visitor.tryPromotingType;

public class TypeHits extends PyInspection {
    private final TypeFix myQuickFix = new TypeFix();

    @Override
    public @NotNull PsiElementVisitor buildVisitor(@NotNull ProblemsHolder holder, boolean isOnTheFly) {
        return new PyElementVisitor() {
            @Override
            public void visitPyElement(@NotNull PyElement element) {
                if (element instanceof PyTargetExpression) {
                    if (((PyTargetExpression) element).getAnnotationValue() == null
                            && !(element.getParent() instanceof PyForPart)
                            && ((PyTargetExpressionImpl) element).getReference().resolve() == element) {
                        holder.registerProblem(element, "No type declare of variable " + element.getText(), myQuickFix);
                    }
                } else if (element instanceof PyFunction pyFunction) {
                    boolean isNeedFix = pyFunction.getAnnotation() == null;
                    for (PyParameter parameter : pyFunction.getParameterList().getParameters()) {
                        if (!parameter.getText().equals("self") && ((PyNamedParameter) parameter).getAnnotation() == null) {
                            isNeedFix = true;
                            break;
                        }
                    }
                    if (isNeedFix) {
                        holder.registerProblem(pyFunction.getNameNode().getPsi(),
                                "lose some type declare of function " + pyFunction.getName(), myQuickFix);
                    }
                }
            }
        };

    }

    private static class TypeFix implements LocalQuickFix {
        PyElementGenerator pyElementGenerator = null;
        TypeEvalContext typeEvalContext = null;
        Project project = null;

        @NotNull
        @Override
        public String getName() {
            return InspectionBundle.message("inspection.checking.type.declare.use.quickfix");
        }


        @Override
        public @IntentionFamilyName
        @NotNull String getFamilyName() {
            return getName();
        }

        @Override
        public void applyFix(@NotNull Project project, @NotNull ProblemDescriptor descriptor) {
            pyElementGenerator = PyElementGenerator.getInstance(project);
            typeEvalContext = TypeEvalContext.userInitiated(project, descriptor.getPsiElement().getContainingFile());
            this.project = project;
            try {
                applyFixElement(descriptor.getPsiElement(), pyElementGenerator);
            } catch (Exception e) {
                Messages.showErrorDialog(e.getMessage(), "infer type error");

                System.out.println(e.getMessage());
            }
        }

        private void applyFixElement(PsiElement psiElement, PyElementGenerator pyElementGenerator) throws Exception {
            if (psiElement instanceof PyTargetExpression) {
                if (psiElement.getParent() instanceof PyTupleExpression) {
                    PyAssignmentStatement assignmentStatement = (PyAssignmentStatement) psiElement.getParent().getParent();
                    PsiElement[] assignValues = assignmentStatement.getAssignedValue().getChildren();
                    PyExpression[] targets = assignmentStatement.getTargets();
                    for (int i = 0; i < targets.length; i++) {
                        StringBuilder builder = new StringBuilder();
                        inferAnnotation(assignValues[i], builder);
                        PyTypeDeclarationStatement templateDeclaration = pyElementGenerator.createFromText(
                                LanguageLevel.forElement(psiElement), PyTypeDeclarationStatement.class,
                                targets[i].getName() + ":" + builder);
                        assignmentStatement.getParent().addBefore(templateDeclaration, assignmentStatement);
                    }
                    return;
                }
                StringBuilder annotationBuilder = new StringBuilder();
                inferAnnotation(((PyAssignmentStatement) psiElement.getParent()).getAssignedValue(), annotationBuilder);
                PyTypeDeclarationStatement templateDeclaration = pyElementGenerator.createFromText(
                        LanguageLevel.forElement(psiElement), PyTypeDeclarationStatement.class,
                        "a:" + annotationBuilder);
                psiElement.add(templateDeclaration.getAnnotation());
                refreshAssignment(psiElement);
            } else if (psiElement.getParent() instanceof PyFunction || psiElement instanceof PyNamedParameter) {
                PyFunction function;
                if (psiElement.getParent() instanceof PyFunction) {
                    function = (PyFunction) psiElement.getParent();
                } else {
                    function = (PyFunction) psiElement.getParent().getParent();
                }

                PyParameter[] parameters = function.getParameterList().getParameters();
                PsiReference psiReference = ReferencesSearch.search(function).findFirst();
                if (psiReference != null) {
                    PyCallExpression pyCallExpression = (PyCallExpression) (psiReference.getElement().getParent());
                    PyExpression[] referenceArguments = pyCallExpression.getArgumentList().getArguments();
                    PyExpression[] orderedArguments = getOrderedReferenceArguments(parameters, referenceArguments);
                    for (int i = 0; i < orderedArguments.length; i++) {
                        StringBuilder annotationBuilder = new StringBuilder();
                        inferAnnotation(orderedArguments[i], annotationBuilder);
                        if (((PyNamedParameter) parameters[i]).getAnnotation() == null) {
                            PyTypeDeclarationStatement templateDeclaration = pyElementGenerator.createFromText(
                                    LanguageLevel.forElement(psiElement), PyTypeDeclarationStatement.class,
                                    "a:" + annotationBuilder);
                            parameters[i].add(templateDeclaration.getAnnotation());
                        }
                    }
                }
                if (function.getAnnotation() == null) {
                    PyType returnStatementType = function.getReturnStatementType(typeEvalContext);
                    PyFunction templateFunction = pyElementGenerator.createFromText(
                            LanguageLevel.forElement(psiElement), PyFunction.class,
                            "def a()->" + returnStatementType.getName() + ":\n	pass");
                    function.addAfter(templateFunction.getAnnotation(), function.getParameterList());
                }

            }
        }

        private PyExpression[] getOrderedReferenceArguments(PyParameter[] parameters, PyExpression[] referenceArguments) {
            if (referenceArguments.length == 0 || !(referenceArguments[0] instanceof PyKeywordArgument)) {
                return referenceArguments;
            }
            PyExpression[] result = new PyExpression[referenceArguments.length];
            for (int i = 0; i < result.length; i++) {
                String targetName = parameters[i].getName();
                int j = 0;
                PyKeywordArgument target = (PyKeywordArgument)referenceArguments[j];
                while (!Objects.equals(target.getKeyword(), targetName)) {
                    j++;
                    target=(PyKeywordArgument)referenceArguments[j];
                }
                result[i] = target.getValueExpression();
            }
            return result;
        }

        private void refreshAssignment(PsiElement psiElement) {
            PsiElement temp = psiElement;
            while (!(temp instanceof PyAssignmentStatement)) {
                temp = temp.getParent();
            }
            PyAssignmentStatement newPyAssignmentStatement = pyElementGenerator.createFromText(
                    LanguageLevel.forElement(temp), PyAssignmentStatement.class,
                    temp.getText());
            temp.replace(newPyAssignmentStatement);
        }


        private boolean inferAnnotation(PsiElement element, StringBuilder stringBuilder) throws Exception {
            if (element == null) {
                throw new Exception("PsiElement is null");
            }
            IElementType elementType = element.getNode().getElementType();
            if (elementType.equals(INTEGER_LITERAL_EXPRESSION)) {
                stringBuilder.append("int");
            } else if (elementType.equals(FLOAT_LITERAL_EXPRESSION)) {
                stringBuilder.append("float");
            } else if (elementType.equals(IMAGINARY_LITERAL_EXPRESSION)) {
                stringBuilder.append("imaginary");
            } else if (elementType.equals(STRING_LITERAL_EXPRESSION)) {
                stringBuilder.append("str");
            } else if (elementType.equals(NONE_LITERAL_EXPRESSION)) {
                stringBuilder.append("None");
            } else if (elementType.equals(BOOL_LITERAL_EXPRESSION)) {
                stringBuilder.append("bool");
            } else if (element instanceof PyDictLiteralExpression) {
                PyKeyValueExpression[] elements = ((PyDictLiteralExpression) element).getElements();
                if (elements.length == 0) {
                    stringBuilder.append("dict[None]");
                    return true;
                }
                stringBuilder.append("dict[");
                inferAnnotation(elements[0].getKey(), stringBuilder);
                stringBuilder.append(":");
                inferAnnotation(elements[0].getValue(), stringBuilder);
                stringBuilder.append("]");
            } else if (element instanceof PyListLiteralExpression) {
                PyExpression[] elements = ((PyListLiteralExpression) element).getElements();
                if (elements.length == 0) {
                    stringBuilder.append("list[None]");
                    return true;
                }
                stringBuilder.append("list[");
                inferAnnotation(elements[0], stringBuilder);
                stringBuilder.append("]");
            } else if (element instanceof PySetLiteralExpression) {
                PyExpression[] elements = ((PySetLiteralExpression) element).getElements();
                if (elements.length == 0) {
                    stringBuilder.append("set[None]");
                    return true;
                }
                stringBuilder.append("set[");
                inferAnnotation(elements[0], stringBuilder);
                stringBuilder.append("]");
            } else if (elementType.equals(REFERENCE_EXPRESSION)) {
                inferReferenceAnnotation(element, stringBuilder);
            } else if (element instanceof PyCallExpression) {
                PsiElement resolve = ((PyCallExpression) element).getCallee().getReference().resolve();
                if (resolve instanceof PyClass) {
                    stringBuilder.append(((PyClass) resolve).getName());
                } else if (resolve instanceof PyFunction) {
                    String anntationStr = ((PyFunction) resolve).getReturnStatementType(typeEvalContext).getName();
                    if (resolve.getContainingFile().getName().contains("builtin")) {
                        anntationStr = typeEvalContext.getType((PyTypedElement) element).getName();
                    }
                    stringBuilder.append(anntationStr);
                }
            } else if (element instanceof PyBinaryExpression pyBinaryExpression) {
                inferAnnotation(pyBinaryExpression.getChildren()[0], stringBuilder);
            } else if (element instanceof PySubscriptionExpression pySubscriptionExpression) {
                StringBuilder temp = new StringBuilder();
                inferReferenceAnnotation(pySubscriptionExpression.getOperand(), temp);
                if (!temp.toString().contains("[")) {
                    Until.throwErrorWithPosition(pySubscriptionExpression.getOperand().getReference().resolve(),
                            "no sub type");
                }
                stringBuilder.append(temp.substring(temp.indexOf("[") + 1, temp.lastIndexOf("]")));
            } else {
                Until.throwErrorWithPosition(element, "unexcepted psiElementType");
            }
            return true;
        }

        private boolean inferReferenceAnnotation(PsiElement element, StringBuilder stringBuilder) throws Exception {
            if (element.getReference().resolve().getParent() instanceof PyForPart) {
                stringBuilder.append(typeEvalContext.getType((PyTypedElement) element).getName());
                return true;
            }
            if (Until.getAnnotationValue(element) == null) {
                applyFixElement(element.getReference().resolve(), pyElementGenerator);
            }
            if (Until.getAnnotationValue(element) == null) {
                Until.throwErrorWithPosition(element.getReference().resolve(), "infer reference type fail");
            }
            stringBuilder.append(Until.getAnnotationValue(element));
            return true;
        }
    }
}


until

import com.intellij.psi.PsiElement;
import com.jetbrains.python.psi.PyNamedParameter;
import com.jetbrains.python.psi.PyTargetExpression;

public class Until {
    public static String getLevelBlanks(int level) {
        StringBuilder stringBuilder = new StringBuilder();
        for (int i = 0; i < 2 * level; i++) {
            stringBuilder.append(" ");
        }
        return stringBuilder.toString();
    }

    public static String getInsertedString(String oldString, String insertBehindString, String needInsertString, int[] bufferIndex) {
        if (bufferIndex[0] > oldString.length() - 1) {
            return null;
        }

        int insertIndex = oldString.indexOf(insertBehindString, bufferIndex[0]);
        if (insertIndex == -1) {
            return null;
        }
        insertIndex += insertBehindString.length();
        bufferIndex[0] = insertIndex + needInsertString.length();
        return oldString.substring(0, insertIndex) + needInsertString + oldString.substring(insertIndex);
    }

    public static void throwErrorWithPosition(PsiElement element, String message) throws Exception {
        int startOffset = element.getNode().getStartOffset();
        String text = element.getContainingFile().getText();
        int line = 1;
        int column = 1;
        for (int i = 0; i < startOffset; i++) {
            if (text.charAt(i) == '\n') {
                line++;
                column = 1;
            } else {
                column++;
            }
        }
        throw new Exception(message + "\nnear element: " + element.getText() + "\nline: " + line + "\ncolumn: " + column);
    }

    public static String getAnnotationValue(PsiElement element) {
        String result = null;
        if (element instanceof PyTargetExpression) {
            result = ((PyTargetExpression)element).getAnnotationValue();
        } else if (element instanceof PyNamedParameter) {
            result = ((PyNamedParameter)element).getAnnotationValue();
        }
        return result;
    }
}

推断用例

https://www.cnblogs.com/wwho/p/17364367.html
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值