001    /**
002     * Copyright 2010-2013 The Kuali Foundation
003     *
004     * Licensed under the Educational Community License, Version 2.0 (the "License");
005     * you may not use this file except in compliance with the License.
006     * You may obtain a copy of the License at
007     *
008     * http://www.opensource.org/licenses/ecl2.php
009     *
010     * Unless required by applicable law or agreed to in writing, software
011     * distributed under the License is distributed on an "AS IS" BASIS,
012     * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013     * See the License for the specific language governing permissions and
014     * limitations under the License.
015     */
016    package org.kuali.common.jdbc.service;
017    
018    import java.io.IOException;
019    import java.sql.Connection;
020    import java.sql.DatabaseMetaData;
021    import java.sql.SQLException;
022    import java.sql.Statement;
023    import java.util.ArrayList;
024    import java.util.Arrays;
025    import java.util.Collections;
026    import java.util.List;
027    
028    import javax.sql.DataSource;
029    
030    import org.apache.commons.lang3.StringUtils;
031    import org.kuali.common.jdbc.listeners.NotifyingListener;
032    import org.kuali.common.jdbc.listeners.SqlListener;
033    import org.kuali.common.jdbc.listeners.ThreadSafeListener;
034    import org.kuali.common.jdbc.model.ExecutionResult;
035    import org.kuali.common.jdbc.model.ExecutionStats;
036    import org.kuali.common.jdbc.model.SqlBucket;
037    import org.kuali.common.jdbc.model.context.JdbcContext;
038    import org.kuali.common.jdbc.model.context.SqlBucketContext;
039    import org.kuali.common.jdbc.model.enums.CommitMode;
040    import org.kuali.common.jdbc.model.event.SqlEvent;
041    import org.kuali.common.jdbc.model.event.SqlExecutionEvent;
042    import org.kuali.common.jdbc.model.meta.Driver;
043    import org.kuali.common.jdbc.model.meta.JdbcMetaData;
044    import org.kuali.common.jdbc.model.meta.Product;
045    import org.kuali.common.jdbc.sql.model.SqlMetaData;
046    import org.kuali.common.jdbc.suppliers.SimpleStringSupplier;
047    import org.kuali.common.jdbc.suppliers.SqlSupplier;
048    import org.kuali.common.threads.ExecutionStatistics;
049    import org.kuali.common.threads.ThreadHandlerContext;
050    import org.kuali.common.threads.ThreadInvoker;
051    import org.kuali.common.util.Assert;
052    import org.kuali.common.util.CollectionUtils;
053    import org.kuali.common.util.FormatUtils;
054    import org.kuali.common.util.Str;
055    import org.kuali.common.util.inform.PercentCompleteInformer;
056    import org.kuali.common.util.nullify.NullUtils;
057    import org.slf4j.Logger;
058    import org.slf4j.LoggerFactory;
059    import org.springframework.jdbc.datasource.DataSourceUtils;
060    
061    public class DefaultJdbcService implements JdbcService {
062    
063            private static final Logger logger = LoggerFactory.getLogger(DefaultJdbcService.class);
064    
065            @Override
066            public ExecutionResult executeSql(JdbcContext context) {
067                    long start = System.currentTimeMillis();
068    
069                    // Log a message if provided
070                    if (context.getMessage().isPresent()) {
071                            logger.info(context.getMessage().get());
072                    }
073    
074                    // Make sure we have something to do
075                    if (CollectionUtils.isEmpty(context.getSuppliers())) {
076                            logger.info("Skipping execution.  No suppliers");
077                            return new ExecutionResult(0, start, System.currentTimeMillis(), 0);
078                    }
079    
080                    // Fire an event before executing any SQL
081                    long sqlStart = System.currentTimeMillis();
082                    context.getListener().beforeExecution(new SqlExecutionEvent(context, start, -1));
083    
084                    // Execute the SQL as dictated by the context
085                    ExecutionStats stats = null;
086                    if (context.isMultithreaded()) {
087                            stats = executeMultiThreaded(context);
088                    } else {
089                            stats = executeSequentially(context);
090                    }
091    
092                    // Fire an event now that all SQL execution is complete
093                    context.getListener().afterExecution(new SqlExecutionEvent(context, sqlStart, System.currentTimeMillis()));
094    
095                    return new ExecutionResult(stats.getUpdateCount(), start, System.currentTimeMillis(), stats.getStatementCount());
096            }
097    
098            protected ThreadsContext getThreadsContext(JdbcContext context) {
099                    // The tracking built into the kuali-threads package assumes "progress" equals one element from the list completing
100                    // It assumes you have a gigantic list where each element in the list = 1 unit of work
101                    // A large list of files that need to be posted to S3 (for example).
102                    // If we could randomly split up the SQL and execute it in whatever order we wanted, the built in tracking would work.
103                    // We cannot do that though, since the SQL in each file needs to execute sequentially in order
104                    // SQL from different files can execute concurrently, but the SQL inside each file needs to execute in order
105                    // For example OLE has ~250,000 SQL statements split up across ~300 files
106                    // In addition, the schema related DDL files need to execute first, then data, then constraints DDL files
107                    // Some files are HUGE, some are tiny. Printing a dot after each file completes doesn't make sense.
108                    // Our list of buckets is pretty small, even though the total number of SQL statements is quite large
109                    // Only printing a dot to the console when each bucket completes is not granular enough
110    
111                    // Calculate the total number of SQL statements across all of the suppliers
112                    long total = MetaDataUtils.getSqlCount(context.getSuppliers());
113    
114                    // Setup an informer based on the total number of SQL statements
115                    PercentCompleteInformer informer = new PercentCompleteInformer(total);
116    
117                    // Setup a thread safe listener that tracks SQL statements as they execute
118                    ThreadSafeListener threadSafeListener = new ThreadSafeListener(informer, context.isTrackProgressByUpdateCount());
119    
120                    // Add the thread safe listener to whatever listeners were already there
121                    List<SqlListener> listeners = new ArrayList<SqlListener>(Arrays.asList(context.getListener(), threadSafeListener));
122                    SqlListener listener = new NotifyingListener(listeners);
123                    return new ThreadsContext(informer, threadSafeListener, listener);
124            }
125    
126            protected ExecutionStats executeMultiThreaded(JdbcContext context) {
127    
128                    // Divide the SQL we have to execute up into buckets as "evenly" as possible
129                    List<SqlBucket> buckets = getSqlBuckets(context);
130    
131                    // Sort the buckets largest to smallest
132                    Collections.sort(buckets);
133                    Collections.reverse(buckets);
134    
135                    // This context includes a thread safe listener that prints a dot to the console each time 1% of the SQL gets completed
136                    ThreadsContext threadsContext = getThreadsContext(context);
137    
138                    // Provide some context for each bucket
139                    List<SqlBucketContext> sbcs = getSqlBucketContexts(buckets, context, threadsContext.getListener());
140    
141                    // Store some context for the thread handler
142                    ThreadHandlerContext<SqlBucketContext> thc = new ThreadHandlerContext<SqlBucketContext>();
143                    thc.setList(sbcs);
144                    thc.setHandler(new SqlBucketHandler());
145                    thc.setMax(buckets.size());
146                    thc.setMin(buckets.size());
147                    thc.setDivisor(1);
148    
149                    // Start threads to execute SQL from multiple suppliers concurrently
150                    ThreadInvoker invoker = new ThreadInvoker();
151                    threadsContext.getInformer().start();
152                    ExecutionStatistics stats = invoker.invokeThreads(thc);
153                    threadsContext.getInformer().stop();
154    
155                    ThreadSafeListener listener = threadsContext.getThreadSafeListener();
156                    logStats(listener, stats, buckets);
157                    return new ExecutionStats(listener.getAggregateUpdateCount(), listener.getAggregateSqlCount());
158            }
159    
160            protected void logStats(ThreadSafeListener listener, ExecutionStatistics stats, List<SqlBucket> buckets) {
161                    // Display thread related stats
162                    long aggregateTime = listener.getAggregateTime();
163                    long wallTime = stats.getExecutionTime();
164                    String avgMillis = FormatUtils.getTime(aggregateTime / buckets.size());
165                    String aTime = FormatUtils.getTime(aggregateTime);
166                    String wTime = FormatUtils.getTime(wallTime);
167                    String sqlCount = FormatUtils.getCount(listener.getAggregateSqlCount());
168                    String sqlSize = FormatUtils.getSize(listener.getAggregateSqlSize());
169                    Object[] args = { buckets.size(), wTime, aTime, avgMillis, sqlCount, sqlSize };
170                    logger.debug("Threads - [count: {}  time: {}  aggregate: {}  avg: {}  sql: {} - {}]", args);
171            }
172    
173            @Override
174            public ExecutionResult executeSql(DataSource dataSource, String sql) {
175                    return executeSql(dataSource, CollectionUtils.singletonList(sql));
176            }
177    
178            @Override
179            public ExecutionResult executeSql(DataSource dataSource, List<String> sql) {
180                    SqlSupplier supplier = new SimpleStringSupplier(sql);
181                    JdbcContext context = new JdbcContext.Builder(dataSource, supplier).build();
182                    return executeSql(context);
183            }
184    
185            protected List<SqlBucketContext> getSqlBucketContexts(List<SqlBucket> buckets, JdbcContext context, SqlListener listener) {
186                    List<SqlBucketContext> sbcs = new ArrayList<SqlBucketContext>();
187                    for (SqlBucket bucket : buckets) {
188                            JdbcContext newJdbcContext = getJdbcContext(context, bucket, listener);
189                            SqlBucketContext sbc = new SqlBucketContext(bucket, newJdbcContext, this);
190                            sbcs.add(sbc);
191                    }
192                    return sbcs;
193            }
194    
195            protected JdbcContext getJdbcContext(JdbcContext original, SqlBucket bucket, SqlListener listener) {
196                    boolean skip = original.isSkipSqlExecution();
197                    DataSource dataSource = original.getDataSource();
198                    List<SqlSupplier> suppliers = bucket.getSuppliers();
199                    CommitMode commitMode = original.getCommitMode();
200                    return new JdbcContext.Builder(dataSource, suppliers).listener(listener).commitMode(commitMode).skipSqlExecution(skip).build();
201            }
202    
203            protected List<SqlBucket> getSqlBuckets(JdbcContext context) {
204    
205                    // Pull out our list of suppliers
206                    List<SqlSupplier> suppliers = new ArrayList<SqlSupplier>(context.getSuppliers());
207    
208                    // number of buckets equals thread count, unless thread count > total number of sources
209                    int bucketCount = Math.min(context.getThreads(), suppliers.size());
210    
211                    // If bucket count is zero, we have issues
212                    Assert.isTrue(bucketCount > 0, "bucket count must be a positive integer");
213    
214                    // Sort the suppliers by SQL size
215                    Collections.sort(suppliers);
216    
217                    // Largest to smallest instead of smallest to largest
218                    Collections.reverse(suppliers);
219    
220                    // Allocate some buckets to hold the sql
221                    List<SqlBucket> buckets = CollectionUtils.getNewList(SqlBucket.class, bucketCount);
222    
223                    // Distribute the sources into buckets as evenly as possible
224                    // "Evenly" in this case means each bucket should be roughly the same size
225                    for (SqlSupplier supplier : suppliers) {
226    
227                            // Sort the buckets by size
228                            Collections.sort(buckets);
229    
230                            // First bucket in the list is the smallest
231                            SqlBucket smallest = buckets.get(0);
232    
233                            // Get a new bucket derived from the smallest bucket
234                            SqlBucket newBucket = getNewBucket(smallest, supplier);
235    
236                            // Remove the smallest bucket
237                            buckets.remove(0);
238    
239                            // Add the new bucket to the list
240                            buckets.add(newBucket);
241                    }
242    
243                    // Return the buckets
244                    return buckets;
245            }
246    
247            protected SqlBucket getNewBucket(SqlBucket bucket, SqlSupplier supplier) {
248                    List<SqlSupplier> list = new ArrayList<SqlSupplier>(bucket.getSuppliers());
249                    list.add(supplier);
250                    SqlMetaData smd = supplier.getMetaData();
251                    long count = bucket.getCount() + smd.getCount();
252                    long size = bucket.getSize() + smd.getSize();
253                    return new SqlBucket(count, size, list);
254            }
255    
256            protected ExecutionStats executeSequentially(JdbcContext context) {
257                    Connection conn = null;
258                    Statement statement = null;
259                    try {
260                            long updateCount = 0;
261                            long statementCount = 0;
262                            conn = DataSourceUtils.doGetConnection(context.getDataSource());
263                            boolean originalAutoCommitSetting = conn.getAutoCommit();
264                            conn.setAutoCommit(false);
265                            statement = conn.createStatement();
266                            List<SqlSupplier> suppliers = context.getSuppliers();
267                            for (SqlSupplier supplier : suppliers) {
268                                    ExecutionStats stats = excecuteSupplier(statement, context, supplier);
269                                    updateCount += stats.getUpdateCount();
270                                    statementCount += stats.getStatementCount();
271                                    conn.commit();
272                            }
273                            conn.setAutoCommit(originalAutoCommitSetting);
274                            return new ExecutionStats(updateCount, statementCount);
275                    } catch (Exception e) {
276                            throw new IllegalStateException(e);
277                    } finally {
278                            JdbcUtils.closeQuietly(context.getDataSource(), conn, statement);
279                    }
280            }
281    
282            protected ExecutionStats excecuteSupplier(Statement statement, JdbcContext context, SqlSupplier supplier) throws SQLException {
283                    try {
284                            long updateCount = 0;
285                            long statementCount = 0;
286                            supplier.open();
287                            List<String> sql = supplier.getSql();
288                            while (sql != null) {
289                                    for (String s : sql) {
290                                            updateCount += executeSql(statement, s, context);
291                                            statementCount++;
292                                    }
293                                    sql = supplier.getSql();
294                            }
295                            return new ExecutionStats(updateCount, statementCount);
296                    } catch (IOException e) {
297                            throw new IllegalStateException(e);
298                    } finally {
299                            supplier.close();
300                    }
301            }
302    
303            protected int executeSql(Statement statement, String sql, JdbcContext context) throws SQLException {
304                    try {
305                            int updateCount = 0;
306                            long start = System.currentTimeMillis();
307                            context.getListener().beforeExecuteSql(new SqlEvent(sql, start));
308                            if (!context.isSkipSqlExecution()) {
309    
310                                    // Execute the SQL
311                                    statement.execute(sql);
312    
313                                    // Get the number of rows updated as a result of executing this SQL
314                                    updateCount = statement.getUpdateCount();
315    
316                                    // Some SQL statements have nothing to do with updating rows
317                                    updateCount = (updateCount == -1) ? 0 : updateCount;
318                            }
319                            context.getListener().afterExecuteSql(new SqlEvent(sql, updateCount, start, System.currentTimeMillis()));
320                            return updateCount;
321                    } catch (SQLException e) {
322                            throw new SQLException("Error executing SQL [" + Str.flatten(sql) + "]", e);
323                    }
324            }
325    
326            @Override
327            public JdbcMetaData getJdbcMetaData(DataSource dataSource) {
328                    Connection conn = null;
329                    try {
330                            conn = DataSourceUtils.doGetConnection(dataSource);
331                            DatabaseMetaData dbmd = conn.getMetaData();
332                            return getJdbcMetaData(dbmd);
333                    } catch (Exception e) {
334                            throw new IllegalStateException(e);
335                    } finally {
336                            logger.trace("closing connection");
337                            JdbcUtils.closeQuietly(dataSource, conn);
338                    }
339            }
340    
341            protected JdbcMetaData getJdbcMetaData(DatabaseMetaData dbmd) throws SQLException {
342                    Product product = getProduct(dbmd);
343                    Driver driver = getDriver(dbmd);
344                    String url = dbmd.getURL();
345                    String username = dbmd.getUserName();
346                    if (username == null) {
347                            username = NullUtils.NULL;
348                    } else if (StringUtils.isBlank(username)) {
349                            username = NullUtils.NONE;
350                    }
351                    return new JdbcMetaData(product, driver, url, username);
352            }
353    
354            protected Product getProduct(DatabaseMetaData dbmd) throws SQLException {
355                    String name = dbmd.getDatabaseProductName();
356                    String version = dbmd.getDatabaseProductVersion();
357                    return new Product(name, version);
358            }
359    
360            protected Driver getDriver(DatabaseMetaData dbmd) throws SQLException {
361                    String name = dbmd.getDriverName();
362                    String version = dbmd.getDriverVersion();
363                    return new Driver(name, version);
364            }
365    }