001package org.nasdanika.rag.core;
002
003import java.text.BreakIterator;
004import java.util.ArrayList;
005import java.util.Arrays;
006import java.util.Collections;
007import java.util.HashSet;
008import java.util.LinkedList;
009import java.util.List;
010import java.util.Locale;
011import java.util.Objects;
012import java.util.Set;
013import java.util.function.Function;
014
015import org.eclipse.emf.ecore.EObject;
016import org.nasdanika.models.pdf.Document;
017import org.nasdanika.models.pdf.Paragraph;
018
019/**
020 * Extracts text from PDF and splits into chunks.
021 * This class tries to keep paragraphs together and split them into sentences if keeping together is not possible.
022 */
023public class PdfTextSplitter {
024        
025        public interface Chunk {
026                
027                String getText();
028                
029                List<EObject> getSources();
030                
031                int size();
032                
033                int overlap();
034                
035        }
036        
037        private int size;
038        private int overlap;
039        private int tolerance;
040        private Function<String, List<String>> tokenizer;
041
042        /**
043         * 
044         * @param size Chunk size in tokens
045         * @param overlap Chunk overlap in tokens.
046         * @param tolerance Size tolerance to allow keep paragraphs and sentences together if possible
047         * @param tokenCounter
048         */
049        public PdfTextSplitter(
050                        int size, 
051                        int overlap,
052                        int tolerance,
053                        Function<String, List<String>> tokenizer) {
054                this.size = size;
055                this.overlap = overlap;
056                this.tolerance = tolerance;
057                this.tokenizer = tokenizer;
058        }
059        
060        protected List<String> splitIntoSentences(String text) {
061                BreakIterator iterator = BreakIterator.getSentenceInstance(Locale.US);
062                iterator.setText(text);
063                int start = iterator.first();
064                List<String> ret = new ArrayList<>();
065                for (int end = iterator.next(); end != BreakIterator.DONE; start = end, end = iterator.next()) {
066                        ret.add(text.substring(start,end));
067                }                               
068                return ret;
069        }
070        
071        protected List<String> splitIntoWords(String text) {
072                String[] words = text.split("\\s+");
073                return Arrays.asList(words);
074        }
075        
076        protected String getWordSeparator() {
077                return " ";     
078        }
079        
080        protected String getLineSeparator() {
081                return System.lineSeparator();
082        }
083        
084        protected String getParagraphSeparator() {
085                return getLineSeparator() + getLineSeparator();
086        }
087        
088        private static record WordRecord(int id, String text, List<String> tokens, Paragraph paragraph) {} 
089        private static record SentenceRecord(int id, String text, int size, Paragraph paragraph, List<WordRecord> words) {}
090        private static record ParagraphRecord(int id, String text, int size, Paragraph paragraph, List<SentenceRecord> sentences) {}    
091        
092        private class ChunkImpl implements Chunk {
093                
094                /**
095                 * Creates a new chunk and adds overlap
096                 * @param paragraphs A list of paragraphs
097                 * @param paragraph current paragraph index
098                 * @param sentence current sentence index
099                 */
100                ChunkImpl(
101                                List<ParagraphRecord> paragraphs, 
102                                int paragraph, 
103                                int sentence, 
104                                int word,
105                                int token) {
106                        
107                        int remaining = overlap;
108                        List<ChunkImpl> chunks = new ArrayList<>();
109                        P: for (; paragraph >= 0; --paragraph) {
110                                ParagraphRecord p = paragraphs.get(paragraph);
111                                if (sentence == -1) {
112                                        // Entire paragraph was added
113                                        if (p.size() < remaining) {
114                                                ChunkImpl pChunk = new ChunkImpl(null, -1, -1, -1, -1);
115                                                pChunk.add(p);
116                                                pChunk.add(getParagraphSeparator(), null);
117                                                chunks.add(pChunk);
118                                                remaining -= pChunk.size();
119                                                if (remaining <= tolerance) {
120                                                        break;
121                                                }
122                                                continue;
123                                        }
124                                        
125                                        // Doesn't fit, setting the sentence index to the last sentence
126                                        sentence = p.sentences().size() - 1;                                    
127                                }
128                                
129                                for (; sentence >= 0; --sentence) {
130                                        SentenceRecord s = p.sentences().get(sentence);
131                                        if (word == -1) {
132                                                // Entire sentence was added
133                                                if (s.size() < remaining) {
134                                                        ChunkImpl sChunk = new ChunkImpl(null, -1, -1, -1, -1);
135                                                        sChunk.add(s);
136                                                        chunks.add(sChunk);
137                                                        remaining -= sChunk.size();
138                                                        if (remaining <= tolerance) {
139                                                                break P;
140                                                        }
141                                                        continue;
142                                                }
143                                                
144                                                // Doesn't fit, setting the word index to the last word
145                                                word = s.words().size() - 1;                                    
146                                        }
147                                        
148                                        for (; word >= 0; --word) {
149                                                WordRecord w = s.words().get(word);
150                                                if (token == -1) {
151                                                        // Entire word was added
152                                                        if (w.tokens().size() < remaining) {
153                                                                ChunkImpl wChunk = new ChunkImpl(null, -1, -1, -1, -1);
154                                                                wChunk.add(w);
155                                                                wChunk.add(getWordSeparator(), w.paragraph());
156                                                                chunks.add(wChunk);
157                                                                remaining -= wChunk.size();
158                                                                if (remaining <= tolerance) {
159                                                                        break P;
160                                                                }
161                                                                continue;
162                                                        }
163                                                        
164                                                        // Doesn't fit, setting the token index to the last token
165                                                        token = w.tokens().size() - 1;                                  
166                                                }
167                                                for (; token >= 0; --token) {
168                                                        ChunkImpl tChunk = new ChunkImpl(null, -1, -1, -1, -1);
169                                                        tChunk.add(w.tokens().get(token), w.paragraph());
170                                                        chunks.add(tChunk);
171                                                        remaining -= tChunk.size();
172                                                        if (remaining <= tolerance) {
173                                                                break P;
174                                                        }
175                                                }
176                                        }
177                                }
178                                
179                        }
180                        
181                        Collections.reverse(chunks);
182                        for (ChunkImpl oc: chunks) {
183                                add(oc);
184                        }
185                        chunkOverlap = size;
186                }
187                
188                private StringBuilder textBuilder = new StringBuilder();
189                private int size;
190                
191                public int size() {
192                        return size;
193                }
194                
195                public String getText() {
196                        return textBuilder.toString();
197                }
198                
199                void add(String text, int size, EObject source) {
200                        textBuilder.append(text);
201                        this.size += size;
202                        if (this.size > PdfTextSplitter.this.size) {
203                                throw new IllegalStateException("Chunk size exceeded: " + this.size);
204                        }
205                        sources.add(source);
206                }
207                
208                void add(String text, EObject source) {
209                        add(text, tokenizer.apply(text).size(), source);
210                }
211                
212                void add(ParagraphRecord paragraph) {
213                        if (!sourceRecords.add(paragraph.id())) {
214                                throw new IllegalStateException("Duplicate source paragraph: " + paragraph);
215                        }
216                        if (size > 0) {
217                                add(getParagraphSeparator(), paragraph.paragraph());
218                        }
219                        add(paragraph.text(), paragraph.size(), paragraph.paragraph());
220                        add(getParagraphSeparator(), tokenizer.apply(getParagraphSeparator()).size(), paragraph.paragraph());
221                }
222                
223                void add(SentenceRecord sentence) {
224                        if (!sourceRecords.add(sentence.id())) {
225                                throw new IllegalStateException("Duplicate source sentence: " + sentence);
226                        }
227                        add(sentence.text(), sentence.size(), sentence.paragraph());
228                }
229                
230                void add(WordRecord word) {
231                        if (!sourceRecords.add(word.id())) {
232                                throw new IllegalStateException("Duplicate source word: " + word);
233                        }
234                        if (size > 0) {
235                                add(getWordSeparator(), null);
236                        }
237                        add(word.text(), word.tokens().size(), word.paragraph());
238                }
239                
240                boolean isFull() {
241                        return size > PdfTextSplitter.this.size - tolerance;
242                }
243                
244                void add(ChunkImpl chunk) {
245                        add(chunk.getText(), chunk.size(), null);
246                        sources.addAll(chunk.getSources());
247                        for (Integer id: chunk.sourceRecords) {
248                                if (!sourceRecords.add(id)) {
249                                        throw new IllegalStateException("Duplicate source record in chunk: " + chunk);
250                                }                               
251                        }
252                }
253                
254                private List<EObject> sources = new ArrayList<>();
255                private Set<Integer> sourceRecords = new HashSet<>();
256
257                @Override
258                public List<EObject> getSources() {
259                        return sources.stream().filter(Objects::nonNull).distinct().toList();
260                }
261                
262                private int chunkOverlap;
263
264                @Override
265                public int overlap() {
266                        return chunkOverlap;
267                }
268                
269        }       
270        
271        /**
272         * Splits docuent into chunks. 
273         * @param document
274         * @return
275         */
276        public List<Chunk> split(Document document) {
277                int[] counter = { 0 };
278                List<ParagraphRecord> paragraphs = document
279                                .getPages()
280                                .stream()
281                                .flatMap(p -> p.getArticles().stream())
282                                .flatMap(a -> a.getParagraphs().stream())
283                                .map(p -> {
284                                        String text = p.getText(getLineSeparator(), getWordSeparator());
285                                        return new ParagraphRecord(
286                                                        counter[0]++,
287                                                        text, 
288                                                        tokenizer.apply(text).size(), 
289                                                        p,
290                                                        splitIntoSentences(text)
291                                                                .stream()
292                                                                .map(s -> new SentenceRecord(
293                                                                                counter[0]++,
294                                                                                s, 
295                                                                                tokenizer.apply(s).size(),
296                                                                                p,
297                                                                                splitIntoWords(s)
298                                                                                        .stream()
299                                                                                        .map(w -> new WordRecord(
300                                                                                                        counter[0]++, 
301                                                                                                        w, 
302                                                                                                        tokenizer.apply(w), 
303                                                                                                        p))
304                                                                                        .toList()))
305                                                                .toList());
306                                })
307                                .toList();              
308                
309                LinkedList<Chunk> chunks = new LinkedList<>();
310                chunks.add(new ChunkImpl(paragraphs, -1, -1, -1, -1));
311                for (int i = 0; i < paragraphs.size(); ++i) {
312                        ChunkImpl chunk = (ChunkImpl) chunks.getLast();
313                        ParagraphRecord paragraph = paragraphs.get(i);
314                        if (paragraph.size() + chunk.size() + tokenizer.apply(getParagraphSeparator()).size() < size) {
315                                // Paragraph fits into the chunk
316                                chunk.add(paragraph);
317                        } else {
318                                // The paragraph is too big to fit into the chunk.
319                                // Close the chunk if its size is within tolerance
320                                if (chunk.isFull()) {
321                                        chunk = new ChunkImpl(paragraphs, i - 1, -1, -1, -1);
322                                        chunks.add(chunk);
323                                        if (paragraph.size() + chunk.size() < size) {
324                                                // Paragraph fits into the chunk
325                                                chunk.add(paragraph);
326                                                continue;
327                                        }
328                                }
329                                
330                                // There is not enough space in the chunk for the entire paragraph - break down into sentences.
331                                for (int j = 0; j < paragraph.sentences().size(); ++j) {
332                                        SentenceRecord sentence = paragraph.sentences().get(j);
333                                        if (sentence.size() + chunk.size() < size) {
334                                                chunk.add(sentence);
335                                        } else {
336                                                if (chunk.isFull()) {
337                                                        chunk = new ChunkImpl(paragraphs, i, j - 1, -1, -1);
338                                                        chunks.add(chunk);
339                                                        if (sentence.size() + chunk.size() < size) {
340                                                                chunk.add(sentence);
341                                                                continue;
342                                                        }
343                                                }
344                                                
345                                                int wordSeparatorSize = tokenizer.apply(getWordSeparator()).size();
346                                                for (int k = 0; k < sentence.words().size(); ++k) {
347                                                        WordRecord word = sentence.words().get(k);
348                                                        if (word.tokens().size() + chunk.size() + wordSeparatorSize < size) {
349                                                                chunk.add(word);
350                                                                chunk.add(getWordSeparator(), word.paragraph());
351                                                        } else {
352                                                                if (chunk.isFull()) {
353                                                                        chunk = new ChunkImpl(paragraphs, i, j, k - 1, -1);
354                                                                        chunks.add(chunk);
355                                                                        if (word.tokens().size() + chunk.size() + wordSeparatorSize < size) {
356                                                                                chunk.add(word);
357                                                                                chunk.add(getWordSeparator(), word.paragraph());
358                                                                                continue;
359                                                                        }
360                                                                }
361                                                                
362                                                                // Breaking the word into tokens and adding individual tokens
363                                                                int w = 0;
364                                                                for (String token: word.tokens()) {
365                                                                        chunk.add(token, 1, word.paragraph());                                                                  
366                                                                        if (chunk.isFull()) {
367                                                                                chunk = new ChunkImpl(paragraphs, i, j, k, w);
368                                                                                chunks.add(chunk);
369                                                                                continue;
370                                                                        }
371                                                                        ++w;
372                                                                }
373                                                        }
374                                                }
375                                                
376                                        }
377                                }
378                        }
379                }
380                                
381                return chunks;
382        }
383        
384}