/*
 * Copyright 2022 the original author or authors.
 * <p>
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 * <p>
 * https://www.apache.org/licenses/LICENSE-2.0
 * <p>
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.openrewrite.java.security;

import org.jetbrains.annotations.NotNull;
import org.openrewrite.*;
import org.openrewrite.internal.lang.Nullable;
import org.openrewrite.java.JavaIsoVisitor;
import org.openrewrite.java.JavaTemplate;
import org.openrewrite.java.JavaVisitor;
import org.openrewrite.java.MethodMatcher;
import org.openrewrite.java.dataflow.LocalFlowSpec;
import org.openrewrite.java.search.UsesMethod;
import org.openrewrite.java.tree.*;

import java.util.*;

public class PartialPathTraversalVulnerability extends Recipe {
    private static final MethodMatcher getCanonicalPathMatcher =
            new MethodMatcher("java.io.File getCanonicalPath()");
    private static final MethodMatcher startsWithMatcher =
            new MethodMatcher("java.lang.String startsWith(java.lang.String)");

    @Override
    public String getDisplayName() {
        return "Partial path traversal vulnerability";
    }

    @Override
    public String getDescription() {
        return "Replaces `dir.getCanonicalPath().startsWith(parent.getCanonicalPath()`, which is vulnerable to partial path traversal attacks, with the more secure `dir.getCanonicalFile().toPath().startsWith(parent.getCanonicalFile().toPath())`.\n\n" +
                "To demonstrate this vulnerability, consider `\"/usr/outnot\".startsWith(\"/usr/out\")`. The check is bypassed although `/outnot` is not under the `/out` directory. " +
                "It's important to understand that the terminating slash may be removed when using various `String` representations of the `File` object. " +
                "For example, on Linux, `println(new File(\"/var\"))` will print `/var`, but `println(new File(\"/var\", \"/\")` will print `/var/`; " +
                "however, `println(new File(\"/var\", \"/\").getCanonicalPath())` will print `/var`.";
    }

    @Override
    public Set<String> getTags() {
        return Collections.singleton("CWE-22");
    }

    @Override
    protected @Nullable TreeVisitor<?, ExecutionContext> getSingleSourceApplicableTest() {
        return new JavaVisitor<ExecutionContext>() {
            @Override
            public J visitJavaSourceFile(JavaSourceFile cu, ExecutionContext executionContext) {
                doAfterVisit(new UsesMethod<>(getCanonicalPathMatcher));
                doAfterVisit(new UsesMethod<>(startsWithMatcher));
                return cu;
            }
        };
    }

    @Override
    protected TreeVisitor<?, ExecutionContext> getVisitor() {
        return new PartialPathTraversalVulnerabilityVisitor<>();
    }

    private static class PartialPathTraversalVulnerabilityVisitor<P> extends JavaIsoVisitor<P> {
        private final JavaTemplate toPathGetCanonicalFileTemplate =
                JavaTemplate
                        .builder(this::getCursor, "#{any(java.io.File)}.getCanonicalFile().toPath()")
                        .build();
        private final JavaTemplate pathStartsWithPathTemplate =
                JavaTemplate
                        .builder(this::getCursor, "#{any(java.nio.file.Path)}.startsWith(#{any(java.nio.file.Path)})")
                        .build();

        private final JavaTemplate pathStartsWithStringTemplate =
                JavaTemplate
                        .builder(this::getCursor, "#{any(java.nio.file.Path)}.startsWith(#{any(String)})")
                        .build();

        private final JavaTemplate pathCreationNormalizeTemplate
                = JavaTemplate
                .builder(this::getCursor, "Paths.get(#{any(String)}).normalize()")
                .imports("java.nio.file.Paths")
                .build();

        private static final class GetCanonicalPathToStartsWithLocalFlow extends LocalFlowSpec<J.MethodInvocation, J.Identifier> {
            @Override
            public boolean isSource(J.MethodInvocation source, Cursor cursor) {
                // SOURCE: Any call to `File#getCanonicalPath()`
                return getCanonicalPathMatcher.matches(source);
            }

            @Override
            public boolean isSink(J.Identifier sink, Cursor cursor) {
                // SINK: Any J.Identifier that is the select (CodeQL: 'qualifier') of a call to `String#startsWith(String)`
                J.MethodInvocation maybeStartsWith = cursor.firstEnclosing(J.MethodInvocation.class);
                if (maybeStartsWith != null) {
                    return startsWithMatcher.matches(maybeStartsWith) && maybeStartsWith.getSelect() == sink;
                }
                return false;
            }
        }

        private static final class NotSafePartialPathTraversalLocalFlow extends LocalFlowSpec<Expression, Expression> {

            @Override
            public boolean isSource(Expression source, Cursor cursor) {
                // SOURCE: Any expression that isn't terminated by the `/` character
                return isSourceFilter(source, cursor);
            }

            @Override
            public boolean isSink(Expression sink, Cursor cursor) {
                // SINK: method argument for the method `String#startsWith`
                J.MethodInvocation maybeStartsWith = cursor.firstEnclosing(J.MethodInvocation.class);
                if (maybeStartsWith != null) {
                    return startsWithMatcher.matches(maybeStartsWith) && maybeStartsWith.getArguments().contains(sink);
                }
                return false;
            }

            static boolean isSourceFilter(Expression source, Cursor cursor) {
                if (cursor.firstEnclosing(J.Import.class) != null) {
                    return false;
                } else if (source instanceof J.Literal) {
                    // Ignore the literal 'null' as a source as the value will probably be reassigned
                    return source.getType() != JavaType.Primitive.Null;
                }
                // A source is any add expression that does not have a `/` appended at the end of it
                return !isSafePartialPathExpression(source) && !(
                        source instanceof J.Identifier ||
                                source instanceof J.Assignment ||
                                source instanceof J.AssignmentOperation ||
                                source instanceof J.Primitive ||
                                source instanceof J.Empty
                );
            }
        }

        @Override
        public Expression visitExpression(Expression expression, P p) {
            if (NotSafePartialPathTraversalLocalFlow.isSourceFilter(expression, getCursor())) {
                // CASE: Find any expression that is considered an 'unsafe' source
                dataflow().findSinks(new NotSafePartialPathTraversalLocalFlow()).ifPresent(sinks -> {
                    if (!sinks.isEmpty()) {
                        Cursor parentCursor = getCursor().dropParentUntil(SourceFile.class::isInstance);
                        // If an unsafeArgument list already exists, get it
                        List<Expression> unsafeArgument =
                                parentCursor
                                        .computeMessageIfAbsent(
                                                "unsafeArgument",
                                                v -> new ArrayList<>(sinks.getSinks().size())
                                        );
                        // Add this set of sinks to it
                        unsafeArgument.addAll(sinks.getSinks());
                        // Put it back
                        parentCursor
                                .putMessage("unsafeArgument", unsafeArgument);
                    }
                });
            }
            return super.visitExpression(expression, p);
        }

        @Override
        public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, P p) {
            Cursor parentCursor = getCursor().dropParentUntil(SourceFile.class::isInstance);
            if (getCanonicalPathMatcher.matches(method)) {
                // CASE: ...getCanonicalPath();
                visitGetCanonicalPathMethodInvocation(parentCursor);
            } else if (startsWithMatcher.matches(method)) {
                // CASE: ...startsWith(...);
                J.MethodInvocation newStartsWithMethod =
                        visitStartsWithMethodInvocation(method, parentCursor);
                if (newStartsWithMethod != null) {
                    return newStartsWithMethod;
                }
            }
            return super.visitMethodInvocation(method, p);
        }

        private void visitGetCanonicalPathMethodInvocation(Cursor parentCursor) {
            dataflow().findSinks(new GetCanonicalPathToStartsWithLocalFlow()).ifPresent(sinkFlow -> {
                Map<Expression, Set<Expression>> replacement =
                        parentCursor.computeMessageIfAbsent("replacement", k -> new HashMap<>());
                sinkFlow.getSinks().forEach(sink -> {
                    // MAP of:
                    // String k = VALUE /** VALUE **/.getCanonicalPath();
                    // k /** KEY **/.startsWith(...);
                    Set<Expression> replacementSet = replacement.computeIfAbsent(sink, k -> new HashSet<>());
                    replacementSet.add(sinkFlow.getSource().getSelect());
                    replacement.put(sink, replacementSet);
                });
                parentCursor.putMessage("replacement", replacement);
            });
        }

        @Nullable
        private J.MethodInvocation visitStartsWithMethodInvocation(J.MethodInvocation method, Cursor parentCursor) {
            assert method.getSelect() != null : "Select is null for `startsWith`";
            final Expression select = unwrap(method.getSelect());
            final Expression argument = unwrap(method.getArguments().get(0));
            List<Expression> unsafeArguments = parentCursor.getMessage("unsafeArgument");
            if (unsafeArguments == null) {
                unsafeArguments = Collections.emptyList();
            }

            if (unsafeArguments.contains(argument) ||
                    !(isSafePartialPathExpression(argument) || argument instanceof J.Identifier)) {
                // CASE: startsWith is passed an argument or variable represented by an argument
                // that is not terminated by a `/`
                if (getCanonicalPathMatcher.matches(select)) {
                    // CASE: getCanonicalPath().startsWith(...)
                    final J.MethodInvocation getCanonicalPathSelect = (J.MethodInvocation) select;
                    final J.MethodInvocation getCanonicalPathSelectReplacement =
                            replaceGetCanonicalPath(getCanonicalPathSelect);
                    return replaceWithPathStartsWithMethodInvocation(
                            method,
                            argument,
                            getCanonicalPathSelectReplacement
                    );
                } else {
                    @SuppressWarnings("unchecked")
                    Map<Expression, Set<Expression>> replacementMap =
                            parentCursor.getMessage("replacement");
                    if (replacementMap != null) {
                        // CASE: canonicalPath.startsWith(...)
                        final Set<Expression> alternateSelects = replacementMap.get(select);
                        if (alternateSelects != null) {
                            if (alternateSelects.size() == 1) {
                                Expression alternateSelect = alternateSelects.iterator().next();
                                J.MethodInvocation getCanonicalPathSelectReplacement = method.getSelect().withTemplate(
                                        toPathGetCanonicalFileTemplate,
                                        ((J.Identifier) select).getCoordinates().replace(),
                                        alternateSelect
                                );
                                return replaceWithPathStartsWithMethodInvocation(
                                        method,
                                        argument,
                                        getCanonicalPathSelectReplacement
                                );
                            } else if (!alternateSelects.isEmpty()) {
                                // There are multiple possible alternative selects.
                                // This is a super complicated case.
                                // The best solution is to simply wrap the subject in `new File(...)` and use that.
                                maybeAddImport("java.nio.file.Paths");
                                J.MethodInvocation newSelect = method.getSelect().withTemplate(
                                        pathCreationNormalizeTemplate,
                                        ((J.Identifier) select).getCoordinates().replace(),
                                        method.getSelect()
                                );
                                return replaceWithPathStartsWithMethodInvocation(
                                        method,
                                        argument,
                                        newSelect
                                );
                            }
                        }
                    }
                }
            }
            return null;
        }

        private static boolean isSafePartialPathExpression(Expression expression) {
            if (expression instanceof J.Binary) {
                J.Binary concatArgument = (J.Binary) expression;
                if (J.Binary.Type.Addition.equals(concatArgument.getOperator())) {
                    // CASE: ...startsWith(... + ...);
                    Expression right = unwrap(concatArgument.getRight());
                    if (right instanceof J.FieldAccess || right instanceof J.Identifier) {
                        // CASE:
                        // - ...startsWith(... + File.separator);
                        // - ...startsWith(... + File.seperatorChar);
                        // - ...startsWith(... + separator);
                        // - ...startsWith(... + seperatorChar);
                        J.Identifier nameIdentifier;
                        JavaType type;
                        if (right instanceof J.FieldAccess) {
                            // CASE:
                            // - ...startsWith(... + File.separator);
                            // - ...startsWith(... + File.seperatorChar);
                            nameIdentifier = ((J.FieldAccess) right).getName();
                            type = ((J.FieldAccess) right).getTarget().getType();
                        } else {
                            // CASE:
                            // - ...startsWith(... + separator); statically imported from java.io.File
                            // - ...startsWith(... + seperatorChar); statically imported from java.io.File
                            if (((J.Identifier) right).getFieldType() == null) {
                                return false;
                            }
                            nameIdentifier = (J.Identifier) right;
                            type = ((J.Identifier) right).getFieldType().getOwner();
                        }
                        final String name = nameIdentifier.getSimpleName();
                        return TypeUtils.isOfClassType(type, "java.io.File")
                                && (name.equals("separator") || name.equals("separatorChar"));
                    } else if (right instanceof J.Literal) {
                        // CASE:
                        // - ...startsWith(... + "/);
                        // - ...startsWith(... + '/');
                        J.Literal literal = (J.Literal) right;
                        if (literal.getValue() instanceof String) {
                            String value = (String) literal.getValue();
                            // CASE: ...startsWith(... + "/");
                            return value.equals("/");
                        } else if (literal.getValue() instanceof Character) {
                            Character value = (Character) literal.getValue();
                            // CASE: ...startsWith(... + '/');
                            return value.equals('/');
                        }
                    }
                }
            }
            return false;
        }

        /**
         * Replaces the {@link String#startsWith(String)} call with a call to
         * {@link java.nio.file.Path#startsWith(java.nio.file.Path)} or
         * {@link java.nio.file.Path#startsWith(String)}.
         */
        @NotNull
        private J.MethodInvocation replaceWithPathStartsWithMethodInvocation(
                J.MethodInvocation method,
                Expression argument,
                J.MethodInvocation getCanonicalPathSubjectReplacement
        ) {
            if (getCanonicalPathMatcher.matches(argument)) {
                // CASE: ...startsWith(...getCanonicalPath())
                final J.MethodInvocation getCanonicalPathArgument = (J.MethodInvocation) argument;
                final J.MethodInvocation getCanonicalPathArgumentReplacement =
                        replaceGetCanonicalPath(getCanonicalPathArgument);
                return method
                        .withTemplate(
                                pathStartsWithPathTemplate,
                                method.getCoordinates().replace(),
                                getCanonicalPathSubjectReplacement,
                                getCanonicalPathArgumentReplacement
                        );
            } else {
                // CASE: ...startsWith(...)
                return method
                        .withTemplate(
                                pathStartsWithStringTemplate,
                                method.getCoordinates().replace(),
                                getCanonicalPathSubjectReplacement,
                                argument
                        );
            }
        }

        private static Expression unwrap(Expression expression) {
            if (expression instanceof J.Parentheses) {
                //noinspection unchecked
                return unwrap(((J.Parentheses<Expression>) expression).getTree());
            } else {
                return expression;
            }
        }

        private J.MethodInvocation replaceGetCanonicalPath(J.MethodInvocation getCanonicalPath) {
            return getCanonicalPath
                    .withTemplate(
                            toPathGetCanonicalFileTemplate,
                            getCanonicalPath.getCoordinates().replace(),
                            getCanonicalPath.getSelect()
                    );
        }
    }
}
