Changeset 3877


Ignore:
Timestamp:
Oct 25, 2007, 9:57:20 AM (14 years ago)
Author:
Nicklas Nordborg
Message:

References #797: Enhance performance for LOWESS and Medin-ratio plug-ins

LOWESS has now been optimized to use fewer database queries. Initial test results indicates at
least 80% reduced execution time, but I have only tested on a 1-4 bioassays.

File:
1 edited

Legend:

Unmodified
Added
Removed
  • trunk/src/plugins/core/net/sf/basedb/plugins/LowessNormalization.java

    r3679 r3877  
    4242import net.sf.basedb.core.SpotBatcher;
    4343import net.sf.basedb.core.Transformation;
     44import net.sf.basedb.core.Type;
    4445import net.sf.basedb.core.VirtualColumn;
    4546import net.sf.basedb.core.plugin.About;
     
    5152import net.sf.basedb.core.plugin.Request;
    5253import net.sf.basedb.core.plugin.Response;
    53 import net.sf.basedb.core.query.Aggregations;
    5454import net.sf.basedb.core.query.Dynamic;
     55import net.sf.basedb.core.query.Expression;
    5556import net.sf.basedb.core.query.Expressions;
    56 import net.sf.basedb.core.query.JoinType;
    5757import net.sf.basedb.core.query.Orders;
    5858import net.sf.basedb.core.query.Restriction;
    5959import net.sf.basedb.core.query.Restrictions;
    60 import net.sf.basedb.core.query.Select;
    6160import net.sf.basedb.core.query.Selects;
    6261import net.sf.basedb.core.query.SqlResult;
    6362import net.sf.basedb.util.Values;
    6463
     64import java.sql.SQLException;
    6565import java.util.ArrayList;
    6666import java.util.Arrays;
     
    7373
    7474/**
    75    @author enell
     75   @author enell, Nicklas
    7676   @version 2.0
    7777   @base.modified $Date$
     
    103103    "method addressing single and multiple slide systematic " +
    104104    "variation. Nucleic Acids Res 2002, 30:e15.",
    105     "1.0",
     105    "2.5",
    106106    "2006, Base 2 development team",
    107107    null,
     
    193193    if (command.equals(Request.COMMAND_EXECUTE))
    194194    {
    195       SpotBatcher batcher = null;
    196195      DbControl dc = null;
    197196      try
     
    201200        String childName = Values.getString((String)job.getValue(CHILD_NAME), source.getName());
    202201        String childDescription = (String)job.getValue(CHILD_DESCRIPTION);
    203        
    204         // Create Transformation
    205         Transformation t = source.newTransformation(getCurrentJob(dc));
    206         t.setName(about.getName());
    207         dc.saveItem(t);
    208        
    209         // Create the normalized bioassay set
    210         BioAssaySet child = t.newProduct(null, "new", true);
    211         child.setName(childName);
    212         child.setDescription(childDescription);
    213         dc.saveItem(child);
    214        
    215         // Batcher for inserting normalized data
    216         batcher = child.getSpotBatcher();
    217 
    218202        int blockGroupSize = (Integer) job.getValue(blockGroupParameter.getName());
    219203        float delta = (Float) job.getValue(deltaParameter.getName());
    220204        float fitFraction = (Float) job.getValue(fitFractionParameter.getName());
    221         int iter = (Integer) job.getValue(iterParameter.getName());
    222 
    223         Select M = Selects.expression
    224         (
    225           Expressions.log2(Expressions.divide(Dynamic.column(VirtualColumn.channel(1)), Dynamic.column(VirtualColumn.channel(2)))),
    226           "m"
    227         );
    228         Select A = Selects.expression
    229         (
    230           Expressions.log10(Expressions.sqrt(Expressions.multiply(Dynamic.column(VirtualColumn.channel(1)), Dynamic.column(VirtualColumn.channel(2))))),
    231           "a"
    232         );
    233         Restriction intensityRestriction = Restrictions.and
    234         (
    235           Restrictions.gt(Dynamic.column(VirtualColumn.channel(1)), Expressions.aFloat(0)),
    236           Restrictions.gt(Dynamic.column(VirtualColumn.channel(2)), Expressions.aFloat(0))
    237         );
    238 
    239         DynamicSpotQuery query;
    240         DynamicResultIterator resultIter;
    241         query = source.getSpotData();
    242         query.restrict(intensityRestriction);
    243         long numSpots = query.count(dc);
    244         int normalizedSpots = 0;
    245         if (progress != null) progress.display((int)(normalizedSpots / numSpots * 100), normalizedSpots + " spots normalized");
    246        
    247         List<BioAssay> assays = source.getBioAssays().list(dc);
    248         for (BioAssay assay : assays)
    249         {
    250           query = assay.getSpotData();
    251           query.restrict(intensityRestriction);
    252           query.joinRawData(JoinType.LEFT);
    253           query.select(Selects.expression(Aggregations.max(Dynamic.rawData("block")), "max"));
    254           query.select(Selects.expression(Aggregations.min(Dynamic.rawData("block")), "min"));
    255           resultIter = query.iterate(dc);
    256           SqlResult result = resultIter.next();
    257           int maxBlock = result.getInt(resultIter.getIndex("max"));
    258           int minBlock = result.getInt(resultIter.getIndex("min"));
    259           resultIter.close();
    260 
    261           for (int i = minBlock; i <= maxBlock; i += blockGroupSize)
    262           {
    263             Restriction blockRestriction = Restrictions.between
    264             (
    265               Expressions.selected(Dynamic.selectRawData("block")),
    266               Expressions.integer(i),
    267               Expressions.integer(i+blockGroupSize)
    268             );
    269            
    270             query = assay.getSpotData();
    271             query.select(M);
    272             query.select(A);
    273             query.select(Dynamic.selectRawData("block"));
    274             query.joinRawData(JoinType.LEFT);
    275             query.restrictPermanent(intensityRestriction);
    276             query.restrictPermanent(blockRestriction);
    277             query.orderPermanent(Orders.asc(Expressions.selected(A)));
    278  
    279             int count = (int)query.count(dc);
    280             if (count <= 0) continue;
    281  
    282             List<Double> mValues = new ArrayList<Double>(count);
    283             List<Double> aValues = new ArrayList<Double>(count);
    284             DynamicResultIterator it = query.iterate(dc);
    285            
    286             int aIndex = it.getIndex(A.getAlias());
    287             int mIndex = it.getIndex(M.getAlias());
    288            
    289             while (it.hasNext())
    290             {
    291               SqlResult r = it.next();
    292               aValues.add((double) r.getFloat(aIndex));
    293               mValues.add((double) r.getFloat(mIndex));
    294             }
    295             it.close();
    296            
    297             List<Double> smoothCurve = lowess(aValues, mValues, fitFraction, iter, delta);
    298  
    299             query.reset();
    300             query.select(Dynamic.select(VirtualColumn.COLUMN));
    301             query.select(Dynamic.select(VirtualColumn.POSITION));
    302             query.select(Dynamic.select(VirtualColumn.channel(1)));
    303             query.select(Dynamic.select(VirtualColumn.channel(2)));
    304             query.select(A);
    305             query.joinRawData(JoinType.LEFT);
    306            
    307             it = query.iterate(dc);
    308             int column = it.getIndex(VirtualColumn.COLUMN.getName());
    309             int position = it.getIndex(VirtualColumn.POSITION.getName());
    310             int ch1Col = it.getIndex(VirtualColumn.channel(1).getName());
    311             int ch2Col = it.getIndex(VirtualColumn.channel(2).getName());
    312            
    313             for (int j = 0; j < smoothCurve.size() && it.hasNext(); j++)
    314             {
    315               SqlResult r = it.next();
    316  
    317               double factor = Math.exp(smoothCurve.get(j) * 0.5);
    318               double ch1 = r.getFloat(ch1Col)/factor;
    319               double ch2 = r.getFloat(ch2Col)*factor;
    320               batcher.insert(r.getShort(column), r.getInt(position), (float) ch1, (float) ch2);
    321             }
    322             it.close();
    323            
    324             normalizedSpots += smoothCurve.size();
    325             if (progress != null) progress.display((int)((normalizedSpots * 100) / numSpots), normalizedSpots + " spots normalized");
    326           }
    327         }
    328         batcher.close();
     205        int iterations = (Integer) job.getValue(iterParameter.getName());
     206        Job thisJob = getCurrentJob(dc);
     207       
     208        BioAssaySet child = normalize(dc, source, thisJob, fitFraction, delta, iterations, blockGroupSize, progress);
     209        child.setName(childName);
     210        child.setDescription(childDescription);
    329211        dc.commit();
     212        int normalizedSpots = child.getNumSpots();
    330213        if (progress != null) progress.display(100, normalizedSpots + " spots normalized\n");
    331214        response.setDone(normalizedSpots + " spots normalized, " + (source.getNumSpots() - normalizedSpots) + " spots removed");
     
    418301  // -------------------------------------------
    419302
     303  /**
     304    Normalise the source bioassay set using LOWESS normalization.
     305    @param dc The DbControl to use for database access
     306    @param source The source bioassay set that is going to be normalized
     307    @param job The job that is doing the normalization, or null
     308    @param fitFraction
     309    @param delta
     310    @param iterations
     311    @param blockGroupSize
     312    @return The normalized bioassayset
     313    @since 2.5
     314  */
     315  public BioAssaySet normalize(DbControl dc, BioAssaySet source, Job job, float fitFraction, float delta, int iterations, int blockGroupSize, ProgressReporter progress)
     316  {
     317    if (progress != null) progress.display(0, "Preparing to normalize...");
     318   
     319    // Create Transformation
     320    Transformation t = source.newTransformation(job);
     321    t.setName(about.getName());
     322    dc.saveItem(t);
     323   
     324    // Create the normalized bioassay set
     325    BioAssaySet child = t.newProduct(null, "new", true);
     326    dc.saveItem(child);
     327   
     328    // Batcher for inserting normalized data
     329    SpotBatcher batcher = child.getSpotBatcher();
     330
     331    // Expressions used to get data
     332    Expression block = Dynamic.rawData("block");
     333    Expression ch1 = Dynamic.column(VirtualColumn.channel(1));
     334    Expression ch2 = Dynamic.column(VirtualColumn.channel(2));
     335    // A = log10(sqrt(ch1 * ch2))
     336    Expression A = Expressions.log10(Expressions.sqrt(Expressions.multiply(ch1, ch2)));
     337
     338    // Create restriction: ch1 > 0 and ch2 > 0
     339    Restriction intensityRestriction = Restrictions.and(
     340        Restrictions.gt(ch1, Expressions.aFloat(0)),
     341        Restrictions.gt(ch2, Expressions.aFloat(0))
     342      );
     343
     344    // Create restriction: column = :bioAssayColumn
     345    Restriction bioAssayRestriction = Restrictions.eq(
     346        Dynamic.column(VirtualColumn.COLUMN),
     347        Expressions.parameter("bioAssayColumn")
     348      );
     349     
     350    // Count number of spots that is going to be normalized
     351    DynamicSpotQuery query = source.getSpotData();
     352    query.restrict(intensityRestriction);
     353    long numSpots = query.count(dc);
     354    int normalizedSpots = 0;
     355    if (progress != null) progress.display((int)(normalizedSpots / numSpots * 100), normalizedSpots + " spots normalized");
     356   
     357    // Create query to retrieve spot data: COLUMN, POSITION, ch1, ch2, block
     358    // We use a parameter to restrict the query to return data for one bioassay at a time
     359    query.select(Dynamic.select(VirtualColumn.POSITION));
     360    query.select(Selects.expression(ch1, "ch1"));
     361    query.select(Selects.expression(ch2, "ch2"));
     362    query.select(Selects.expression(block, "block"));
     363    query.restrict(bioAssayRestriction);
     364    query.order(Orders.asc(block));
     365    query.order(Orders.asc(A));
     366     
     367    // Normalize one bioassay at a time
     368    List<BioAssay> assays = source.getBioAssays().list(dc);
     369 
     370    try
     371    {
     372      for (BioAssay assay : assays)
     373      {
     374        // Prepare list for holding data
     375        int assaySpots = assay.getNumSpots();
     376        List<SpotData> data = new ArrayList<SpotData>(assaySpots);
     377       
     378        // Load spot data for this bioassay
     379        short bioassayColumn = assay.getDataCubeColumnNo();
     380        query.setParameter("bioAssayColumn", (int)bioassayColumn, Type.INT);
     381       
     382        DynamicResultIterator it = query.iterate(dc);
     383        int positionIndex = it.getIndex(VirtualColumn.POSITION.getName());
     384        int ch1Index = it.getIndex("ch1");
     385        int ch2Index = it.getIndex("ch2");
     386        int blockIndex = it.getIndex("block");
     387       
     388        // Copy bioassay data to SpotData objects
     389        while (it.hasNext())
     390        {
     391          SqlResult r = it.next();
     392          SpotData spot = new SpotData(r.getInt(positionIndex),
     393            r.getFloat(ch1Index), r.getFloat(ch2Index), r.getInt(blockIndex));
     394          data.add(spot);
     395        }
     396        it.close();
     397       
     398        // Continue with next bioassay if there is no data
     399        int dataSize = data.size();
     400        if (dataSize == 0) continue;
     401       
     402        // Get range of block numbers - NOTE! query must return spots sorted in block order
     403        int minBlock = data.get(0).block;
     404        int maxBlock = data.get(data.size()-1).block;
     405       
     406        int fromIndex = 0;
     407        int toIndex = 0;
     408        int fromBlock = minBlock;
     409        // Normalize each block range independently: fromBlock + blockGroupSize --> toBlock
     410        while (fromBlock <= maxBlock)
     411        {
     412          // Find start and end index for current block range
     413          int toBlock = fromBlock + blockGroupSize - 1;
     414          if (toBlock > maxBlock) toBlock = maxBlock;
     415          fromIndex = toIndex;
     416          // Data is sorted by block; find index of last spot with: block <= toBlock
     417          // spot given by toIndex should not be included
     418          while (toIndex < dataSize && data.get(toIndex).block <= toBlock)
     419          {
     420            ++toIndex;
     421          }
     422         
     423          if (toIndex > fromIndex)
     424          {
     425            List<Double> smoothCurve = lowess(data.subList(fromIndex, toIndex), fitFraction, iterations, delta);
     426            for (int j = 0; j < smoothCurve.size(); ++j)
     427            {
     428              SpotData spot = data.get(fromIndex + j);
     429              double factor = Math.exp(smoothCurve.get(j) * 0.5);
     430              double newCh1 = spot.ch1/factor;
     431              double newCh2 = spot.ch2*factor;
     432              batcher.insert(bioassayColumn, spot.position, (float) newCh1, (float) newCh2);
     433            }
     434            normalizedSpots += smoothCurve.size();
     435            if (progress != null) progress.display((int)((normalizedSpots * 100) / numSpots), normalizedSpots + " spots normalized");
     436          }
     437          fromBlock = toBlock + 1;
     438        }
     439      }
     440      batcher.flush();
     441      batcher.close();
     442    }
     443    catch (SQLException e)
     444    {
     445      throw new BaseException(e);
     446    }
     447    return child;
     448  }
     449 
    420450  private RequestInformation getConfigureJobParameters()
    421451  {
     
    455485    }
    456486    return configureJob;
    457   } 
    458  
    459 
    460   private static List<Double> lowess(List<Double> x, List<Double> y, double f, int iter, double delta)
    461   {
    462     Double[] smoothCurve = new Double[x.size()];
    463     int windowSize = Math.min(x.size(), (int) (x.size() * f + 0.5));
     487  }
     488
     489  private static List<Double> lowess(List<SpotData> data, double f, int iter, double delta)
     490  {
     491    int dataSize = data.size();
     492    Double[] smoothCurve = new Double[dataSize];
     493    int windowSize = Math.min(dataSize, (int) (dataSize * f + 0.5));
    464494
    465495    List<Double> wFit = new ArrayList<Double>();
    466     wFit.addAll(Collections.nCopies(x.size(), 1D));
    467     for (int iteration = 0; iteration < iter; iteration++)
     496    wFit.addAll(Collections.nCopies(dataSize, 1D));
     497    for (int iteration = 0; iteration < iter; ++iteration)
    468498    {
    469499      int windowStart = 0;
    470500      int i = 0;
    471501      int prevI = -1;
    472       while (prevI < x.size() - 1)
    473       {
    474         double xi = x.get(i);
    475         // center window around the i:th value in x
     502      while (prevI < dataSize - 1)
     503      {
     504        double Ai = data.get(i).A;
     505        // center window around the i:th value in A
    476506        // while distance from windowStart to i is greater then distance from i to windowEnd: move window
    477         while (windowStart + windowSize < x.size() && xi - x.get(windowStart) > x.get(windowStart + windowSize) - xi)
     507        while ((windowStart + windowSize < dataSize) &&
     508            (Ai - data.get(windowStart).A > data.get(windowStart + windowSize).A - Ai))
    478509        {
    479510          windowStart++;
    480511        }
    481512       
    482         List<Double> xWindow = x.subList(windowStart, windowStart + windowSize);
    483         List<Double> yWindow = y.subList(windowStart, windowStart + windowSize);
     513        List<SpotData> window = data.subList(windowStart, windowStart + windowSize);
    484514        List<Double> wFitWindow = wFit.subList(windowStart, windowStart + windowSize);
    485515       
    486         List<Double> w = calculateWeights(xWindow, xi, wFitWindow);
    487         double[] km = weightedLeastSquaresRegression(xWindow, yWindow, w);
    488         smoothCurve[i] = km[0] * xi + km[1];
     516        List<Double> w = calculateWeights(window, Ai, wFitWindow);
     517        double[] km = weightedLeastSquaresRegression(window, w);
     518        smoothCurve[i] = km[0] * Ai + km[1];
    489519
    490520        // Interpolate skipped points due to delta
    491521        if (prevI + 1 < i)
    492522        {
    493           double d = xi - x.get(prevI);
     523          double d = Ai - data.get(prevI).A;
    494524          if (d == 0)
    495525          {
     
    503533            for (int j = prevI + 1; j < i; j++)
    504534            {
    505               double a = (x.get(j) - x.get(prevI)) / d;
    506               smoothCurve[j] = a * smoothCurve[i] + (1D - a) * smoothCurve[prevI];
     535              double t = (data.get(j).A - data.get(prevI).A) / d;
     536              smoothCurve[j] = t * smoothCurve[i] + (1D - t) * smoothCurve[prevI];
    507537            }
    508538          }
     
    511541        // increase i, next x value must be at least delta greater then current
    512542        prevI = i;
    513         double cut = xi + delta;
    514         while (i < x.size()-1 && x.get(i) <= cut)
     543        double cut = Ai + delta;
     544        while (i < dataSize-1 && data.get(i).A <= cut)
    515545        {
    516546          i++;
     
    518548      }
    519549     
    520       double invYWRange = 1D/(ywRangeFactor  * medianCorrection(y, smoothCurve));
     550      double invYWRange = 1D/(ywRangeFactor  * medianCorrection(data, smoothCurve));
    521551      for (int j = 0; j < wFit.size(); j++)
    522552      {
    523         double w = Math.abs((smoothCurve[j] - y.get(j)) * invYWRange);
     553        double w = Math.abs((smoothCurve[j] - data.get(j).M) * invYWRange);
    524554        wFit.set(j, w < 1 ? Math.pow(Math.pow(1 - w, 2), 2) : 0);
    525555      }
    526556    }
    527 
    528557    return Arrays.asList(smoothCurve);
    529558  }
    530559
    531   private static double medianCorrection(List<Double> y, Double[] smoothCurve)
    532   {
    533     List<Double> temp = new ArrayList<Double>(y.size());
    534     for (int i = 0; i < y.size(); i++)
    535     {
    536       temp.add(Math.abs(y.get(i) - smoothCurve[i]));
     560  private static double medianCorrection(List<SpotData> data, Double[] smoothCurve)
     561  {
     562    List<Double> temp = new ArrayList<Double>(data.size());
     563    for (int i = 0; i < data.size(); i++)
     564    {
     565      temp.add(Math.abs(data.get(i).M - smoothCurve[i]));
    537566    }
    538567    Collections.sort(temp);
     
    550579
    551580 
    552   private static double[] weightedLeastSquaresRegression(List<Double> x, List<Double> y, List<Double> w)
     581  private static double[] weightedLeastSquaresRegression(List<SpotData> data, List<Double> w)
    553582  {
    554583    double k;
    555584    double m;
    556     double sumX = 0;
    557     double sumY = 0;
    558     double sumXX = 0;
    559     double sumXY = 0;
     585    double sumA = 0;
     586    double sumM = 0;
     587    double sumAA = 0;
     588    double sumAM = 0;
    560589    double sumW = 0;
    561     for (int j = 0; j < x.size(); j++)
     590    for (int j = 0; j < data.size(); j++)
    562591    {
    563592      double localW = w.get(j);
    564       double localX = x.get(j);
    565       double localY = y.get(j);
    566       sumX += localX * localW;
    567       sumY += localY * localW;
    568       sumXX += localX * localX * localW;
    569       sumXY += localX * localY * localW;
     593      double localA = data.get(j).A;
     594      double localM = data.get(j).M;
     595      sumA += localA * localW;
     596      sumM += localM * localW;
     597      sumAA += localA * localA * localW;
     598      sumAM += localA * localM * localW;
    570599      sumW += localW;
    571600    }
     
    574603      throw new BaseException("Sum of weigths in line_fit is not positive");
    575604    }
    576     double denom = sumW * sumXX - sumX * sumX;
     605    double denom = sumW * sumAA - sumA * sumA;
    577606    if (denom != 0.)
    578607    {
    579       k = (sumW * sumXY - sumX * sumY) / denom;
    580       m = (sumY - k * sumX) / sumW;
     608      k = (sumW * sumAM - sumA * sumM) / denom;
     609      m = (sumM - k * sumA) / sumW;
    581610    }
    582611    else
    583612    {
    584613      k = 0.;
    585       m = sumY / sumW;
     614      m = sumM / sumW;
    586615    }
    587616
     
    589618  }
    590619
    591   private static List<Double> calculateWeights(List<Double> values, double x1, List<Double> wFit)
    592   {
    593     List<Double> w = new ArrayList<Double>(values.size());
     620  private static List<Double> calculateWeights(List<SpotData> data, double A1, List<Double> wFit)
     621  {
     622    List<Double> w = new ArrayList<Double>(data.size());
    594623    double invRadius = 0;
    595     for (double x2 : values)
    596     {
    597       double abs = Math.abs(x1 - x2);
     624    for (SpotData spot : data)
     625    {
     626      double abs = Math.abs(A1 - spot.A);
    598627      if (abs > invRadius) invRadius = abs;
    599628    }
    600629    invRadius = 1 / invRadius;
    601     for (int i = 0; i < values.size(); i++)
    602     {
    603       double x2 = values.get(i);
    604       double distance = Math.abs(x1 - x2) * invRadius;
     630    for (int i = 0; i < data.size(); i++)
     631    {
     632      double A2 = data.get(i).A;
     633      double distance = Math.abs(A1 - A2) * invRadius;
    605634      w.add((distance < 1 ? Math.pow(1D - Math.pow(Math.abs(distance), 3), 3) : 0) * wFit.get(i));
    606635    }
    607636    return w;
    608637  }
    609  
     638
     639  private static class SpotData
     640  {
     641    final int position;
     642    final float ch1;
     643    final float ch2;
     644    final double M;
     645    final double A;
     646    final int block;
     647    static final double LN2 = Math.log(2);
     648
     649    public SpotData(int position, float ch1, float ch2, int block)
     650    {
     651      this.position = position;
     652      this.ch1 = ch1;
     653      this.ch2 = ch2;
     654      this.block = block;
     655      this.M = Math.log(ch1 / ch2) / LN2;
     656      this.A = Math.log10(Math.sqrt(ch1 * ch2));
     657    }
     658  }
     659
    610660}
Note: See TracChangeset for help on using the changeset viewer.