/*
 * Copyright 2022 the original author or authors.
 * <p>
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 * <p>
 * https://www.apache.org/licenses/LICENSE-2.0
 * <p>
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.openrewrite.java.security;

import lombok.AllArgsConstructor;
import lombok.Value;
import org.openrewrite.*;
import org.openrewrite.analysis.InvocationMatcher;
import org.openrewrite.analysis.dataflow.Dataflow;
import org.openrewrite.analysis.dataflow.LocalFlowSpec;
import org.openrewrite.internal.lang.Nullable;
import org.openrewrite.java.JavaIsoVisitor;
import org.openrewrite.java.JavaTemplate;
import org.openrewrite.java.MethodMatcher;
import org.openrewrite.java.search.UsesAllMethods;
import org.openrewrite.java.security.internal.CursorUtil;
import org.openrewrite.java.tree.Expression;
import org.openrewrite.java.tree.J;
import org.openrewrite.java.tree.JavaType;
import org.openrewrite.staticanalysis.RemoveUnusedLocalVariables;

import java.time.Duration;
import java.util.*;

import static org.openrewrite.java.security.internal.FileSeparatorUtil.isFileSeparatorExpression;

public class PartialPathTraversalVulnerability extends Recipe {
    private static final MethodMatcher getCanonicalPathMatcher = new MethodMatcher("java.io.File getCanonicalPath()");
    private static final MethodMatcher startsWithMatcher = new MethodMatcher("java.lang.String startsWith(java.lang.String)");

    @Override
    public String getDisplayName() {
        return "Partial path traversal vulnerability";
    }

    @Override
    public String getDescription() {
        return "Replaces `dir.getCanonicalPath().startsWith(parent.getCanonicalPath()`, which is vulnerable to partial path traversal attacks, with the more secure `dir.getCanonicalFile().toPath().startsWith(parent.getCanonicalFile().toPath())`.\n\n" + "To demonstrate this vulnerability, consider `\"/usr/outnot\".startsWith(\"/usr/out\")`. The check is bypassed although `/outnot` is not under the `/out` directory. " + "It's important to understand that the terminating slash may be removed when using various `String` representations of the `File` object. " + "For example, on Linux, `println(new File(\"/var\"))` will print `/var`, but `println(new File(\"/var\", \"/\")` will print `/var/`; " + "however, `println(new File(\"/var\", \"/\").getCanonicalPath())` will print `/var`.";
    }

    @Override
    public Set<String> getTags() {
        return Collections.singleton("CWE-22");
    }

    @Override
    public Duration getEstimatedEffortPerOccurrence() {
        return Duration.ofMinutes(10);
    }

    @Override
    public TreeVisitor<?, ExecutionContext> getVisitor() {
        return Preconditions.check(new UsesAllMethods<>(getCanonicalPathMatcher, startsWithMatcher), new JavaIsoVisitor<ExecutionContext>() {
            @Override
            public J.CompilationUnit visitCompilationUnit(J.CompilationUnit cu, ExecutionContext ctx) {
                J.CompilationUnit compilationUnit = (J.CompilationUnit) new ZipSlip.ZipSlipComplete<>(false, false).visitNonNull(cu, ctx, getCursor().getParentOrThrow());
                if (cu != compilationUnit) {
                    // The root cause of this vulnerability is Zip Slip, so don't run partial Path
                    return cu;
                }
                return (J.CompilationUnit) new PartialPathTraversalVulnerabilityVisitor<ExecutionContext>().visitNonNull(cu, ctx, getCursor().getParentOrThrow());
            }
        });
    }

    static class PartialPathTraversalVulnerabilityVisitor<P> extends JavaIsoVisitor<P> {
        private final JavaTemplate toPathGetCanonicalFileTemplate = JavaTemplate.builder("#{any(java.io.File)}.getCanonicalFile().toPath()").build();
        private final JavaTemplate pathStartsWithPathTemplate = JavaTemplate.builder("#{any(java.nio.file.Path)}.startsWith(#{any(java.nio.file.Path)})").build();

        private final JavaTemplate pathStartsWithStringTemplate = JavaTemplate.builder("#{any(java.nio.file.Path)}.startsWith(#{any(String)})").build();

        private final JavaTemplate pathCreationNormalizeTemplate = JavaTemplate.builder("Paths.get(#{any(String)}).normalize()").imports("java.nio.file.Paths").build();

        @Override
        public J.Block visitBlock(J.Block block, P p) {
            J.Block b = super.visitBlock(block, p);
            if (b == block) {
                return b;
            }
            b = (J.Block) new RemoveUnusedLocalVariables(new String[0]).getVisitor().visitNonNull(b, (ExecutionContext) p, getCursor().getParentOrThrow());
            return b;
        }

        @Override
        public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, P p) {
            if (startsWithMatcher.matches(method)) {
                // CASE: ...startsWith(...);
                J.MethodInvocation newStartsWithMethod = visitStartsWithMethodInvocation(getCursor());
                if (newStartsWithMethod != null) {
                    return newStartsWithMethod;
                }
            }
            return super.visitMethodInvocation(method, p);
        }

        private J.@Nullable MethodInvocation visitStartsWithMethodInvocation(Cursor methodCursor) {
            J.MethodInvocation method = methodCursor.getValue();
            assert method.getSelect() != null : "Select is null for `startsWith`";
            Expression select = Expression.unwrap(method.getSelect());
            Expression argument = Expression.unwrap(method.getArguments().get(0));
            Cursor argumentCursor = new Cursor(methodCursor, argument);

            if (isSafePartialPathExpression(argument) || argument instanceof J.Identifier) {
                // `computeUnsafeArguments()` is potentially expensive, only compute it if needed
                if (!computeUnsafeArguments().contains(argument)) {
                    return null;
                }
            }

            // CASE: startsWith is passed an argument or variable represented by an argument
            // that is not terminated by a `/`
            if (getCanonicalPathMatcher.matches(select)) {
                // CASE: getCanonicalPath().startsWith(...)
                J.MethodInvocation getCanonicalPathSelectReplacement = replaceGetCanonicalPath(new Cursor(methodCursor, select));
                return replaceWithPathStartsWithMethodInvocation(methodCursor, argumentCursor, getCanonicalPathSelectReplacement);
            } else {
                // Compute a set of potential alternative select statements
                Set<ExpressionWithTry> alternateSelects = computeAlternateSelects(select);
                if (alternateSelects.size() == 1) {
                    ExpressionWithTry expressionWithTry = alternateSelects.iterator().next();
                    // If both the select statements share the same outer `try` block, then we can make a more
                    // intelligent replacement
                    if (expressionWithTry.maybeTryStatement == findNearestRelevantTry(methodCursor)) {
                        Expression alternateSelect = expressionWithTry.expression;
                        J.MethodInvocation getCanonicalPathSelectReplacement = toPathGetCanonicalFileTemplate.apply(new Cursor(methodCursor, select), ((J.Identifier) select).getCoordinates().replace(), alternateSelect);
                        return replaceWithPathStartsWithMethodInvocation(methodCursor, argumentCursor, getCanonicalPathSelectReplacement);
                    }
                    // Otherwise, we can't make a more intelligent replacement, fall back to the default
                }
                if (!alternateSelects.isEmpty()) {
                    // There are multiple possible alternative selects,
                    // or the alternative select does not share the same outer `try` block.
                    // This is a super complicated case.
                    // The best solution is to simply wrap the subject in `Paths.get(...).normalize()` and use that.
                    maybeAddImport("java.nio.file.Paths");
                    J.MethodInvocation newSelect = pathCreationNormalizeTemplate.apply(new Cursor(methodCursor, select), ((J.Identifier) select).getCoordinates().replace(), method.getSelect());
                    return replaceWithPathStartsWithMethodInvocation(methodCursor, argumentCursor, newSelect);
                }
            }
            return null;
        }

        @AllArgsConstructor
        private static final class GetCanonicalPathToStartsWithLocalFlow extends LocalFlowSpec<J.MethodInvocation, Expression> {
            Expression currentStartsWithSelect;

            @Override
            public boolean isSource(J.MethodInvocation source, Cursor cursor) {
                // SOURCE: Any call to `File#getCanonicalPath()`
                return getCanonicalPathMatcher.matches(source);
            }

            @Override
            public boolean isSink(Expression sink, Cursor cursor) {
                // SINK: Any J.Identifier that is the select (CodeQL: 'qualifier') of a call to `String#startsWith(String)`
                return currentStartsWithSelect == sink;
            }
        }

        @Value
        static class ExpressionWithTry {
            Expression expression;
            @Nullable J.Try maybeTryStatement;
        }

        @Nullable
        private static J.Try findNearestRelevantTry(Cursor startCursor) {
            for (Cursor cursor : (Iterable<Cursor>) startCursor::getPathAsCursors) {
                Object cursorValue = cursor.getValue();
                if (cursorValue instanceof J.Try) {
                    return (J.Try) cursorValue;
                }
                if (cursorValue instanceof J.MethodDeclaration) {
                    return null;
                }
                if (cursorValue instanceof J.Block && J.Block.isStaticOrInitBlock(cursor)) {
                    return null;
                }
            }
            return null;
        }

        private Set<ExpressionWithTry> computeAlternateSelects(Expression currentSelect) {
            // Start visiting as high as possible.
            return CursorUtil.findOuterExecutableBlock(getCursor()).map(outerExecutable -> {
                Set<ExpressionWithTry> alternateSelects = new HashSet<>();
                new JavaIsoVisitor<Set<ExpressionWithTry>>() {
                    @Override
                    public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, Set<ExpressionWithTry> alternateSelectsInner) {
                        if (getCanonicalPathMatcher.matches(method)) {
                            Dataflow.startingAt(getCursor()).findSinks(new GetCanonicalPathToStartsWithLocalFlow(currentSelect)).ifPresent(sinkFlow -> {
                                J.Try maybeOuterTryStatement = findNearestRelevantTry(getCursor());
                                alternateSelectsInner.add(new ExpressionWithTry(sinkFlow.getSource().getSelect(), maybeOuterTryStatement));
                            });
                        }
                        return super.visitMethodInvocation(method, alternateSelectsInner);
                    }
                }.visit(outerExecutable.getValue(), alternateSelects, outerExecutable.getParentOrThrow());
                return alternateSelects;
            }).orElse(Collections.emptySet());
        }

        private static final class NotSafePartialPathTraversalLocalFlow extends LocalFlowSpec<Expression, Expression> {

            @Override
            public boolean isSource(Expression source, Cursor cursor) {
                // SOURCE: Any expression that isn't terminated by the `/` character
                return isSourceFilter(source, cursor);
            }

            @Override
            public boolean isSink(Expression sink, Cursor cursor) {
                // SINK: method argument for the method `String#startsWith`
                return InvocationMatcher.from(startsWithMatcher).advanced().isAnyArgument(cursor);
            }

            static boolean isSourceFilter(Expression source, Cursor cursor) {
                if (cursor.firstEnclosing(J.Import.class) != null) {
                    return false;
                } else if (source instanceof J.Literal) {
                    // Ignore the literal 'null' as a source as the value will probably be reassigned
                    return source.getType() != JavaType.Primitive.Null;
                }
                // A source is any add expression that does not have a `/` appended at the end of it
                return !isSafePartialPathExpression(source) && !(source instanceof J.Identifier || source instanceof J.Assignment || source instanceof J.AssignmentOperation || source instanceof J.Primitive || source instanceof J.Empty);
            }
        }

        /**
         * Warning: This method is potentially expensive on first call.
         */
        private List<Expression> computeUnsafeArguments() {
            // Start visiting as high as possible.
            return CursorUtil.findOuterExecutableBlock(getCursor())
                    .map(outerExecutable -> outerExecutable.computeMessageIfAbsent("EXPENSIVE_COMPUTE_UNSAFE_ARGUMENTS", k -> {
                        List<Expression> unsafeArguments = new ArrayList<>();
                        new JavaIsoVisitor<List<Expression>>() {
                            @Override
                            public Expression visitExpression(Expression expression, List<Expression> unsafeArgumentsInner) {
                                if (NotSafePartialPathTraversalLocalFlow.isSourceFilter(expression, getCursor())) {
                                    // CASE: Find any expression that is considered an 'unsafe' source
                                    Dataflow.startingAt(getCursor()).findSinks(new NotSafePartialPathTraversalLocalFlow()).ifPresent(sinks -> {
                                        if (!sinks.isEmpty()) {
                                            // Add this set of sinks to it
                                            unsafeArgumentsInner.addAll(sinks.getSinks());
                                        }
                                    });
                                }
                                return super.visitExpression(expression, unsafeArgumentsInner);
                            }
                        }.visit(outerExecutable.getValue(), unsafeArguments, outerExecutable.getParentOrThrow());
                        return unsafeArguments;
                    }))
                    .orElse(Collections.emptyList());
        }

        /**
         * Replaces the {@link String#startsWith(String)} call with a call to
         * {@link java.nio.file.Path#startsWith(java.nio.file.Path)} or
         * {@link java.nio.file.Path#startsWith(String)}.
         */
        private J.MethodInvocation replaceWithPathStartsWithMethodInvocation(Cursor methodCursor, Cursor argumentCursor, J.MethodInvocation getCanonicalPathSubjectReplacement) {
            J.MethodInvocation method = methodCursor.getValue();
            Expression argument = argumentCursor.getValue();
            if (getCanonicalPathMatcher.matches(argument)) {
                // CASE: ...startsWith(...getCanonicalPath())
                J.MethodInvocation getCanonicalPathArgumentReplacement = replaceGetCanonicalPath(argumentCursor);
                return pathStartsWithPathTemplate.apply(methodCursor, method.getCoordinates().replace(), getCanonicalPathSubjectReplacement, getCanonicalPathArgumentReplacement);
            } else {
                // CASE: ...startsWith(...)
                if (isFileSeparatorExpression(argument)) {
                    // CASE: ...startsWith(File.separator)
                    return method;
                }
                return pathStartsWithStringTemplate.apply(methodCursor, method.getCoordinates().replace(), getCanonicalPathSubjectReplacement, argument);
            }
        }

        private J.MethodInvocation replaceGetCanonicalPath(Cursor getCanonicalPathCursor) {
            J.MethodInvocation getCanonicalPath = getCanonicalPathCursor.getValue();
            return toPathGetCanonicalFileTemplate.apply(getCanonicalPathCursor, getCanonicalPath.getCoordinates().replace(), getCanonicalPath.getSelect());
        }
    }

    static boolean isSafePartialPathExpression(@Nullable Expression expression) {
        if (expression instanceof J.Binary) {
            J.Binary concatArgument = (J.Binary) expression;
            if (J.Binary.Type.Addition.equals(concatArgument.getOperator())) {
                // CASE: ...startsWith(... + ...);
                return isFileSeparatorExpression(concatArgument.getRight());
            }
        }
        return false;
    }
}
