001package org.nasdanika.ai;
002
003import java.awt.image.BufferedImage;
004import java.io.ByteArrayOutputStream;
005import java.io.File;
006import java.io.IOException;
007import java.io.InputStream;
008import java.net.URL;
009import java.util.ArrayList;
010import java.util.Arrays;
011import java.util.Collections;
012import java.util.List;
013
014import javax.imageio.ImageIO;
015
016import reactor.core.publisher.Mono;
017
018public interface Chat extends Model {
019        
020        /**
021         * Chat requirement.
022         * String attributes match any value if null.
023         */
024        record Requirement(
025                String provider,
026                String model,
027                String version) {}      
028        
029        interface Message {
030                
031                String getRole();
032                
033                String getContent();
034                
035                /**
036                 * Images encoded as base64 url
037                 * @return
038                 */
039                List<String> getImages();
040                
041                /**
042                 * Adds an image encoded as base64 data URL
043                 * @param dataUrl
044                 * @return this message
045                 */
046                Message addImage(String dataUrl);
047                
048                default Message addImage(File file) {
049                        try {
050                                return addImage(ImageIO.read(file));
051                        } catch (IOException e) {
052                                throw new IllegalArgumentException("Cannot read image from file '" + file.getAbsolutePath() + "': " + e, e);
053                        }
054                }
055
056                default Message addImage(InputStream inputStream) {
057                        try {
058                                return addImage(ImageIO.read(inputStream));
059                        } catch (IOException e) {
060                                throw new IllegalArgumentException("Cannot read image from input stream: " + e, e);
061                        }
062                }
063
064                default Message addImage(URL url) {
065                        try {
066                                return addImage(ImageIO.read(url));
067                        } catch (IOException e) {
068                                throw new IllegalArgumentException("Cannot read image from URL '" + url + "': " + e, e);
069                        }
070                }
071
072                default Message addImage(BufferedImage image) {
073                        ByteArrayOutputStream baos = new ByteArrayOutputStream();
074                        try {
075                                try (baos) {
076                                        ImageIO.write(image, "PNG", baos);                                                                                              
077                                }
078                        String base64Image = java.util.Base64.getEncoder().encodeToString(baos.toByteArray());
079                            return addImage("data:image/png;base64," + base64Image);
080                        } catch (Exception e) {
081                                throw new IllegalArgumentException("Cannot write image: " + e, e);
082                        }
083                        
084                }
085                
086                /**
087                 * Creates a message
088                 * @param role
089                 * @param content Message content. Can be null.
090                 * @return
091                 */
092                static Message create(String role, String content) {
093                                                                        
094                        return new Message() {
095                                
096                                @Override
097                                public String getRole() {
098                                        return role;
099                                }
100                                
101                                @Override
102                                public String getContent() {
103                                        return content;
104                                }
105                                
106                                private List<String> images = new ArrayList<>();
107
108                                @Override
109                                public List<String> getImages() {
110                                        return Collections.unmodifiableList(images);
111                                }
112
113                                @Override
114                                public Message addImage(String dataUrl) {
115                                        images.add(dataUrl);
116                                        return this;
117                                }
118                                
119                        };
120                        
121                }
122                
123        }
124        
125        interface ResponseMessage extends Message {
126                
127                String getRefusal();
128                
129                String getFinishReason();
130                
131                @Override
132                default Message addImage(String dataUrl) {
133                        throw new UnsupportedOperationException();                      
134                }
135                
136                @Override
137                default List<String> getImages() {
138                        return Collections.emptyList();
139                }
140                
141        }
142                
143        enum Role {
144                
145                system,
146                assistant,
147                user,
148                function,
149                tool,
150                developer;              
151                
152                public Message createMessage(String content) {
153                        return Message.create(name(), content);
154                }
155                
156        }
157        
158        Mono<List<ResponseMessage>> chatAsync(List<Message> messages);
159        
160        default Mono<List<ResponseMessage>> chatAsync(Message... messages) {
161                return chatAsync(Arrays.asList(messages));
162        }               
163        
164        default List<ResponseMessage> chat(List<Message> messages) {
165                return chatAsync(messages).block();
166        }
167                
168        default List<ResponseMessage> chat(Message... messages) {
169                return chat(Arrays.asList(messages));
170        }       
171        
172        int getMaxOutputTokens();
173
174}