Source code for lsst.afw.geom.testUtils

#
# LSST Data Management System
# Copyright 2016 LSST Corporation.
#
# This product includes software developed by the
# LSST Project (http://www.lsst.org/).
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the LSST License Statement and
# the GNU General Public License along with this program.  If not,
# see <http://www.lsstcorp.org/LegalNotices/>.
#

__all__ = ["BoxGrid", "makeSipIwcToPixel", "makeSipPixelToIwc"]

import itertools
import math
import os
import pickle

import astshim as ast
import numpy as np
from numpy.testing import assert_allclose, assert_array_equal
from astshim.test import makeForwardPolyMap, makeTwoWayPolyMap
from lsst.afw.geom.wcsUtils import getCdMatrixFromMetadata

from .box import Box2I, Box2D
import lsst.afw.geom as afwGeom
from lsst.pex.exceptions import InvalidParameterError
import lsst.utils
import lsst.utils.tests


[docs]class BoxGrid: """Divide a box into nx by ny sub-boxes that tile the region The sub-boxes will be of the same type as `box` and will exactly tile `box`; they will also all be the same size, to the extent possible (some variation is inevitable for integer boxes that cannot be evenly divided. Parameters ---------- box : `lsst.afw.geom.Box2I` or `lsst.afw.geom.Box2D` the box to subdivide; the boxes in the grid will be of the same type numColRow : pair of `int` number of columns and rows """ def __init__(self, box, numColRow): if len(numColRow) != 2: raise RuntimeError( "numColRow=%r; must be a sequence of two integers" % (numColRow,)) self._numColRow = tuple(int(val) for val in numColRow) if isinstance(box, Box2I): stopDelta = 1 elif isinstance(box, Box2D): stopDelta = 0 else: raise RuntimeError("Unknown class %s of box %s" % (type(box), box)) self.boxClass = type(box) self.stopDelta = stopDelta minPoint = box.getMin() self.pointClass = type(minPoint) dtype = np.array(minPoint).dtype self._divList = [np.linspace(start=box.getMin()[i], stop=box.getMax()[i] + self.stopDelta, num=self._numColRow[i] + 1, endpoint=True, dtype=dtype) for i in range(2)] @property def numColRow(self): return self._numColRow def __getitem__(self, indXY): """Return the box at the specified x,y index Parameters ---------- indXY : pair of `ints` the x,y index to return Returns ------- subBox : `lsst.afw.geom.Box2I` or `lsst.afw.geom.Box2D` """ beg = self.pointClass(*[self._divList[i][indXY[i]] for i in range(2)]) end = self.pointClass( *[self._divList[i][indXY[i] + 1] - self.stopDelta for i in range(2)]) return self.boxClass(beg, end) def __len__(self): return self.shape[0]*self.shape[1] def __iter__(self): """Return an iterator over all boxes, where column varies most quickly """ for row in range(self.numColRow[1]): for col in range(self.numColRow[0]): yield self[col, row]
class FrameSetInfo: """Information about a FrameSet Parameters ---------- frameSet : `ast.FrameSet` The FrameSet about which you want information Notes ----- **Fields** baseInd : `int` Index of base frame currInd : `int` Index of current frame isBaseSkyFrame : `bool` Is the base frame an `ast.SkyFrame`? isCurrSkyFrame : `bool` Is the current frame an `ast.SkyFrame`? """ def __init__(self, frameSet): self.baseInd = frameSet.base self.currInd = frameSet.current self.isBaseSkyFrame = frameSet.getFrame(self.baseInd).className == "SkyFrame" self.isCurrSkyFrame = frameSet.getFrame(self.currInd).className == "SkyFrame" def makeSipPolyMapCoeffs(metadata, name): """Return a list of ast.PolyMap coefficients for the specified SIP matrix The returned list of coefficients for an ast.PolyMap that computes the following function: f(dxy) = dxy + sipPolynomial(dxy)) where dxy = pixelPosition - pixelOrigin and sipPolynomial is a polynomial with terms `<name>n_m for x^n y^m` (e.g. A2_0 is the coefficient for x^2 y^0) Parameters ---------- metadata : lsst.daf.base.PropertySet FITS metadata describing a WCS with the specified SIP coefficients name : str The desired SIP terms: one of A, B, AP, BP Returns ------- list A list of coefficients for an ast.PolyMap that computes the specified SIP polynomial, including a term for out = in Note ---- This is an internal function for use by makeSipIwcToPixel and makeSipPixelToIwc """ outAxisDict = dict(A=1, B=2, AP=1, BP=2) outAxis = outAxisDict.get(name) if outAxis is None: raise RuntimeError("%s not a supported SIP name" % (name,)) width = metadata.getAsInt("%s_ORDER" % (name,)) + 1 found = False # start with a term for out = in coeffs = [] if outAxis == 1: coeffs.append([1.0, outAxis, 1, 0]) else: coeffs.append([1.0, outAxis, 0, 1]) # add SIP distortion terms for xPower in range(width): for yPower in range(width): coeffName = "%s_%s_%s" % (name, xPower, yPower) if not metadata.exists(coeffName): continue found = True coeff = metadata.getAsDouble(coeffName) coeffs.append([coeff, outAxis, xPower, yPower]) if not found: raise RuntimeError("No %s coefficients found" % (name,)) return coeffs
[docs]def makeSipIwcToPixel(metadata): """Make an IWC to pixel transform with SIP distortion from FITS-WCS metadata This function is primarily intended for unit tests. IWC is intermediate world coordinates, as described in the FITS papers. Parameters ---------- metadata : lsst.daf.base.PropertySet FITS metadata describing a WCS with inverse SIP coefficients Returns ------- lsst.afw.geom.TransformPoint2ToPoint2 Transform from IWC position to pixel position (zero-based) in the forward direction. The inverse direction is not defined. Notes ----- The inverse SIP terms APn_m, BPn_m are polynomial coefficients x^n y^m for computing transformed x, y respectively. If we call the resulting polynomial inverseSipPolynomial, the returned transformation is: pixelPosition = pixel origin + uv + inverseSipPolynomial(uv) where uv = inverseCdMatrix * iwcPosition """ crpix = (metadata.get("CRPIX1") - 1, metadata.get("CRPIX2") - 1) pixelRelativeToAbsoluteMap = ast.ShiftMap(crpix) cdMatrix = getCdMatrixFromMetadata(metadata) cdMatrixMap = ast.MatrixMap(cdMatrix.copy()) coeffList = makeSipPolyMapCoeffs(metadata, "AP") + makeSipPolyMapCoeffs(metadata, "BP") coeffArr = np.array(coeffList, dtype=float) sipPolyMap = ast.PolyMap(coeffArr, 2, "IterInverse=0") iwcToPixelMap = cdMatrixMap.getInverse().then(sipPolyMap).then(pixelRelativeToAbsoluteMap) return afwGeom.TransformPoint2ToPoint2(iwcToPixelMap)
[docs]def makeSipPixelToIwc(metadata): """Make a pixel to IWC transform with SIP distortion from FITS-WCS metadata This function is primarily intended for unit tests. IWC is intermediate world coordinates, as described in the FITS papers. Parameters ---------- metadata : lsst.daf.base.PropertySet FITS metadata describing a WCS with forward SIP coefficients Returns ------- lsst.afw.geom.TransformPoint2ToPoint2 Transform from pixel position (zero-based) to IWC position in the forward direction. The inverse direction is not defined. Notes ----- The forward SIP terms An_m, Bn_m are polynomial coefficients x^n y^m for computing transformed x, y respectively. If we call the resulting polynomial sipPolynomial, the returned transformation is: iwcPosition = cdMatrix * (dxy + sipPolynomial(dxy)) where dxy = pixelPosition - pixelOrigin """ crpix = (metadata.get("CRPIX1") - 1, metadata.get("CRPIX2") - 1) pixelAbsoluteToRelativeMap = ast.ShiftMap(crpix).getInverse() cdMatrix = getCdMatrixFromMetadata(metadata) cdMatrixMap = ast.MatrixMap(cdMatrix.copy()) coeffList = makeSipPolyMapCoeffs(metadata, "A") + makeSipPolyMapCoeffs(metadata, "B") coeffArr = np.array(coeffList, dtype=float) sipPolyMap = ast.PolyMap(coeffArr, 2, "IterInverse=0") pixelToIwcMap = pixelAbsoluteToRelativeMap.then(sipPolyMap).then(cdMatrixMap) return afwGeom.TransformPoint2ToPoint2(pixelToIwcMap)
class PermutedFrameSet: """A FrameSet with base or current frame possibly permuted, with associated information Only two-axis frames will be permuted. Parameters ---------- frameSet : `ast.FrameSet` The FrameSet you wish to permute. A deep copy is made. permuteBase : `bool` Permute the base frame's axes? permuteCurr : `bool` Permute the current frame's axes? Raises ------ RuntimeError If you try to permute a frame that does not have 2 axes Notes ----- **Fields** frameSet : `ast.FrameSet` The FrameSet that may be permuted. A local copy is made. isBaseSkyFrame : `bool` Is the base frame an `ast.SkyFrame`? isCurrSkyFrame : `bool` Is the current frame an `ast.SkyFrame`? isBasePermuted : `bool` Are the base frame axes permuted? isCurrPermuted : `bool` Are the current frame axes permuted? """ def __init__(self, frameSet, permuteBase, permuteCurr): self.frameSet = frameSet.copy() fsInfo = FrameSetInfo(self.frameSet) self.isBaseSkyFrame = fsInfo.isBaseSkyFrame self.isCurrSkyFrame = fsInfo.isCurrSkyFrame if permuteBase: baseNAxes = self.frameSet.getFrame(fsInfo.baseInd).nAxes if baseNAxes != 2: raise RuntimeError("Base frame has {} axes; 2 required to permute".format(baseNAxes)) self.frameSet.current = fsInfo.baseInd self.frameSet.permAxes([2, 1]) self.frameSet.current = fsInfo.currInd if permuteCurr: currNAxes = self.frameSet.getFrame(fsInfo.currInd).nAxes if currNAxes != 2: raise RuntimeError("Current frame has {} axes; 2 required to permute".format(currNAxes)) assert self.frameSet.getFrame(fsInfo.currInd).nAxes == 2 self.frameSet.permAxes([2, 1]) self.isBasePermuted = permuteBase self.isCurrPermuted = permuteCurr class TransformTestBaseClass(lsst.utils.tests.TestCase): """Base class for unit tests of Transform<X>To<Y> Subclasses must call `TransformTestBaseClass.setUp(self)` if they provide their own version. If a package other than afw uses this class then it must override the `getTestDir` method to avoid writing into afw's test directory. """ def getTestDir(self): """Return a directory where temporary test files can be written The default implementation returns the test directory of the `afw` package. If this class is used by a test in a package other than `afw` then the subclass must override this method. """ return os.path.join(lsst.utils.getPackageDir("afw"), "tests") def setUp(self): """Set up a test Subclasses should call this method if they override setUp. """ # tell unittest to use the msg argument of asserts as a supplement # to the error message, rather than as the whole error message self.longMessage = True # list of endpoint class name prefixes; the full name is prefix + "Endpoint" self.endpointPrefixes = ("Generic", "Point2", "SpherePoint") # GoodNAxes is dict of endpoint class name prefix: # tuple containing 0 or more valid numbers of axes self.goodNAxes = { "Generic": (1, 2, 3, 4), # all numbers of axes are valid for GenericEndpoint "Point2": (2,), "SpherePoint": (2,), } # BadAxes is dict of endpoint class name prefix: # tuple containing 0 or more invalid numbers of axes self.badNAxes = { "Generic": (), # all numbers of axes are valid for GenericEndpoint "Point2": (1, 3, 4), "SpherePoint": (1, 3, 4), } # Dict of frame index: identity name for frames created by makeFrameSet self.frameIdentDict = { 1: "baseFrame", 2: "frame2", 3: "frame3", 4: "currFrame", } @staticmethod def makeRawArrayData(nPoints, nAxes, delta=0.123): """Make an array of generic point data The data will be suitable for spherical points Parameters ---------- nPoints : `int` Number of points in the array nAxes : `int` Number of axes in the point Returns ------- np.array of floats with shape (nAxes, nPoints) The values are as follows; if nAxes != 2: The first point has values `[0, delta, 2*delta, ..., (nAxes-1)*delta]` The Nth point has those values + N if nAxes == 2 then the data is scaled so that the max value of axis 1 is a bit less than pi/2 """ delta = 0.123 # oneAxis = [0, 1, 2, ...nPoints-1] oneAxis = np.arange(nPoints, dtype=float) # [0, 1, 2...] # rawData = [oneAxis, oneAxis + delta, oneAxis + 2 delta, ...] rawData = np.array([j * delta + oneAxis for j in range(nAxes)], dtype=float) if nAxes == 2: # scale rawData so that max value of 2nd axis is a bit less than pi/2, # thus making the data safe for SpherePoint maxLatitude = np.max(rawData[1]) rawData *= math.pi * 0.4999 / maxLatitude return rawData @staticmethod def makeRawPointData(nAxes, delta=0.123): """Make one generic point Parameters ---------- nAxes : `int` Number of axes in the point delta : `float` Increment between axis values Returns ------- A list of `nAxes` floats with values `[0, delta, ..., (nAxes-1)*delta] """ return [i*delta for i in range(nAxes)] @staticmethod def makeEndpoint(name, nAxes=None): """Make an endpoint Parameters ---------- name : `str` Endpoint class name prefix; the full class name is name + "Endpoint" nAxes : `int` or `None`, optional number of axes; an int is required if `name` == "Generic"; otherwise ignored Returns ------- subclass of `lsst.afw.geom.BaseEndpoint` The constructed endpoint Raises ------ TypeError If `name` == "Generic" and `nAxes` is None or <= 0 """ EndpointClassName = name + "Endpoint" EndpointClass = getattr(afwGeom, EndpointClassName) if name == "Generic": if nAxes is None: raise TypeError("nAxes must be an integer for GenericEndpoint") return EndpointClass(nAxes) return EndpointClass() @classmethod def makeGoodFrame(cls, name, nAxes=None): """Return the appropriate frame for the given name and nAxes Parameters ---------- name : `str` Endpoint class name prefix; the full class name is name + "Endpoint" nAxes : `int` or `None`, optional number of axes; an int is required if `name` == "Generic"; otherwise ignored Returns ------- `ast.Frame` The constructed frame Raises ------ TypeError If `name` == "Generic" and `nAxes` is `None` or <= 0 """ return cls.makeEndpoint(name, nAxes).makeFrame() @staticmethod def makeBadFrames(name): """Return a list of 0 or more frames that are not a valid match for the named endpoint Parameters ---------- name : `str` Endpoint class name prefix; the full class name is name + "Endpoint" Returns ------- Collection of `ast.Frame` A collection of 0 or more frames """ return { "Generic": [], "Point2": [ ast.SkyFrame(), ast.Frame(1), ast.Frame(3), ], "SpherePoint": [ ast.Frame(1), ast.Frame(2), ast.Frame(3), ], }[name] def makeFrameSet(self, baseFrame, currFrame): """Make a FrameSet The FrameSet will contain 4 frames and three transforms connecting them. The idenity of each frame is provided by self.frameIdentDict Frame Index Mapping from this frame to the next `baseFrame` 1 `ast.UnitMap(nIn)` Frame(nIn) 2 `polyMap` Frame(nOut) 3 `ast.UnitMap(nOut)` `currFrame` 4 where: - `nIn` = `baseFrame.nAxes` - `nOut` = `currFrame.nAxes` - `polyMap` = `makeTwoWayPolyMap(nIn, nOut)` Returns ------ `ast.FrameSet` The FrameSet as described above Parameters ---------- baseFrame : `ast.Frame` base frame currFrame : `ast.Frame` current frame """ nIn = baseFrame.nAxes nOut = currFrame.nAxes polyMap = makeTwoWayPolyMap(nIn, nOut) # The only way to set the Ident of a frame in a FrameSet is to set it in advance, # and I don't want to modify the inputs, so replace the input frames with copies baseFrame = baseFrame.copy() baseFrame.ident = self.frameIdentDict[1] currFrame = currFrame.copy() currFrame.ident = self.frameIdentDict[4] frameSet = ast.FrameSet(baseFrame) frame2 = ast.Frame(nIn) frame2.ident = self.frameIdentDict[2] frameSet.addFrame(ast.FrameSet.CURRENT, ast.UnitMap(nIn), frame2) frame3 = ast.Frame(nOut) frame3.ident = self.frameIdentDict[3] frameSet.addFrame(ast.FrameSet.CURRENT, polyMap, frame3) frameSet.addFrame(ast.FrameSet.CURRENT, ast.UnitMap(nOut), currFrame) return frameSet @staticmethod def permuteFrameSetIter(frameSet): """Iterator over 0 or more frameSets with SkyFrames axes permuted Only base and current SkyFrames are permuted. If neither the base nor the current frame is a SkyFrame then no frames are returned. Returns ------- iterator over `PermutedFrameSet` """ fsInfo = FrameSetInfo(frameSet) if not (fsInfo.isBaseSkyFrame or fsInfo.isCurrSkyFrame): return permuteBaseList = [False, True] if fsInfo.isBaseSkyFrame else [False] permuteCurrList = [False, True] if fsInfo.isCurrSkyFrame else [False] for permuteBase in permuteBaseList: for permuteCurr in permuteCurrList: yield PermutedFrameSet(frameSet, permuteBase, permuteCurr) @staticmethod def makeJacobian(nIn, nOut, inPoint): """Make a Jacobian matrix for the equation described by `makeTwoWayPolyMap`. Parameters ---------- nIn, nOut : `int` the dimensions of the input and output data; see makeTwoWayPolyMap inPoint : `numpy.ndarray` an array of size `nIn` representing the point at which the Jacobian is measured Returns ------- J : `numpy.ndarray` an `nOut` x `nIn` array of first derivatives """ basePolyMapCoeff = 0.001 # see makeTwoWayPolyMap baseCoeff = 2.0 * basePolyMapCoeff coeffs = np.empty((nOut, nIn)) for iOut in range(nOut): coeffOffset = baseCoeff * iOut for iIn in range(nIn): coeffs[iOut, iIn] = baseCoeff * (iIn + 1) + coeffOffset coeffs[iOut, iIn] *= inPoint[iIn] assert coeffs.ndim == 2 # Avoid spurious errors when comparing to a simplified array assert coeffs.shape == (nOut, nIn) return coeffs def checkTransformation(self, transform, mapping, msg=""): """Check applyForward and applyInverse for a transform Parameters ---------- transform : `lsst.afw.geom.Transform` The transform to check mapping : `ast.Mapping` The mapping the transform should use. This mapping must contain valid forward or inverse transformations, but they need not match if both present. Hence the mappings returned by make*PolyMap are acceptable. msg : `str` Error message suffix describing test parameters """ fromEndpoint = transform.fromEndpoint toEndpoint = transform.toEndpoint mappingFromTransform = transform.getMapping() nIn = mapping.nIn nOut = mapping.nOut self.assertEqual(nIn, fromEndpoint.nAxes, msg=msg) self.assertEqual(nOut, toEndpoint.nAxes, msg=msg) # forward transformation of one point rawInPoint = self.makeRawPointData(nIn) inPoint = fromEndpoint.pointFromData(rawInPoint) # forward transformation of an array of points nPoints = 7 # arbitrary rawInArray = self.makeRawArrayData(nPoints, nIn) inArray = fromEndpoint.arrayFromData(rawInArray) if mapping.hasForward: self.assertTrue(transform.hasForward) outPoint = transform.applyForward(inPoint) rawOutPoint = toEndpoint.dataFromPoint(outPoint) assert_allclose(rawOutPoint, mapping.applyForward(rawInPoint), err_msg=msg) assert_allclose(rawOutPoint, mappingFromTransform.applyForward(rawInPoint), err_msg=msg) outArray = transform.applyForward(inArray) rawOutArray = toEndpoint.dataFromArray(outArray) self.assertFloatsAlmostEqual(rawOutArray, mapping.applyForward(rawInArray), msg=msg) self.assertFloatsAlmostEqual(rawOutArray, mappingFromTransform.applyForward(rawInArray), msg=msg) else: # Need outPoint, but don't need it to be consistent with inPoint rawOutPoint = self.makeRawPointData(nOut) outPoint = toEndpoint.pointFromData(rawOutPoint) rawOutArray = self.makeRawArrayData(nPoints, nOut) outArray = toEndpoint.arrayFromData(rawOutArray) self.assertFalse(transform.hasForward) if mapping.hasInverse: self.assertTrue(transform.hasInverse) # inverse transformation of one point; # remember that the inverse need not give the original values # (see the description of the `mapping` parameter) inversePoint = transform.applyInverse(outPoint) rawInversePoint = fromEndpoint.dataFromPoint(inversePoint) assert_allclose(rawInversePoint, mapping.applyInverse(rawOutPoint), err_msg=msg) assert_allclose(rawInversePoint, mappingFromTransform.applyInverse(rawOutPoint), err_msg=msg) # inverse transformation of an array of points; # remember that the inverse will not give the original values # (see the description of the `mapping` parameter) inverseArray = transform.applyInverse(outArray) rawInverseArray = fromEndpoint.dataFromArray(inverseArray) self.assertFloatsAlmostEqual(rawInverseArray, mapping.applyInverse(rawOutArray), msg=msg) self.assertFloatsAlmostEqual(rawInverseArray, mappingFromTransform.applyInverse(rawOutArray), msg=msg) else: self.assertFalse(transform.hasInverse) def checkInverseTransformation(self, forward, inverse, msg=""): """Check that two Transforms are each others' inverses. Parameters ---------- forward : `lsst.afw.geom.Transform` the reference Transform to test inverse : `lsst.afw.geom.Transform` the transform that should be the inverse of `forward` msg : `str` error message suffix describing test parameters """ fromEndpoint = forward.fromEndpoint toEndpoint = forward.toEndpoint forwardMapping = forward.getMapping() inverseMapping = inverse.getMapping() # properties self.assertEqual(forward.fromEndpoint, inverse.toEndpoint, msg=msg) self.assertEqual(forward.toEndpoint, inverse.fromEndpoint, msg=msg) self.assertEqual(forward.hasForward, inverse.hasInverse, msg=msg) self.assertEqual(forward.hasInverse, inverse.hasForward, msg=msg) # transformations of one point # we don't care about whether the transformation itself is correct # (see checkTransformation), so inPoint/outPoint need not be related rawInPoint = self.makeRawPointData(fromEndpoint.nAxes) inPoint = fromEndpoint.pointFromData(rawInPoint) rawOutPoint = self.makeRawPointData(toEndpoint.nAxes) outPoint = toEndpoint.pointFromData(rawOutPoint) # transformations of arrays of points nPoints = 7 # arbitrary rawInArray = self.makeRawArrayData(nPoints, fromEndpoint.nAxes) inArray = fromEndpoint.arrayFromData(rawInArray) rawOutArray = self.makeRawArrayData(nPoints, toEndpoint.nAxes) outArray = toEndpoint.arrayFromData(rawOutArray) if forward.hasForward: self.assertEqual(forward.applyForward(inPoint), inverse.applyInverse(inPoint), msg=msg) self.assertEqual(forwardMapping.applyForward(rawInPoint), inverseMapping.applyInverse(rawInPoint), msg=msg) # Assertions must work with both lists and numpy arrays assert_array_equal(forward.applyForward(inArray), inverse.applyInverse(inArray), err_msg=msg) assert_array_equal(forwardMapping.applyForward(rawInArray), inverseMapping.applyInverse(rawInArray), err_msg=msg) if forward.hasInverse: self.assertEqual(forward.applyInverse(outPoint), inverse.applyForward(outPoint), msg=msg) self.assertEqual(forwardMapping.applyInverse(rawOutPoint), inverseMapping.applyForward(rawOutPoint), msg=msg) assert_array_equal(forward.applyInverse(outArray), inverse.applyForward(outArray), err_msg=msg) assert_array_equal(forwardMapping.applyInverse(rawOutArray), inverseMapping.applyForward(rawOutArray), err_msg=msg) def checkTransformFromMapping(self, fromName, toName): """Check Transform_<fromName>_<toName> using the Mapping constructor Parameters ---------- fromName, toName : `str` Endpoint name prefix for "from" and "to" endpoints, respectively, e.g. "Point2" for `lsst.afw.geom.Point2Endpoint` fromAxes, toAxes : `int` number of axes in fromFrame and toFrame, respectively """ transformClassName = "Transform{}To{}".format(fromName, toName) TransformClass = getattr(afwGeom, transformClassName) baseMsg = "TransformClass={}".format(TransformClass.__name__) # check valid numbers of inputs and outputs for nIn, nOut in itertools.product(self.goodNAxes[fromName], self.goodNAxes[toName]): msg = "{}, nIn={}, nOut={}".format(baseMsg, nIn, nOut) polyMap = makeTwoWayPolyMap(nIn, nOut) transform = TransformClass(polyMap) # desired output from `str(transform)` desStr = "{}[{}->{}]".format(transformClassName, nIn, nOut) self.assertEqual("{}".format(transform), desStr) self.assertEqual(repr(transform), "lsst.afw.geom." + desStr) self.checkTransformation(transform, polyMap, msg=msg) # Forward transform but no inverse polyMap = makeForwardPolyMap(nIn, nOut) transform = TransformClass(polyMap) self.checkTransformation(transform, polyMap, msg=msg) # Inverse transform but no forward polyMap = makeForwardPolyMap(nOut, nIn).getInverse() transform = TransformClass(polyMap) self.checkTransformation(transform, polyMap, msg=msg) # check invalid # of output against valid # of inputs for nIn, badNOut in itertools.product(self.goodNAxes[fromName], self.badNAxes[toName]): badPolyMap = makeTwoWayPolyMap(nIn, badNOut) msg = "{}, nIn={}, badNOut={}".format(baseMsg, nIn, badNOut) with self.assertRaises(InvalidParameterError, msg=msg): TransformClass(badPolyMap) # check invalid # of inputs against valid and invalid # of outputs for badNIn, nOut in itertools.product(self.badNAxes[fromName], self.goodNAxes[toName] + self.badNAxes[toName]): badPolyMap = makeTwoWayPolyMap(badNIn, nOut) msg = "{}, badNIn={}, nOut={}".format(baseMsg, nIn, nOut) with self.assertRaises(InvalidParameterError, msg=msg): TransformClass(badPolyMap) def checkTransformFromFrameSet(self, fromName, toName): """Check Transform_<fromName>_<toName> using the FrameSet constructor Parameters ---------- fromName, toName : `str` Endpoint name prefix for "from" and "to" endpoints, respectively, e.g. "Point2" for `lsst.afw.geom.Point2Endpoint` """ transformClassName = "Transform{}To{}".format(fromName, toName) TransformClass = getattr(afwGeom, transformClassName) baseMsg = "TransformClass={}".format(TransformClass.__name__) for nIn, nOut in itertools.product(self.goodNAxes[fromName], self.goodNAxes[toName]): msg = "{}, nIn={}, nOut={}".format(baseMsg, nIn, nOut) baseFrame = self.makeGoodFrame(fromName, nIn) currFrame = self.makeGoodFrame(toName, nOut) frameSet = self.makeFrameSet(baseFrame, currFrame) self.assertEqual(frameSet.nFrame, 4) # construct 0 or more frame sets that are invalid for this transform class for badBaseFrame in self.makeBadFrames(fromName): badFrameSet = self.makeFrameSet(badBaseFrame, currFrame) with self.assertRaises(InvalidParameterError): TransformClass(badFrameSet) for badCurrFrame in self.makeBadFrames(toName): reallyBadFrameSet = self.makeFrameSet(badBaseFrame, badCurrFrame) with self.assertRaises(InvalidParameterError): TransformClass(reallyBadFrameSet) for badCurrFrame in self.makeBadFrames(toName): badFrameSet = self.makeFrameSet(baseFrame, badCurrFrame) with self.assertRaises(InvalidParameterError): TransformClass(badFrameSet) transform = TransformClass(frameSet) desStr = "{}[{}->{}]".format(transformClassName, nIn, nOut) self.assertEqual("{}".format(transform), desStr) self.assertEqual(repr(transform), "lsst.afw.geom." + desStr) self.checkPersistence(transform) mappingFromTransform = transform.getMapping() transformCopy = TransformClass(mappingFromTransform) self.assertEqual(type(transform), type(transformCopy)) self.assertEqual(transform.getMapping(), mappingFromTransform) polyMap = makeTwoWayPolyMap(nIn, nOut) self.checkTransformation(transform, mapping=polyMap, msg=msg) # If the base and/or current frame of frameSet is a SkyFrame, # try permuting that frame (in place, so the connected mappings are # correctly updated). The Transform constructor should undo the permutation, # (via SpherePointEndpoint.normalizeFrame) in its internal copy of frameSet, # forcing the axes of the SkyFrame into standard (longitude, latitude) order for permutedFS in self.permuteFrameSetIter(frameSet): if permutedFS.isBaseSkyFrame: baseFrame = permutedFS.frameSet.getFrame(ast.FrameSet.BASE) # desired base longitude axis desBaseLonAxis = 2 if permutedFS.isBasePermuted else 1 self.assertEqual(baseFrame.lonAxis, desBaseLonAxis) if permutedFS.isCurrSkyFrame: currFrame = permutedFS.frameSet.getFrame(ast.FrameSet.CURRENT) # desired current base longitude axis desCurrLonAxis = 2 if permutedFS.isCurrPermuted else 1 self.assertEqual(currFrame.lonAxis, desCurrLonAxis) permTransform = TransformClass(permutedFS.frameSet) self.checkTransformation(permTransform, mapping=polyMap, msg=msg) def checkGetInverse(self, fromName, toName): """Test Transform<fromName>To<toName>.getInverse Parameters ---------- fromName, toName : `str` Endpoint name prefix for "from" and "to" endpoints, respectively, e.g. "Point2" for `lsst.afw.geom.Point2Endpoint` """ transformClassName = "Transform{}To{}".format(fromName, toName) TransformClass = getattr(afwGeom, transformClassName) baseMsg = "TransformClass={}".format(TransformClass.__name__) for nIn, nOut in itertools.product(self.goodNAxes[fromName], self.goodNAxes[toName]): msg = "{}, nIn={}, nOut={}".format(baseMsg, nIn, nOut) self.checkInverseMapping( TransformClass, makeTwoWayPolyMap(nIn, nOut), "{}, Map={}".format(msg, "TwoWay")) self.checkInverseMapping( TransformClass, makeForwardPolyMap(nIn, nOut), "{}, Map={}".format(msg, "Forward")) self.checkInverseMapping( TransformClass, makeForwardPolyMap(nOut, nIn).getInverse(), "{}, Map={}".format(msg, "Inverse")) def checkInverseMapping(self, TransformClass, mapping, msg): """Test Transform<fromName>To<toName>.getInverse for a specific mapping. Parameters ---------- TransformClass : `type` The class of transform to test, such as TransformPoint2ToPoint2 mapping : `ast.Mapping` The mapping to use for the transform msg : `str` Error message suffix """ transform = TransformClass(mapping) inverse = transform.getInverse() inverseInverse = inverse.getInverse() self.checkInverseTransformation(transform, inverse, msg=msg) self.checkInverseTransformation(inverse, inverseInverse, msg=msg) self.checkTransformation(inverseInverse, mapping, msg=msg) def checkGetJacobian(self, fromName, toName): """Test Transform<fromName>To<toName>.getJacobian Parameters ---------- fromName, toName : `str` Endpoint name prefix for "from" and "to" endpoints, respectively, e.g. "Point2" for `lsst.afw.geom.Point2Endpoint` """ transformClassName = "Transform{}To{}".format(fromName, toName) TransformClass = getattr(afwGeom, transformClassName) baseMsg = "TransformClass={}".format(TransformClass.__name__) for nIn, nOut in itertools.product(self.goodNAxes[fromName], self.goodNAxes[toName]): msg = "{}, nIn={}, nOut={}".format(baseMsg, nIn, nOut) polyMap = makeForwardPolyMap(nIn, nOut) transform = TransformClass(polyMap) fromEndpoint = transform.fromEndpoint # Test multiple points to ensure correct functional form rawInPoint = self.makeRawPointData(nIn) inPoint = fromEndpoint.pointFromData(rawInPoint) jacobian = transform.getJacobian(inPoint) assert_allclose(jacobian, self.makeJacobian(nIn, nOut, rawInPoint), err_msg=msg) rawInPoint = self.makeRawPointData(nIn, 0.111) inPoint = fromEndpoint.pointFromData(rawInPoint) jacobian = transform.getJacobian(inPoint) assert_allclose(jacobian, self.makeJacobian(nIn, nOut, rawInPoint), err_msg=msg) def checkThen(self, fromName, midName, toName): """Test Transform<fromName>To<midName>.then(Transform<midName>To<toName>) Parameters ---------- fromName : `str` the prefix of the starting endpoint (e.g., "Point2" for a Point2Endpoint) for the final, concatenated Transform midName : `str` the prefix for the shared endpoint where two Transforms will be concatenated toName : `str` the prefix of the ending endpoint for the final, concatenated Transform """ TransformClass1 = getattr(afwGeom, "Transform{}To{}".format(fromName, midName)) TransformClass2 = getattr(afwGeom, "Transform{}To{}".format(midName, toName)) baseMsg = "{}.then({})".format(TransformClass1.__name__, TransformClass2.__name__) for nIn, nMid, nOut in itertools.product(self.goodNAxes[fromName], self.goodNAxes[midName], self.goodNAxes[toName]): msg = "{}, nIn={}, nMid={}, nOut={}".format( baseMsg, nIn, nMid, nOut) polyMap1 = makeTwoWayPolyMap(nIn, nMid) transform1 = TransformClass1(polyMap1) polyMap2 = makeTwoWayPolyMap(nMid, nOut) transform2 = TransformClass2(polyMap2) transform = transform1.then(transform2) fromEndpoint = transform1.fromEndpoint toEndpoint = transform2.toEndpoint inPoint = fromEndpoint.pointFromData(self.makeRawPointData(nIn)) outPointMerged = transform.applyForward(inPoint) outPointSeparate = transform2.applyForward( transform1.applyForward(inPoint)) assert_allclose(toEndpoint.dataFromPoint(outPointMerged), toEndpoint.dataFromPoint(outPointSeparate), err_msg=msg) outPoint = toEndpoint.pointFromData(self.makeRawPointData(nOut)) inPointMerged = transform.applyInverse(outPoint) inPointSeparate = transform1.applyInverse( transform2.applyInverse(outPoint)) assert_allclose( fromEndpoint.dataFromPoint(inPointMerged), fromEndpoint.dataFromPoint(inPointSeparate), err_msg=msg) # Mismatched number of axes should fail if midName == "Generic": nIn = self.goodNAxes[fromName][0] nOut = self.goodNAxes[toName][0] polyMap = makeTwoWayPolyMap(nIn, 3) transform1 = TransformClass1(polyMap) polyMap = makeTwoWayPolyMap(2, nOut) transform2 = TransformClass2(polyMap) with self.assertRaises(InvalidParameterError): transform = transform1.then(transform2) # Mismatched types of endpoints should fail if fromName != midName: # Use TransformClass1 for both args to keep test logic simple outName = midName joinNAxes = set(self.goodNAxes[fromName]).intersection( self.goodNAxes[outName]) for nIn, nMid, nOut in itertools.product(self.goodNAxes[fromName], joinNAxes, self.goodNAxes[outName]): polyMap = makeTwoWayPolyMap(nIn, nMid) transform1 = TransformClass1(polyMap) polyMap = makeTwoWayPolyMap(nMid, nOut) transform2 = TransformClass1(polyMap) with self.assertRaises(InvalidParameterError): transform = transform1.then(transform2) def assertTransformsEqual(self, transform1, transform2): """Assert that two transforms are equal""" self.assertEqual(type(transform1), type(transform2)) self.assertEqual(transform1.fromEndpoint, transform2.fromEndpoint) self.assertEqual(transform1.toEndpoint, transform2.toEndpoint) self.assertEqual(transform1.getMapping(), transform2.getMapping()) fromEndpoint = transform1.fromEndpoint toEndpoint = transform1.toEndpoint mapping = transform1.getMapping() nIn = mapping.nIn nOut = mapping.nOut if mapping.hasForward: nPoints = 7 # arbitrary rawInArray = self.makeRawArrayData(nPoints, nIn) inArray = fromEndpoint.arrayFromData(rawInArray) outArray = transform1.applyForward(inArray) outData = toEndpoint.dataFromArray(outArray) outArrayRoundTrip = transform2.applyForward(inArray) outDataRoundTrip = toEndpoint.dataFromArray(outArrayRoundTrip) assert_allclose(outData, outDataRoundTrip) if mapping.hasInverse: nPoints = 7 # arbitrary rawOutArray = self.makeRawArrayData(nPoints, nOut) outArray = toEndpoint.arrayFromData(rawOutArray) inArray = transform1.applyInverse(outArray) inData = fromEndpoint.dataFromArray(inArray) inArrayRoundTrip = transform2.applyInverse(outArray) inDataRoundTrip = fromEndpoint.dataFromArray(inArrayRoundTrip) assert_allclose(inData, inDataRoundTrip) def checkPersistence(self, transform): """Check persistence of a transform """ className = type(transform).__name__ # check writeString and readString transformStr = transform.writeString() serialVersion, serialClassName, serialRest = transformStr.split(" ", 2) self.assertEqual(int(serialVersion), 1) self.assertEqual(serialClassName, className) badStr1 = " ".join(["2", serialClassName, serialRest]) with self.assertRaises(lsst.pex.exceptions.InvalidParameterError): transform.readString(badStr1) badClassName = "x" + serialClassName badStr2 = " ".join(["1", badClassName, serialRest]) with self.assertRaises(lsst.pex.exceptions.InvalidParameterError): transform.readString(badStr2) transformFromStr1 = transform.readString(transformStr) self.assertTransformsEqual(transform, transformFromStr1) # check transformFromString transformFromStr2 = afwGeom.transformFromString(transformStr) self.assertTransformsEqual(transform, transformFromStr2) # Check pickling self.assertTransformsEqual(transform, pickle.loads(pickle.dumps(transform))) # Check afw::table::io persistence round-trip with lsst.utils.tests.getTempFilePath(".fits") as filename: transform.writeFits(filename) self.assertTransformsEqual(transform, type(transform).readFits(filename))