/*
 * Copyright Openmind http://www.openmindonline.it
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package it.openutils.testing;

import java.io.IOException;
import java.net.URL;
import java.sql.SQLException;
import java.util.HashMap;
import java.util.Map;

import javax.sql.DataSource;

import org.apache.commons.lang.ArrayUtils;
import org.apache.commons.lang.StringUtils;
import org.dbunit.DatabaseUnitException;
import org.dbunit.database.DatabaseConnection;
import org.dbunit.database.DatabaseSequenceFilter;
import org.dbunit.database.IDatabaseConnection;
import org.dbunit.dataset.DataSetException;
import org.dbunit.dataset.FilteredDataSet;
import org.dbunit.dataset.IDataSet;
import org.dbunit.dataset.filter.ITableFilter;
import org.dbunit.dataset.filter.SequenceTableFilter;
import org.dbunit.operation.DatabaseOperation;
import org.dbunit.operation.DeleteAllOperation;
import org.dbunit.operation.DeleteOperation;
import org.dbunit.operation.InsertOperation;
import org.dbunit.operation.RefreshOperation;
import org.dbunit.operation.TruncateTableOperation;
import org.dbunit.operation.UpdateOperation;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.context.ApplicationContext;


/**
 * @author fgiust
 * @version $Id: DbUnitTestContext.java 581 2008-01-31 22:19:56Z fgiust $
 */
public class DbUnitTestContext
{

    /**
     * Truncate dataset cache. This is kept as a static attribute since the creation of the dataset is very expensive
     * and it doesn't change across tests.
     */
    protected static Map<String, IDataSet> truncateDataSetCache = new HashMap<String, IDataSet>();

    /**
     * Logger.
     */
    private static Logger log = LoggerFactory.getLogger(DbUnitTestContext.class);

    /**
     * The test instance.
     */
    private Object testcase;

    /**
     * Spring application context.
     */
    private ApplicationContext applicationContext;

    /**
     * Dataset cache.
     */
    private Map<String, IDataSet> datasetCache = new HashMap<String, IDataSet>();

    /**
     * Instantiates a new DbUnitTestContext
     * @param testcase test instance
     * @param applicationContext Spring application context
     */
    public DbUnitTestContext(Object testcase, ApplicationContext applicationContext)
    {
        this.testcase = testcase;
        this.applicationContext = applicationContext;
    }

    /**
     * Setup the Database before running the test method.
     * @throws Exception Any exception.
     */
    @SuppressWarnings("unchecked")
    public void setUpDbUnit() throws Exception
    {
        DbUnitExecution singleDbUnitExecution = testcase.getClass().getAnnotation(DbUnitExecution.class);

        DbUnitExecution[] executions = null;
        if (singleDbUnitExecution != null)
        {
            executions = new DbUnitExecution[]{singleDbUnitExecution };
        }
        else
        {
            DbUnitConfiguration dbUnitConfiguration = testcase.getClass().getAnnotation(DbUnitConfiguration.class);
            if (dbUnitConfiguration != null)
            {
                executions = dbUnitConfiguration.dbUnitExecutions();
            }
        }
        if (executions != null)
        {
            for (DbUnitExecution dbUnitExecution : executions)
            {
                String[] datasets = dbUnitExecution.datasets();
                String dataSourceName = dbUnitExecution.dataSource();
                String schema = dbUnitExecution.schema();
                if (StringUtils.isEmpty(schema))
                {
                    schema = null;
                }

                IDatabaseConnection connection = new DatabaseConnection(
                    getDatasource(dataSourceName).getConnection(),
                    schema);

                try
                {
                    ITableFilter tableFilter = new RegExpTableFilter(dbUnitExecution.excludedTables());

                    if (dbUnitExecution.truncateAll())
                    {
                        truncateAll(connection, tableFilter, dataSourceName, getDatabaseOperation(dbUnitExecution
                            .truncateOperation()));
                    }

                    DatabaseOperation dbOperation = getDatabaseOperation(dbUnitExecution.insertOperation());
                    for (String datasetFile : datasets)
                    {
                        importDataSet(createDataset(datasetFile), connection, tableFilter, dataSourceName, dbOperation);
                    }
                }
                finally
                {
                    connection.close();
                }
            }
        }
    }

    /**
     * Instantiates the givec Database Operation. Standard operations in the <code>org.dbunit.operation</code> package
     * have protected constructors and needs to be accessed using the public static fields in
     * <code>org.dbunit.operation.DatabaseOperation</code>.
     * @param dboperationClass db operation class
     * @return db operation instance
     * @throws InstantiationException if the given db operation class cannot be instantiated
     * @throws IllegalAccessException if the given db operation class cannot be instantiated
     */
    private DatabaseOperation getDatabaseOperation(Class< ? extends DatabaseOperation> dboperationClass)
        throws InstantiationException, IllegalAccessException
    {
        if (UpdateOperation.class.equals(dboperationClass))
        {
            return DatabaseOperation.UPDATE;
        }
        else if (InsertOperation.class.equals(dboperationClass))
        {
            return DatabaseOperation.INSERT;
        }
        else if (RefreshOperation.class.equals(dboperationClass))
        {
            return DatabaseOperation.REFRESH;
        }
        else if (DeleteOperation.class.equals(dboperationClass))
        {
            return DatabaseOperation.DELETE;
        }
        else if (DeleteAllOperation.class.equals(dboperationClass))
        {
            return DatabaseOperation.DELETE_ALL;
        }
        else if (TruncateTableOperation.class.equals(dboperationClass))
        {
            return DatabaseOperation.TRUNCATE_TABLE;
        }
        else if (DeleteOperation.class.equals(dboperationClass))
        {
            return DatabaseOperation.UPDATE;
        }

        return dboperationClass.newInstance();

    }

    /**
     * Creates a dataset instance by fetching a file from the classpath
     * @param datasetFile name of the file, will be loaded from the classpath
     * @return IDataSet instance
     * @throws IOException
     * @throws DataSetException
     */
    private IDataSet createDataset(String datasetFile) throws IOException, DataSetException
    {
        IDataSet dataSet = datasetCache.get(datasetFile);
        if (dataSet == null)
        {
            URL datasetUrl = getClass().getResource(datasetFile);
            if (datasetUrl == null)
            {
                throw new IllegalArgumentException("Dataset " + datasetFile + " not found");
            }
            dataSet = DbUnitUtils.loadDataset(datasetUrl);
            datasetCache.put(datasetFile, dataSet);
        }
        return dataSet;
    }

    private void importDataSet(IDataSet dataSet, IDatabaseConnection connection, ITableFilter tableFilter,
        String dataSourceName, DatabaseOperation databaseOperation) throws SQLException, DataSetException,
        DatabaseUnitException
    {

        if (dataSet == null)
        {
            throw new IllegalArgumentException("dataSet is null");
        }

        IDataSet orderedDataset = new FilteredDataSet(tableFilter, dataSet);
        if (log.isDebugEnabled())
        {
            log.debug("Tables: {}", ArrayUtils.toString(orderedDataset.getTableNames()));
        }

        // if a sorted dataset is available, use table sequence for sorting
        IDataSet truncateDataSet = getTruncateDataset(dataSourceName, connection.getSchema());
        if (truncateDataSet != null)
        {
            ITableFilter filter = new SequenceTableFilter(truncateDataSet.getTableNames());
            orderedDataset = new FilteredDataSet(filter, dataSet);
        }

        if (dataSet != null)
        {
            databaseOperation.execute(connection, orderedDataset);
        }

    }

    /**
     * @param connection
     * @param tableFilter
     * @throws SQLException
     * @throws DataSetException
     * @throws DatabaseUnitException
     */
    private void truncateAll(IDatabaseConnection connection, ITableFilter tableFilter, String dataSourceName,
        DatabaseOperation databaseOperation) throws SQLException, DataSetException, DatabaseUnitException
    {
        IDataSet truncateDataSet = getTruncateDataset(dataSourceName, connection.getSchema());
        if (truncateDataSet == null)
        {
            log.debug("Generating sorted dataset for initial cleanup");
            IDataSet unsortedDataSet = connection.createDataSet();

            if (log.isDebugEnabled())
            {
                log.debug("Unfiltered truncateDataSet: {}", ArrayUtils.toString(unsortedDataSet.getTableNames()));
            }

            // excluded unwanted tables
            unsortedDataSet = new FilteredDataSet(tableFilter, unsortedDataSet);

            // sort tables
            ITableFilter sortingFilter = new DatabaseSequenceFilter(connection, unsortedDataSet.getTableNames());
            truncateDataSet = new FilteredDataSet(sortingFilter, unsortedDataSet);

            storeTruncateDataset(dataSourceName, connection.getSchema(), truncateDataSet);
            log.debug("Sorted dataset generated");
        }

        if (truncateDataSet != null)
        {
            if (log.isDebugEnabled())
            {
                log.debug("Tables truncateDataSet: {}", ArrayUtils.toString(truncateDataSet.getTableNames()));
            }
            databaseOperation.execute(connection, truncateDataSet);
        }
    }

    private IDataSet getTruncateDataset(String datasourceName, String schema)
    {
        return truncateDataSetCache.get(datasourceName + "_" + StringUtils.defaultString(schema));
    }

    private IDataSet storeTruncateDataset(String datasourceName, String schema, IDataSet dataset)
    {
        return truncateDataSetCache.put(datasourceName + "_" + StringUtils.defaultString(schema), dataset);
    }

    /**
     * Loads a named datasource from the spring context.
     * @param name datasource name.
     * @return Datasource instance
     */
    @SuppressWarnings("unchecked")
    protected DataSource getDatasource(String name)
    {
        if (StringUtils.isEmpty(name))
        {
            Map<String, DataSource> dsMap = applicationContext.getBeansOfType(DataSource.class);
            if (dsMap == null || dsMap.size() != 1)
            {
                throw new RuntimeException(
                    "Unable to find a datasource in spring applicationContext, please specify the datasource bean name "
                        + "using the \"datasource\" attribute of @DbUnitConfiguration");
            }
            return dsMap.values().iterator().next();
        }
        return (DataSource) applicationContext.getBean(name);
    }

}
