/*
 * Copyright 2021 the original author or authors.
 * <p>
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 * <p>
 * https://www.apache.org/licenses/LICENSE-2.0
 * <p>
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.openrewrite.java.security;

import org.openrewrite.*;
import org.openrewrite.java.JavaIsoVisitor;
import org.openrewrite.java.JavaTemplate;
import org.openrewrite.java.MethodMatcher;
import org.openrewrite.java.search.UsesType;
import org.openrewrite.java.tree.J;
import org.openrewrite.java.tree.JavaCoordinates;
import org.openrewrite.java.tree.Statement;
import org.openrewrite.java.tree.TypeUtils;

public class XmlParserXXEVulnerability extends Recipe {
    private static final MethodMatcher XML_PARSER_FACTORY_INSTANCE = new MethodMatcher("javax.xml.stream.XMLInputFactory new*()");
    private static final MethodMatcher XML_PARSER_FACTORY_SET_PROPERTY = new MethodMatcher("javax.xml.stream.XMLInputFactory setProperty(java.lang.String, ..)");

    private static final String XML_FACTORY_FQN = "javax.xml.stream.XMLInputFactory";
    private static final String SUPPORTING_EXTERNAL_ENTITIES_PROPERTY_NAME = "IS_SUPPORTING_EXTERNAL_ENTITIES";
    private static final String SUPPORT_DTD_PROPERTY_NAME = "SUPPORT_DTD";
    private static final String XML_PARSER_INITIALIZATION_METHOD = "xml-parser-initialization-method";
    private static final String XML_FACTORY_VARIABLE_NAME = "xml-factory-variable-name";

    @Override
    public String getDisplayName() {
        return "XML parser XXE vulnerability";
    }

    @Override
    public String getDescription() {
        return "Avoid exposing dangerous features of the XML parser by setting XMLInputFactory `IS_SUPPORTING_EXTERNAL_ENTITIES` and `SUPPORT_DTD` properties to `false`.";
    }

    @Override
    public TreeVisitor<?, ExecutionContext> getVisitor() {
        return Preconditions.check(new UsesType<>(XML_FACTORY_FQN, true), new JavaIsoVisitor<ExecutionContext>() {
            @Override
            public J.ClassDeclaration visitClassDeclaration(J.ClassDeclaration classDecl, ExecutionContext ctx) {
                J.ClassDeclaration cd = super.visitClassDeclaration(classDecl, ctx);
                Cursor supportsExternalCursor = getCursor().getMessage(SUPPORTING_EXTERNAL_ENTITIES_PROPERTY_NAME);
                Cursor supportsDTDCursor = getCursor().getMessage(SUPPORT_DTD_PROPERTY_NAME);
                Cursor initializationCursor = getCursor().getMessage(XML_PARSER_INITIALIZATION_METHOD);
                String xmlFactoryVariableName = getCursor().getMessage(XML_FACTORY_VARIABLE_NAME);

                Cursor setPropertyBlockCursor = null;
                if (supportsExternalCursor == null && supportsDTDCursor == null) {
                    setPropertyBlockCursor = initializationCursor;
                } else if (supportsExternalCursor == null ^ supportsDTDCursor == null) {
                    setPropertyBlockCursor = supportsExternalCursor == null ? supportsDTDCursor : supportsExternalCursor;
                }
                if (setPropertyBlockCursor != null && xmlFactoryVariableName != null) {
                    doAfterVisit(new XmlFactoryInsertPropertyStatementVisitor(setPropertyBlockCursor.getValue(), xmlFactoryVariableName, supportsExternalCursor == null, supportsDTDCursor == null));
                }
                return cd;
            }

            @Override
            public J.VariableDeclarations.NamedVariable visitVariable(J.VariableDeclarations.NamedVariable variable, ExecutionContext ctx) {
                J.VariableDeclarations.NamedVariable v = super.visitVariable(variable, ctx);
                if (TypeUtils.isOfClassType(v.getType(), XML_FACTORY_FQN)) {
                    getCursor().putMessageOnFirstEnclosing(J.ClassDeclaration.class, XML_FACTORY_VARIABLE_NAME, v.getSimpleName());
                }
                return v;
            }

            @Override
            public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, ExecutionContext ctx) {
                J.MethodInvocation m = super.visitMethodInvocation(method, ctx);
                if (XML_PARSER_FACTORY_INSTANCE.matches(m)) {
                    getCursor().putMessageOnFirstEnclosing(J.ClassDeclaration.class, XML_PARSER_INITIALIZATION_METHOD, getCursor().dropParentUntil(J.Block.class::isInstance));
                } else if (XML_PARSER_FACTORY_SET_PROPERTY.matches(m) && m.getArguments().get(0) instanceof J.FieldAccess) {
                    J.FieldAccess fa = (J.FieldAccess) m.getArguments().get(0);
                    if (SUPPORTING_EXTERNAL_ENTITIES_PROPERTY_NAME.equals(fa.getSimpleName())) {
                        getCursor().putMessageOnFirstEnclosing(J.ClassDeclaration.class, SUPPORTING_EXTERNAL_ENTITIES_PROPERTY_NAME, getCursor().dropParentUntil(J.Block.class::isInstance));
                    } else if (SUPPORT_DTD_PROPERTY_NAME.equals(fa.getSimpleName())) {
                        getCursor().putMessageOnFirstEnclosing(J.ClassDeclaration.class, SUPPORT_DTD_PROPERTY_NAME, getCursor().dropParentUntil(J.Block.class::isInstance));
                    }
                }
                return m;
            }
        });
    }

    private static class XmlFactoryInsertPropertyStatementVisitor extends JavaIsoVisitor<ExecutionContext> {
        J.Block scope;
        StringBuilder propertyTemplate = new StringBuilder();

        public XmlFactoryInsertPropertyStatementVisitor(J.Block scope, String factoryVariableName, boolean needsExternalEntitiesDisabled, boolean needsSupportsDtdDisabled) {
            this.scope = scope;
            if (needsExternalEntitiesDisabled) {
                propertyTemplate.append(factoryVariableName).append(".setProperty(XMLInputFactory.IS_SUPPORTING_EXTERNAL_ENTITIES, false);");
            }
            if (needsSupportsDtdDisabled) {
                propertyTemplate.append(factoryVariableName).append(".setProperty(XMLInputFactory.SUPPORT_DTD, false);");
            }
        }

        @Override
        public J.Block visitBlock(J.Block block, ExecutionContext ctx) {
            J.Block b = super.visitBlock(block, ctx);
            Statement beforeStatement = null;
            if (b.isScope(scope)) {
                for (int i = b.getStatements().size() - 2; i > -1; i--) {
                    Statement st = b.getStatements().get(i);
                    Statement stBefore = b.getStatements().get(i + 1);
                    if (st instanceof J.MethodInvocation) {
                        J.MethodInvocation m = (J.MethodInvocation) st;
                        if (XML_PARSER_FACTORY_INSTANCE.matches(m) || XML_PARSER_FACTORY_SET_PROPERTY.matches(m)) {
                            beforeStatement = stBefore;
                        }
                    } else if (st instanceof J.VariableDeclarations) {
                        J.VariableDeclarations vd = (J.VariableDeclarations) st;
                        if (vd.getVariables().get(0).getInitializer() instanceof J.MethodInvocation) {
                            J.MethodInvocation m = (J.MethodInvocation) vd.getVariables().get(0).getInitializer();
                            if (m != null && XML_PARSER_FACTORY_INSTANCE.matches(m)) {
                                beforeStatement = stBefore;
                            }
                        }
                    }
                }

                if (getCursor().getParent() != null && getCursor().getParent().getValue() instanceof J.ClassDeclaration) {
                    propertyTemplate.insert(0, "{\n").append("}");
                }
                JavaCoordinates insertCoordinates = beforeStatement != null ?
                        beforeStatement.getCoordinates().before() :
                        b.getCoordinates().lastStatement();
                b = JavaTemplate.builder(propertyTemplate.toString()).contextSensitive().build().apply(
                        new Cursor(getCursor().getParent(), b), insertCoordinates);
            }
            return b;
        }
    }
}
