001/* -*- mode: Java; c-basic-offset: 2; indent-tabs-mode: nil; coding: utf-8-unix -*-
002 *
003 * Copyright © 2018 microBean.
004 *
005 * Licensed under the Apache License, Version 2.0 (the "License");
006 * you may not use this file except in compliance with the License.
007 * You may obtain a copy of the License at
008 *
009 *     http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
014 * implied.  See the License for the specific language governing
015 * permissions and limitations under the License.
016 */
017package org.microbean.jpa.weld;
018
019import java.io.IOException;
020
021import java.net.URL;
022
023import java.util.ArrayList;
024import java.util.Collection;
025import java.util.Collections;
026import java.util.Enumeration;
027import java.util.HashMap;
028import java.util.HashSet;
029import java.util.Map;
030import java.util.Objects;
031import java.util.Set;
032
033import javax.enterprise.event.Observes;
034
035import javax.enterprise.inject.literal.NamedLiteral;
036
037import javax.enterprise.inject.spi.AfterBeanDiscovery;
038import javax.enterprise.inject.spi.AnnotatedType;
039import javax.enterprise.inject.spi.Bean;
040import javax.enterprise.inject.spi.BeanManager;
041import javax.enterprise.inject.spi.ProcessAnnotatedType;
042import javax.enterprise.inject.spi.WithAnnotations;
043
044import javax.inject.Singleton;
045
046import javax.persistence.Converter;
047import javax.persistence.Embeddable;
048import javax.persistence.Entity;
049import javax.persistence.MappedSuperclass;
050import javax.persistence.PersistenceUnit;
051
052import javax.persistence.spi.PersistenceProvider;
053import javax.persistence.spi.PersistenceProviderResolver;
054import javax.persistence.spi.PersistenceProviderResolverHolder;
055
056import javax.persistence.spi.PersistenceUnitInfo;
057
058import javax.sql.DataSource;
059
060import javax.xml.bind.JAXBContext;
061import javax.xml.bind.JAXBException;
062import javax.xml.bind.Unmarshaller;
063
064import org.microbean.jpa.jaxb.Persistence;
065
066public class Extension implements javax.enterprise.inject.spi.Extension {
067
068  private final Map<String, Set<Class<?>>> entityClassesByPersistenceUnitNames;
069  
070  public Extension() {
071    super();
072    this.entityClassesByPersistenceUnitNames = new HashMap<>();
073  }
074
075  private final void discoverManagedClasses(@Observes @WithAnnotations({ Converter.class, Entity.class, Embeddable.class, MappedSuperclass.class }) final ProcessAnnotatedType<?> event) {
076    if (event != null) {
077      final AnnotatedType<?> annotatedType = event.getAnnotatedType();
078      if (annotatedType != null) {
079        final Class<?> entityClass = annotatedType.getJavaClass();
080        assert entityClass != null;
081        final Set<PersistenceUnit> persistenceUnits = annotatedType.getAnnotations(PersistenceUnit.class);
082        if (persistenceUnits == null || persistenceUnits.isEmpty()) {
083          Set<Class<?>> entityClasses = this.entityClassesByPersistenceUnitNames.get("");
084          if (entityClasses == null) {
085            entityClasses = new HashSet<>();
086            this.entityClassesByPersistenceUnitNames.put("", entityClasses);
087          }
088          entityClasses.add(entityClass);
089        } else {
090          for (final PersistenceUnit persistenceUnit : persistenceUnits) {
091            String name = "";
092            if (persistenceUnit != null) {
093              name = persistenceUnit.unitName();
094              assert name != null;
095            }
096            Set<Class<?>> entityClasses = this.entityClassesByPersistenceUnitNames.get(name);
097            if (entityClasses == null) {
098              entityClasses = new HashSet<>();
099              this.entityClassesByPersistenceUnitNames.put(name, entityClasses);
100            }
101            entityClasses.add(entityClass);
102          }
103        }
104        event.veto(); // entities can't be beans
105      }
106    }
107  }
108
109  private final void afterBeanDiscovery(@Observes final AfterBeanDiscovery event, final BeanManager beanManager)
110    throws IOException, JAXBException, ReflectiveOperationException {
111    if (event != null && beanManager != null) {
112
113      // Add a bean for PersistenceProviderResolver.
114      final PersistenceProviderResolver resolver =
115        PersistenceProviderResolverHolder.getPersistenceProviderResolver();
116      event.addBean()
117        .types(PersistenceProviderResolver.class)
118        .scope(Singleton.class)
119        .createWith(cc -> resolver);
120
121      // Add a bean for each "generic" PersistenceProvider reachable
122      // from the resolver.  (Any PersistenceUnitInfo may also specify
123      // the class name of a PersistenceProvider whose class may not
124      // be among those loaded by the resolver; we deal with those
125      // later.)
126      final Collection<? extends PersistenceProvider> providers = resolver.getPersistenceProviders();
127      for (final PersistenceProvider provider : providers) {
128        event.addBean()
129          .addTransitiveTypeClosure(provider.getClass())
130          .scope(Singleton.class)
131          .createWith(cc -> provider);
132      }
133
134      // Discover all META-INF/persistence.xml resources, load them
135      // using JAXB, and turn them into PersistenceUnitInfo instances,
136      // and add beans for all of them.
137      final Enumeration<URL> urls =
138        Thread.currentThread().getContextClassLoader().getResources("META-INF/persistence.xml");
139      if (urls != null && urls.hasMoreElements()) {
140        final Unmarshaller unmarshaller =
141          JAXBContext.newInstance("org.microbean.jpa.jaxb").createUnmarshaller();
142        assert unmarshaller != null;
143        while (urls.hasMoreElements()) {
144          final URL url = urls.nextElement();
145          final Collection<? extends PersistenceUnitInfo> persistenceUnitInfos =
146            PersistenceUnitInfoBean.fromPersistence((Persistence)unmarshaller.unmarshal(url),
147                                                    new URL(url, "../.."),
148                                                    this.entityClassesByPersistenceUnitNames,
149                                                    jtaDataSourceName -> this.getJtaDataSource(jtaDataSourceName, beanManager),
150                                                    nonJtaDataSourceName -> this.getNonJtaDataSource(nonJtaDataSourceName, beanManager));
151          for (final PersistenceUnitInfo persistenceUnitInfo : persistenceUnitInfos) {
152            assert persistenceUnitInfo != null;
153
154            String persistenceUnitName = persistenceUnitInfo.getPersistenceUnitName();
155            if (persistenceUnitName == null) {
156              persistenceUnitName = "";
157            }
158
159            event.addBean()
160              .types(Collections.singleton(PersistenceUnitInfo.class))
161              .scope(Singleton.class)
162              .addQualifiers(NamedLiteral.of(persistenceUnitName))
163              .createWith(cc -> persistenceUnitInfo);
164
165            final String providerClassName = persistenceUnitInfo.getPersistenceProviderClassName();
166            if (providerClassName != null) {
167              @SuppressWarnings("unchecked")
168              final Class<? extends PersistenceProvider> c = (Class<? extends PersistenceProvider>)Class.forName(providerClassName, true, Thread.currentThread().getContextClassLoader());
169              assert c != null;
170              boolean add = true;
171              for (final PersistenceProvider provider : providers) {
172                if (c.equals(provider.getClass())) {
173                  add = false;
174                  break;
175                }
176              }
177              if (add) {
178                // The PersistenceProvider class in question is not
179                // one we already loaded.  Try to add a bean for it
180                // too.
181                final PersistenceProvider provider = c.newInstance();
182                event.addBean()
183                  .addTransitiveTypeClosure(provider.getClass())
184                  .scope(Singleton.class)
185                  .createWith(cc -> provider);
186              }
187            }
188            
189          }
190        }
191      }
192    }
193  }
194
195  private final DataSource getJtaDataSource(final String dataSourceName, final BeanManager beanManager) {
196    Objects.requireNonNull(dataSourceName);
197    Objects.requireNonNull(beanManager);
198    final Bean<?> bean = beanManager.resolve(beanManager.getBeans(DataSource.class, NamedLiteral.of(dataSourceName)));
199    DataSource returnValue = null;
200    if (bean != null) {
201      returnValue = (DataSource)beanManager.getReference(bean, DataSource.class, beanManager.createCreationalContext(bean));
202    }
203    return returnValue;
204  }
205
206  private final DataSource getNonJtaDataSource(final String dataSourceName, final BeanManager beanManager) {
207    Objects.requireNonNull(dataSourceName);
208    Objects.requireNonNull(beanManager);
209    final Bean<?> bean = beanManager.resolve(beanManager.getBeans(DataSource.class, NamedLiteral.of(dataSourceName)));
210    DataSource returnValue = null;
211    if (bean != null) {
212      returnValue = (DataSource)beanManager.getReference(bean, DataSource.class, beanManager.createCreationalContext(bean));
213    }
214    return returnValue;
215  }
216
217}