/*
 * Decompiled with CFR 0.152.
 */
package org.gorpipe.spark;

import breeze.linalg.DenseMatrix;
import com.google.common.collect.Iterators;
import gorsat.process.GenericSessionFactory;
import gorsat.process.PipeInstance;
import java.io.BufferedWriter;
import java.io.IOException;
import java.io.Serializable;
import java.nio.file.Files;
import java.nio.file.OpenOption;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Spliterators;
import java.util.function.DoubleFunction;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.Stream;
import java.util.stream.StreamSupport;
import org.apache.spark.SparkContext;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.FlatMapFunction;
import org.apache.spark.api.java.function.Function2;
import org.apache.spark.api.java.function.MapFunction;
import org.apache.spark.api.java.function.MapPartitionsFunction;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.spark.broadcast.Broadcast;
import org.apache.spark.mllib.feature.PCA;
import org.apache.spark.mllib.feature.PCAModel;
import org.apache.spark.mllib.linalg.Matrices;
import org.apache.spark.mllib.linalg.Matrix;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.mllib.linalg.distributed.BlockMatrix;
import org.apache.spark.mllib.linalg.distributed.CoordinateMatrix;
import org.apache.spark.mllib.linalg.distributed.IndexedRowMatrix;
import org.apache.spark.mllib.linalg.distributed.MatrixEntry;
import org.apache.spark.mllib.linalg.distributed.RowMatrix;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Encoder;
import org.apache.spark.sql.Encoders;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import org.gorpipe.gor.session.GorSession;
import org.gorpipe.spark.GorSparkSession;
import org.gorpipe.spark.SparkGOR;
import scala.Tuple2;

public class SparkPCA {
    static String[] testargs = new String[]{"--projectroot", "/gorproject", "--freeze", "plink_wes", "--variants", "testvars2.gorz", "--pnlist", "testpns.txt", "--partsize", "10", "--pcacomponents", "3", "--outfile", "out.txt"};

    public static void main(String[] args) throws IOException {
        Path varpath;
        Path pnpath;
        Path freezepath;
        String freeze;
        List<String> argList = Arrays.asList(args);
        int i = argList.indexOf("--appname");
        String appName = i != -1 ? argList.get(i + 1) : "pca";
        i = argList.indexOf("--freeze");
        String string = freeze = i != -1 ? argList.get(i + 1) : null;
        if (freeze != null && freeze.startsWith("'")) {
            freeze = freeze.substring(1, freeze.length() - 1);
        }
        i = argList.indexOf("--projectroot");
        String projectRoot = argList.get(i + 1);
        i = argList.indexOf("--variants");
        String variants = argList.get(i + 1);
        i = argList.indexOf("--pnlist");
        String pnlist = argList.get(i + 1);
        i = argList.indexOf("--partsize");
        int partsize = i != -1 ? Integer.parseInt(argList.get(i + 1)) : 10;
        i = argList.indexOf("--pcacomponents");
        int pcacomponents = i != -1 ? Integer.parseInt(argList.get(i + 1)) : 3;
        i = argList.indexOf("--outfile");
        String outfile = i != -1 ? argList.get(i + 1) : null;
        boolean sparse = argList.indexOf("--sparse") != -1;
        Path root = Paths.get(projectRoot, new String[0]);
        Path outpath = Paths.get(outfile, new String[0]);
        if (!outpath.isAbsolute()) {
            outpath = root.resolve(outpath);
        }
        if (!(freezepath = Paths.get(freeze, new String[0])).isAbsolute()) {
            freezepath = root.resolve(freezepath);
        }
        if (!(pnpath = Paths.get(pnlist, new String[0])).isAbsolute()) {
            pnpath = root.resolve(pnpath);
        }
        if (!(varpath = Paths.get(variants, new String[0])).isAbsolute()) {
            varpath = root.resolve(varpath);
        }
        SparkSession.Builder ssBuilder = SparkSession.builder();
        try (SparkSession spark = ssBuilder.appName(appName).getOrCreate();){
            long varcount;
            block25: {
                if (varpath.getFileName().toString().endsWith(".gorz")) {
                    GenericSessionFactory gsf = new GenericSessionFactory(".", "result_cache");
                    try (GorSession gs = gsf.create();
                         PipeInstance pi = new PipeInstance(gs.getGorContext());){
                        pi.init("gor " + varpath.toString(), false, "");
                        Stream<String> str = StreamSupport.stream(Spliterators.spliteratorUnknownSize(pi.theInputSource(), 0), false).map(Object::toString);
                        varcount = str.count();
                        break block25;
                    }
                }
                Stream<String> str = Files.lines(varpath).skip(1L);
                varcount = str.count();
            }
            System.err.println("starting pca " + varcount);
            SparkPCA.pca(spark, projectRoot, freeze, pnlist, variants, partsize, pcacomponents, pnpath, varpath, freezepath, (int)varcount, outpath, sparse);
            System.err.println("pca done");
        }
        catch (Exception e) {
            e.printStackTrace();
        }
    }

    private static RowMatrix blockMatrixToRowMatrix(Dataset<Row> ds, int varcount, int partsize) {
        JavaRDD dbm = ds.select("chrom", new String[]{"pos", "values"}).javaRDD().mapPartitionsWithIndex((Function2 & Serializable)(pi, input) -> {
            double[] mat = null;
            Iterator it = Collections.emptyIterator();
            int start = 0;
            while (input.hasNext()) {
                Row row = (Row)input.next();
                String strvec = row.getString(2).substring(1);
                int len = strvec.length();
                if (mat == null) {
                    mat = new double[varcount * len];
                }
                if (start * len > mat.length) {
                    throw new RuntimeException("len " + len + " " + mat.length + "  " + varcount);
                }
                for (int i = 0; i < len; ++i) {
                    mat[start + varcount * i] = strvec.charAt(i) - 48;
                }
                ++start;
            }
            if (mat != null) {
                Matrix matrix = Matrices.dense((int)(mat.length / varcount), (int)varcount, mat);
                Tuple2 index = new Tuple2(pi, (Object)0);
                Tuple2 tupmat = new Tuple2((Object)index, (Object)matrix);
                return Iterators.singletonIterator((Object)tupmat);
            }
            return it;
        }, true);
        BlockMatrix mat = new BlockMatrix(dbm.rdd(), partsize, varcount);
        IndexedRowMatrix irm = mat.toIndexedRowMatrix();
        DenseMatrix dmb = irm.toBreeze();
        System.err.println(dmb);
        return irm.toRowMatrix();
    }

    private static RowMatrix coordMatrixToRowMatrix(Dataset<Row> ds, int varcount, int samplecount, int partsize) {
        JavaRDD dbm = ds.select("chrom", new String[]{"pos", "values"}).javaRDD().zipWithIndex().flatMap((FlatMapFunction & Serializable)tup -> {
            Row row = (Row)tup._1;
            long idx = (Long)tup._2;
            long pi = idx / (long)varcount;
            long ip = idx % (long)varcount;
            String strvec = row.getString(2).substring(1);
            int len = strvec.length();
            return IntStream.range(0, len).filter(i -> strvec.charAt(i) != '0').mapToObj(i -> new MatrixEntry(pi * (long)partsize + (long)i, ip, (double)(strvec.charAt(i) - 48))).iterator();
        });
        CoordinateMatrix mat = new CoordinateMatrix(dbm.rdd(), (long)samplecount, (long)varcount);
        DenseMatrix dmb = mat.toBreeze();
        System.err.println(dmb);
        IndexedRowMatrix irm = mat.toIndexedRowMatrix();
        DenseMatrix dmb2 = mat.toBreeze();
        System.err.println(dmb2);
        return irm.toRowMatrix();
    }

    private static void pca(SparkSession spark, String projectRoot, String freeze, String pnlist, String variants, int partsize, int pcacomponents, Path pnpath, Path varpath, Path freezepath, int varcount, Path outpath, boolean sparse) throws IOException {
        GorSparkSession gorSparkSession = SparkGOR.createSession(spark, projectRoot, "result_cache", null, null);
        String freezevariants = freezepath.resolve("variants.gord").toString();
        Dataset<? extends Row> pnidx = gorSparkSession.spark("spark <(partgor -ff " + pnpath.toString() + " -partsize " + partsize + " -dict " + freezevariants + " <(gorrow 1,1 | calc pn '#{tags}' | split pn))", null);
        Dataset<? extends Row> ds = gorSparkSession.spark("spark -tag <(partgor -ff " + pnpath.toString() + " -partsize " + partsize + " -dict " + freezevariants + " <(gor " + varpath.toString() + "| varjoin -r -l -e '?' <(gor " + freezevariants + " -nf -f #{tags})| rename Chrom CHROM | rename ref REF | rename alt ALT | calc ID chrom+'_'+pos+'_'+ref+'_'+alt | csvsel " + freezepath.resolve("buckets.tsv").toString() + " <(nor <(gorrow 1,1 | calc pn '#{tags}' | split pn) | select pn) -u 3 -gc id,ref,alt -vs 1 | replace values 'u'+values))", null);
        SparkPCA.labelPoint(spark, ds, pnidx, varcount, pcacomponents, outpath);
    }

    private static void labelPoint(SparkSession spark, Dataset<Row> ds, Dataset<Row> pnidx, int varcount, int pcacomponents, Path outpath) throws IOException {
        Dataset dv = ds.select("values", new String[0]).mapPartitions((MapPartitionsFunction & Serializable)ir -> {
            double[][] mat = null;
            Iterator it = Collections.emptyIterator();
            int start = 0;
            while (ir.hasNext()) {
                int i;
                Row row = (Row)ir.next();
                String strvec = row.getString(0).substring(1);
                int len = strvec.length();
                if (mat == null) {
                    mat = new double[len][];
                    for (i = 0; i < len; ++i) {
                        mat[i] = new double[varcount];
                    }
                }
                for (i = 0; i < len; ++i) {
                    mat[i][start] = strvec.charAt(i) - 48;
                }
                ++start;
            }
            if (mat != null) {
                ArrayList<Vector> lv = new ArrayList<Vector>(mat.length);
                for (int i = 0; i < mat.length; ++i) {
                    lv.add(Vectors.dense((double[])mat[i]));
                }
                return lv.stream().iterator();
            }
            return it;
        }, Encoders.kryo(Vector.class));
        JavaPairRDD jprs = pnidx.select("pn", new String[0]).map((MapFunction & Serializable)r -> r.get(0).toString(), Encoders.STRING()).javaRDD().zipWithIndex().mapToPair(Tuple2::swap);
        JavaPairRDD jprv = dv.javaRDD().zipWithIndex().mapToPair(Tuple2::swap);
        PCA pca = new PCA(pcacomponents);
        PCAModel pcamodel = pca.fit(jprv.values());
        JavaPairRDD jprr = jprv.mapToPair((PairFunction & Serializable)f -> new Tuple2((Object)((Long)f._1), (Object)pcamodel.transform((Vector)f._2)));
        JavaPairRDD projected = jprs.join(jprr).mapToPair((PairFunction & Serializable)f -> (Tuple2)f._2);
        Map result = projected.collectAsMap();
        try (BufferedWriter bw = Files.newBufferedWriter(outpath, new OpenOption[0]);){
            for (String pn : result.keySet()) {
                bw.write(pn);
                Vector pcacomp = (Vector)result.get(pn);
                for (int i = 0; i < pcacomp.size(); ++i) {
                    bw.write(9);
                    bw.write(Double.toString(pcacomp.apply(i)));
                }
                bw.write(10);
            }
        }
    }

    private static void coordpca(String[] args, SparkSession spark) {
        GorSparkSession gorSparkSession = SparkGOR.createSession(spark);
        JavaSparkContext javaSparkContext = JavaSparkContext.fromSparkContext((SparkContext)spark.sparkContext());
        Dataset<? extends Row> dsmap = gorSparkSession.spark("spark -tag <(pgor -split <(gor /gorproject/brca.gor) /gorproject/plink_wes/metadata/AF.gorz| varjoin -r -l -e '?' /gorproject/plink_wes/vep_single.gorz| where max_consequence in ('frameshift_variant','splice_acceptor_variant','splice_donor_variant','start_lost','stop_gained','stop_lost','incomplete_terminal_codon_variant','inframe_deletion','inframe_insertion','missense_variant','protein_altering_variant','splice_region_variant')| group chrom -count)", null);
        Map<String, Integer> rangeCount = dsmap.collectAsList().stream().collect(Collectors.toMap(r -> r.getString(4), r -> r.getInt(3)));
        Map<String, Integer> rangeSum = dsmap.collectAsList().stream().collect(Collectors.toMap(r -> r.getString(4), new Function<Row, Integer>(){
            int sum = 0;

            @Override
            public Integer apply(Row r) {
                int ret = this.sum;
                this.sum += r.getInt(3);
                return ret;
            }
        }));
        Broadcast bcsize = javaSparkContext.broadcast(rangeCount);
        Broadcast bcsum = javaSparkContext.broadcast(rangeSum);
        Dataset<? extends Row> ds = gorSparkSession.spark("spark -tag <(partgor -ff <(nor -h /gorproject/plink_wes/buckets.tsv | select 1 | top 50) -partsize 10 -dict /gorproject/plink_wes/variants.gord <(pgor -split <(gor /gorproject/brca.gor) /gorproject/plink_wes/variants.gord -nf -f #{tags} | varjoin -r -l -e '?' /gorproject/plink_wes/vep_single.gorz| where max_consequence in ('frameshift_variant','splice_acceptor_variant','splice_donor_variant','start_lost','stop_gained','stop_lost','incomplete_terminal_codon_variant','inframe_deletion','inframe_insertion','missense_variant','protein_altering_variant','splice_region_variant')| rename Chrom CHROM | rename ref REF | rename alt ALT | calc ID chrom+'_'+pos+'_'+ref+'_'+alt | csvsel /gorproject/plink_wes/buckets.tsv <(nor -h /gorproject/plink_wes/buckets.tsv | select 1 | top 50) -u 3 -gc id,ref,alt -vs 1))", null);
        Encoder menc = Encoders.bean(MatrixEntry.class);
        Dataset dsm = ds.select("values", new String[]{"tag"}).mapPartitions((MapPartitionsFunction & Serializable)input -> {
            if (input.hasNext()) {
                Row r = (Row)input.next();
                final String tag = r.getString(1);
                int size = (Integer)((Map)bcsize.getValue()).get(tag);
                final int sum = (Integer)((Map)bcsum.getValue()).get(tag);
                return Stream.concat(Stream.of(r), StreamSupport.stream(Spliterators.spliterator(input, (long)size, 64), false)).flatMap(new Function<Row, Stream<MatrixEntry>>(){
                    int k = 0;

                    @Override
                    public Stream<MatrixEntry> apply(Row row) {
                        assert (row.getString(1).equals(tag));
                        Stream<MatrixEntry> sme = row.getString(0).chars().map(c -> c - 48).asDoubleStream().mapToObj(new DoubleFunction<MatrixEntry>(){
                            int i = 0;

                            @Override
                            public MatrixEntry apply(double d) {
                                return new MatrixEntry((long)(sum + k), (long)this.i++, d);
                            }
                        });
                        ++this.k;
                        return sme;
                    }
                }).iterator();
            }
            return Collections.emptyIterator();
        }, menc);
        CoordinateMatrix cm = new CoordinateMatrix(dsm.rdd());
        RowMatrix rowMatrix = cm.transpose().toRowMatrix();
        Matrix pc = rowMatrix.computePrincipalComponents(3);
        RowMatrix projected = rowMatrix.multiply(pc);
        DenseMatrix dm = projected.toBreeze();
        System.err.println(dm.toString(20, 20));
    }

    private static void test1(String[] args, SparkSession spark) {
        Dataset ds = spark.read().format("csv").option("header", "true").option("delimiter", "\t").option("inferSchema", "true").load("/gorproject/ref/dbsnp/dbsnp.gor");
        ds.createOrReplaceTempView("dbsnp");
        Dataset sqlds = spark.sql("select * from dbsnp where rsids = 'rs22'");
        sqlds.write().save("/gorproject/mu.parquet");
        spark.close();
    }
}

