001package org.nasdanika.ai;
002
003import java.util.ArrayList;
004import java.util.Arrays;
005import java.util.Collection;
006import java.util.List;
007import java.util.function.BinaryOperator;
008import java.util.function.Function;
009
010import org.nasdanika.ai.FittedPredictor.ErrorComputer;
011import org.nasdanika.ai.FittedPredictor.Fitter;
012
013import reactor.core.publisher.Flux;
014import reactor.core.publisher.Mono;
015
016/**
017 * Collects all samples to double[][] features and double[][] labels, 
018 * then calls fit(double[][] features, double[][] labels) to fit {@link Function}<double[][],double[][]> 
019 * which makes predictions for multiple features at once. 
020 * This allows to use a single base class for multiple types of predictors
021 */
022public abstract class AbstractDoubleFitter implements FittedPredictor.Fitter<double[], double[], Double> {
023
024        private record FeaturesMapped<S>(double[] features, S sample) {}        
025        private record Data(double[] features, double[] labels) {}
026
027        @Override
028        public <S> Mono<FittedPredictor<double[], double[], Double>> fitAsync(
029                        Flux<S> samples,
030                        Function<S, Mono<double[]>> featureMapper, 
031                        Function<S, Mono<double[]>> labelMapper) {
032                
033                return samples
034                        .flatMap(s -> featureMapper.apply(s).map(f -> new FeaturesMapped<>(f,s)))
035                        .flatMap(fm -> featureMapper.apply(fm.sample()).map(l -> new Data(fm.features(),l)))
036                        .collectList()
037                        .map(dataList -> {
038                                double[][] features = new double[dataList.size()][];
039                                double[][] labels = new double[dataList.size()][];
040                                int idx = 0;
041                                for (Data data: dataList) {
042                                        features[idx] = data.features();
043                                        labels[idx++] = data.labels();
044                                }
045                                
046                                Function<double[][],double[][]> predictor = fit(features, labels);                              
047                                return createPredictor(predictor, features, labels);                            
048                        });
049        }
050        
051        @Override
052        public <S> FittedPredictor<double[], double[], Double> fit(
053                        Collection<S> samples, 
054                        Function<S, double[]> featureMapper,
055                        Function<S, double[]> labelMapper) {
056                
057                double[][] features = new double[samples.size()][];
058                double[][] labels = new double[samples.size()][];
059                int idx = 0;
060                for (S sample: samples) {
061                        features[idx] = featureMapper.apply(sample);
062                        labels[idx++] = labelMapper.apply(sample);
063                }
064                
065                Function<double[][],double[][]> predictor = fit(features, labels);
066                return createPredictor(predictor, features, labels); 
067        }
068        
069        protected abstract Function<double[][],double[][]> fit(double[][] features, double[][] labels);
070        
071        protected FittedPredictor<double[], double[], Double> createPredictor(
072                        Function<double[][],double[][]> predictor, 
073                        double[][] features,
074                        double[][] labels) {
075                
076                Double error = computeError(predictor, features, labels);
077                return new FittedPredictor<double[], double[], Double>() {
078                                                
079                        @Override
080                        public double[] predict(double[] feature) {
081                                return predictor.apply(new double[][] { feature })[0];
082                        }
083                        
084                        @Override
085                        public List<Sample<double[], double[]>> predict(Collection<double[]> input) {
086                                double[][] inputArray = input.toArray(double[][]::new);
087                                double[][] results = predictor.apply(inputArray);
088                                List<Sample<double[], double[]>> ret = new ArrayList<>();
089                                for (int i = 0; i < inputArray.length; ++i) {
090                                        ret.add(new Sample<>(inputArray[i], results[i]));
091                                }
092                                return ret;
093                        }
094
095                        @Override
096                        public Mono<double[]> predictAsync(double[] input) {
097                                return Mono.fromSupplier(() -> predict(input));
098                        }
099
100                        @Override
101                        public Double getError() {
102                                return error; 
103                        }
104                        
105                };
106                
107        }
108        
109        protected static Double computeError(
110                        Function<double[][],double[][]> predictor, 
111                        double[][] features,
112                        double[][] labels) {
113                
114                double[][] prediction = predictor.apply(features);
115                if (prediction == null) {
116                        return null;
117                }
118                double total = 0.0;
119                int count = 0;
120                for (int i = 0; i < labels.length; ++i) {
121                        double[] pe = prediction[i];
122                        if (pe != null) {
123                                for (int j = 0; j < labels[j].length; ++j) {
124                                        double delta = pe[j] - labels[i][j];
125                                        total += delta * delta;
126                                        ++count;
127                                }
128                        }
129                }
130                
131                if (count == 0) {
132                        return null;
133                }
134                
135                return total / count;
136        }       
137        
138        public static Function<double[][], double[]> wrap(Function<double[], Double> predictor) {
139                return input -> {
140                        double[] output = new double[input.length];
141                        for (int i = 0; i < input.length; ++i) {
142                                output[i] = predictor.apply(input[i]);
143                        }
144                        return output;
145                };              
146        }
147        
148        public Fitter<double[], double[], Double> compose(Fitter<double[], double[], Double> other) {
149                BinaryOperator<double[]> add = (a,b) -> {
150                        double[] aCopy = Arrays.copyOf(a, a.length);
151                        for (int i = 0; i < a.length; ++i) {
152                                aCopy[i] += b[i];
153                        }
154                        return aCopy;
155                };
156                
157                BinaryOperator<double[]> subtract = (a,b) -> {
158                        double[] aCopy = Arrays.copyOf(a, a.length);
159                        for (int i = 0; i < a.length; ++i) {
160                                aCopy[i] -= b[i];
161                        }
162                        return aCopy;                   
163                };
164                
165                ErrorComputer<double[], double[], Double> errorComputer = new ErrorComputer<double[], double[], Double>() {
166
167                        @Override
168                        public <S> Double computeError(
169                                        Predictor<double[], double[]> predictor, 
170                                        Collection<S> samples,
171                                        Function<S, double[]> featureMapper, 
172                                        Function<S, double[]> labelMapper) {
173                                                                
174                                double[][] features = new double[samples.size()][];                             
175                                double[][] labels = new double[samples.size()][];
176                                
177                                int idx = 0;
178                                for (S sample: samples) {
179                                        features[idx] = featureMapper.apply(sample);
180                                        labels[idx++] = labelMapper.apply(sample);
181                                }                               
182                                
183                                Function<double[][], double[][]> errorPredictor = input -> {
184                                        double[][] output = new double[input.length][];
185                                        for (int i = 0; i < input.length; ++i) {
186                                                output[i] = predictor.predict(input[i]);
187                                        }
188                                        return output;
189                                };
190                                return AbstractDoubleFitter.computeError(
191                                                errorPredictor, 
192                                                features, 
193                                                labels);
194                        }
195                        
196                };
197                
198                return compose(other, add, subtract, errorComputer);
199        }
200        
201                
202        // TODO - stacking/composition binary operators to add/subtract labels. 
203        // Fit this one, fit the next one on label and prediction difference. 
204        // Predict by adding this prediciton to the next prediction.
205        // Adapters to a single double result with support of stacking/composition too?
206        
207}