001/* -*- mode: Java; c-basic-offset: 2; indent-tabs-mode: nil; coding: utf-8-unix -*- 002 * 003 * Copyright © 2018 microBean. 004 * 005 * Licensed under the Apache License, Version 2.0 (the "License"); 006 * you may not use this file except in compliance with the License. 007 * You may obtain a copy of the License at 008 * 009 * http://www.apache.org/licenses/LICENSE-2.0 010 * 011 * Unless required by applicable law or agreed to in writing, software 012 * distributed under the License is distributed on an "AS IS" BASIS, 013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 014 * implied. See the License for the specific language governing 015 * permissions and limitations under the License. 016 */ 017package org.microbean.jpa.weld; 018 019import java.io.IOException; 020 021import java.net.URL; 022 023import java.util.ArrayList; 024import java.util.Collection; 025import java.util.Collections; 026import java.util.Enumeration; 027import java.util.HashMap; 028import java.util.HashSet; 029import java.util.Map; 030import java.util.Objects; 031import java.util.Set; 032 033import javax.enterprise.event.Observes; 034 035import javax.enterprise.inject.literal.NamedLiteral; 036 037import javax.enterprise.inject.spi.AfterBeanDiscovery; 038import javax.enterprise.inject.spi.AnnotatedType; 039import javax.enterprise.inject.spi.Bean; 040import javax.enterprise.inject.spi.BeanManager; 041import javax.enterprise.inject.spi.ProcessAnnotatedType; 042import javax.enterprise.inject.spi.WithAnnotations; 043 044import javax.inject.Singleton; 045 046import javax.persistence.Converter; 047import javax.persistence.Embeddable; 048import javax.persistence.Entity; 049import javax.persistence.MappedSuperclass; 050import javax.persistence.PersistenceUnit; 051 052import javax.persistence.spi.PersistenceProvider; 053import javax.persistence.spi.PersistenceProviderResolver; 054import javax.persistence.spi.PersistenceProviderResolverHolder; 055 056import javax.persistence.spi.PersistenceUnitInfo; 057 058import javax.sql.DataSource; 059 060import javax.xml.bind.JAXBContext; 061import javax.xml.bind.JAXBException; 062import javax.xml.bind.Unmarshaller; 063 064import org.microbean.jpa.jaxb.Persistence; 065 066public class Extension implements javax.enterprise.inject.spi.Extension { 067 068 private final Map<String, Set<Class<?>>> entityClassesByPersistenceUnitNames; 069 070 public Extension() { 071 super(); 072 this.entityClassesByPersistenceUnitNames = new HashMap<>(); 073 } 074 075 private final void discoverManagedClasses(@Observes @WithAnnotations({ Converter.class, Entity.class, Embeddable.class, MappedSuperclass.class }) final ProcessAnnotatedType<?> event) { 076 if (event != null) { 077 final AnnotatedType<?> annotatedType = event.getAnnotatedType(); 078 if (annotatedType != null) { 079 final Class<?> entityClass = annotatedType.getJavaClass(); 080 assert entityClass != null; 081 final Set<PersistenceUnit> persistenceUnits = annotatedType.getAnnotations(PersistenceUnit.class); 082 if (persistenceUnits == null || persistenceUnits.isEmpty()) { 083 Set<Class<?>> entityClasses = this.entityClassesByPersistenceUnitNames.get(""); 084 if (entityClasses == null) { 085 entityClasses = new HashSet<>(); 086 this.entityClassesByPersistenceUnitNames.put("", entityClasses); 087 } 088 entityClasses.add(entityClass); 089 } else { 090 for (final PersistenceUnit persistenceUnit : persistenceUnits) { 091 String name = ""; 092 if (persistenceUnit != null) { 093 name = persistenceUnit.unitName(); 094 assert name != null; 095 } 096 Set<Class<?>> entityClasses = this.entityClassesByPersistenceUnitNames.get(name); 097 if (entityClasses == null) { 098 entityClasses = new HashSet<>(); 099 this.entityClassesByPersistenceUnitNames.put(name, entityClasses); 100 } 101 entityClasses.add(entityClass); 102 } 103 } 104 event.veto(); // entities can't be beans 105 } 106 } 107 } 108 109 private final void afterBeanDiscovery(@Observes final AfterBeanDiscovery event, final BeanManager beanManager) 110 throws IOException, JAXBException, ReflectiveOperationException { 111 if (event != null && beanManager != null) { 112 113 // Add a bean for PersistenceProviderResolver. 114 final PersistenceProviderResolver resolver = 115 PersistenceProviderResolverHolder.getPersistenceProviderResolver(); 116 event.addBean() 117 .types(PersistenceProviderResolver.class) 118 .scope(Singleton.class) 119 .createWith(cc -> resolver); 120 121 // Add a bean for each "generic" PersistenceProvider reachable 122 // from the resolver. (Any PersistenceUnitInfo may also specify 123 // the class name of a PersistenceProvider whose class may not 124 // be among those loaded by the resolver; we deal with those 125 // later.) 126 final Collection<? extends PersistenceProvider> providers = resolver.getPersistenceProviders(); 127 for (final PersistenceProvider provider : providers) { 128 event.addBean() 129 .addTransitiveTypeClosure(provider.getClass()) 130 .scope(Singleton.class) 131 .createWith(cc -> provider); 132 } 133 134 // Discover all META-INF/persistence.xml resources, load them 135 // using JAXB, and turn them into PersistenceUnitInfo instances, 136 // and add beans for all of them. 137 final Enumeration<URL> urls = 138 Thread.currentThread().getContextClassLoader().getResources("META-INF/persistence.xml"); 139 if (urls != null && urls.hasMoreElements()) { 140 final Unmarshaller unmarshaller = 141 JAXBContext.newInstance("org.microbean.jpa.jaxb").createUnmarshaller(); 142 assert unmarshaller != null; 143 while (urls.hasMoreElements()) { 144 final URL url = urls.nextElement(); 145 final Collection<? extends PersistenceUnitInfo> persistenceUnitInfos = 146 PersistenceUnitInfoBean.fromPersistence((Persistence)unmarshaller.unmarshal(url), 147 new URL(url, "../.."), 148 this.entityClassesByPersistenceUnitNames, 149 jtaDataSourceName -> this.getJtaDataSource(jtaDataSourceName, beanManager), 150 nonJtaDataSourceName -> this.getNonJtaDataSource(nonJtaDataSourceName, beanManager)); 151 for (final PersistenceUnitInfo persistenceUnitInfo : persistenceUnitInfos) { 152 assert persistenceUnitInfo != null; 153 154 String persistenceUnitName = persistenceUnitInfo.getPersistenceUnitName(); 155 if (persistenceUnitName == null) { 156 persistenceUnitName = ""; 157 } 158 159 event.addBean() 160 .types(Collections.singleton(PersistenceUnitInfo.class)) 161 .scope(Singleton.class) 162 .addQualifiers(NamedLiteral.of(persistenceUnitName)) 163 .createWith(cc -> persistenceUnitInfo); 164 165 final String providerClassName = persistenceUnitInfo.getPersistenceProviderClassName(); 166 if (providerClassName != null) { 167 @SuppressWarnings("unchecked") 168 final Class<? extends PersistenceProvider> c = (Class<? extends PersistenceProvider>)Class.forName(providerClassName, true, Thread.currentThread().getContextClassLoader()); 169 assert c != null; 170 boolean add = true; 171 for (final PersistenceProvider provider : providers) { 172 if (c.equals(provider.getClass())) { 173 add = false; 174 break; 175 } 176 } 177 if (add) { 178 // The PersistenceProvider class in question is not 179 // one we already loaded. Try to add a bean for it 180 // too. 181 final PersistenceProvider provider = c.newInstance(); 182 event.addBean() 183 .addTransitiveTypeClosure(provider.getClass()) 184 .scope(Singleton.class) 185 .createWith(cc -> provider); 186 } 187 } 188 189 } 190 } 191 } 192 } 193 } 194 195 private final DataSource getJtaDataSource(final String dataSourceName, final BeanManager beanManager) { 196 Objects.requireNonNull(dataSourceName); 197 Objects.requireNonNull(beanManager); 198 final Bean<?> bean = beanManager.resolve(beanManager.getBeans(DataSource.class, NamedLiteral.of(dataSourceName))); 199 DataSource returnValue = null; 200 if (bean != null) { 201 returnValue = (DataSource)beanManager.getReference(bean, DataSource.class, beanManager.createCreationalContext(bean)); 202 } 203 return returnValue; 204 } 205 206 private final DataSource getNonJtaDataSource(final String dataSourceName, final BeanManager beanManager) { 207 Objects.requireNonNull(dataSourceName); 208 Objects.requireNonNull(beanManager); 209 final Bean<?> bean = beanManager.resolve(beanManager.getBeans(DataSource.class, NamedLiteral.of(dataSourceName))); 210 DataSource returnValue = null; 211 if (bean != null) { 212 returnValue = (DataSource)beanManager.getReference(bean, DataSource.class, beanManager.createCreationalContext(bean)); 213 } 214 return returnValue; 215 } 216 217}