#
# LSST Data Management System
# Copyright 2016 LSST Corporation.
#
# This product includes software developed by the
# LSST Project (http://www.lsst.org/).
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the LSST License Statement and
# the GNU General Public License along with this program.  If not,
# see <http://www.lsstcorp.org/LegalNotices/>.
#
__all__ = ["BoxGrid", "makeSipIwcToPixel", "makeSipPixelToIwc"]
import itertools
import math
import os
import pickle
import astshim as ast
import numpy as np
from numpy.testing import assert_allclose, assert_array_equal
from astshim.test import makeForwardPolyMap, makeTwoWayPolyMap
from lsst.afw.geom.wcsUtils import getCdMatrixFromMetadata
from .box import Box2I, Box2D
import lsst.afw.geom as afwGeom
from lsst.pex.exceptions import InvalidParameterError
import lsst.utils
import lsst.utils.tests
[docs]class BoxGrid:
    """Divide a box into nx by ny sub-boxes that tile the region
    The sub-boxes will be of the same type as `box` and will exactly tile `box`;
    they will also all be the same size, to the extent possible (some variation
    is inevitable for integer boxes that cannot be evenly divided.
    Parameters
    ----------
    box : `lsst.afw.geom.Box2I` or `lsst.afw.geom.Box2D`
        the box to subdivide; the boxes in the grid will be of the same type
    numColRow : pair of `int`
        number of columns and rows
    """
    def __init__(self, box, numColRow):
        if len(numColRow) != 2:
            raise RuntimeError(
                "numColRow=%r; must be a sequence of two integers" % (numColRow,))
        self._numColRow = tuple(int(val) for val in numColRow)
        if isinstance(box, Box2I):
            stopDelta = 1
        elif isinstance(box, Box2D):
            stopDelta = 0
        else:
            raise RuntimeError("Unknown class %s of box %s" % (type(box), box))
        self.boxClass = type(box)
        self.stopDelta = stopDelta
        minPoint = box.getMin()
        self.pointClass = type(minPoint)
        dtype = np.array(minPoint).dtype
        self._divList = [np.linspace(start=box.getMin()[i],
                                     stop=box.getMax()[i] + self.stopDelta,
                                     num=self._numColRow[i] + 1,
                                     endpoint=True,
                                     dtype=dtype) for i in range(2)]
    @property
    def numColRow(self):
        return self._numColRow
    def __getitem__(self, indXY):
        """Return the box at the specified x,y index
        Parameters
        ----------
        indXY : pair of `ints`
            the x,y index to return
        Returns
        -------
        subBox : `lsst.afw.geom.Box2I` or `lsst.afw.geom.Box2D`
        """
        beg = self.pointClass(*[self._divList[i][indXY[i]] for i in range(2)])
        end = self.pointClass(
            *[self._divList[i][indXY[i] + 1] - self.stopDelta for i in range(2)])
        return self.boxClass(beg, end)
    def __len__(self):
        return self.shape[0]*self.shape[1]
    def __iter__(self):
        """Return an iterator over all boxes, where column varies most quickly
        """
        for row in range(self.numColRow[1]):
            for col in range(self.numColRow[0]):
                yield self[col, row] 
class FrameSetInfo:
    """Information about a FrameSet
    Parameters
    ----------
    frameSet : `ast.FrameSet`
        The FrameSet about which you want information
    Notes
    -----
    **Fields**
    baseInd : `int`
        Index of base frame
    currInd : `int`
        Index of current frame
    isBaseSkyFrame : `bool`
        Is the base frame an `ast.SkyFrame`?
    isCurrSkyFrame : `bool`
        Is the current frame an `ast.SkyFrame`?
    """
    def __init__(self, frameSet):
        self.baseInd = frameSet.base
        self.currInd = frameSet.current
        self.isBaseSkyFrame = frameSet.getFrame(self.baseInd).className == "SkyFrame"
        self.isCurrSkyFrame = frameSet.getFrame(self.currInd).className == "SkyFrame"
def makeSipPolyMapCoeffs(metadata, name):
    """Return a list of ast.PolyMap coefficients for the specified SIP matrix
    The returned list of coefficients for an ast.PolyMap
    that computes the following function:
        f(dxy) = dxy + sipPolynomial(dxy))
        where dxy = pixelPosition - pixelOrigin
        and sipPolynomial is a polynomial with terms `<name>n_m for x^n y^m`
            (e.g. A2_0 is the coefficient for x^2 y^0)
    Parameters
    ----------
    metadata : lsst.daf.base.PropertySet
        FITS metadata describing a WCS with the specified SIP coefficients
    name : str
        The desired SIP terms: one of A, B, AP, BP
    Returns
    -------
    list
        A list of coefficients for an ast.PolyMap that computes
        the specified SIP polynomial, including a term for out = in
    Note
    ----
    This is an internal function for use by makeSipIwcToPixel
    and makeSipPixelToIwc
    """
    outAxisDict = dict(A=1, B=2, AP=1, BP=2)
    outAxis = outAxisDict.get(name)
    if outAxis is None:
        raise RuntimeError("%s not a supported SIP name" % (name,))
    width = metadata.getAsInt("%s_ORDER" % (name,)) + 1
    found = False
    # start with a term for out = in
    coeffs = []
    if outAxis == 1:
        coeffs.append([1.0, outAxis, 1, 0])
    else:
        coeffs.append([1.0, outAxis, 0, 1])
    # add SIP distortion terms
    for xPower in range(width):
        for yPower in range(width):
            coeffName = "%s_%s_%s" % (name, xPower, yPower)
            if not metadata.exists(coeffName):
                continue
            found = True
            coeff = metadata.getAsDouble(coeffName)
            coeffs.append([coeff, outAxis, xPower, yPower])
    if not found:
        raise RuntimeError("No %s coefficients found" % (name,))
    return coeffs
[docs]def makeSipIwcToPixel(metadata):
    """Make an IWC to pixel transform with SIP distortion from FITS-WCS metadata
    This function is primarily intended for unit tests.
    IWC is intermediate world coordinates, as described in the FITS papers.
    Parameters
    ----------
    metadata : lsst.daf.base.PropertySet
        FITS metadata describing a WCS with inverse SIP coefficients
    Returns
    -------
    lsst.afw.geom.TransformPoint2ToPoint2
        Transform from IWC position to pixel position (zero-based)
        in the forward direction. The inverse direction is not defined.
    Notes
    -----
    The inverse SIP terms APn_m, BPn_m are polynomial coefficients x^n y^m
    for computing transformed x, y respectively. If we call the resulting
    polynomial inverseSipPolynomial, the returned transformation is:
        pixelPosition = pixel origin + uv + inverseSipPolynomial(uv)
        where uv = inverseCdMatrix * iwcPosition
    """
    crpix = (metadata.get("CRPIX1") - 1, metadata.get("CRPIX2") - 1)
    pixelRelativeToAbsoluteMap = ast.ShiftMap(crpix)
    cdMatrix = getCdMatrixFromMetadata(metadata)
    cdMatrixMap = ast.MatrixMap(cdMatrix.copy())
    coeffList = makeSipPolyMapCoeffs(metadata, "AP") + makeSipPolyMapCoeffs(metadata, "BP")
    coeffArr = np.array(coeffList, dtype=float)
    sipPolyMap = ast.PolyMap(coeffArr, 2, "IterInverse=0")
    iwcToPixelMap = cdMatrixMap.getInverse().then(sipPolyMap).then(pixelRelativeToAbsoluteMap)
    return afwGeom.TransformPoint2ToPoint2(iwcToPixelMap) 
[docs]def makeSipPixelToIwc(metadata):
    """Make a pixel to IWC transform with SIP distortion from FITS-WCS metadata
    This function is primarily intended for unit tests.
    IWC is intermediate world coordinates, as described in the FITS papers.
    Parameters
    ----------
    metadata : lsst.daf.base.PropertySet
        FITS metadata describing a WCS with forward SIP coefficients
    Returns
    -------
    lsst.afw.geom.TransformPoint2ToPoint2
        Transform from pixel position (zero-based) to IWC position
        in the forward direction. The inverse direction is not defined.
    Notes
    -----
    The forward SIP terms An_m, Bn_m are polynomial coefficients x^n y^m
    for computing transformed x, y respectively. If we call the resulting
    polynomial sipPolynomial, the returned transformation is:
        iwcPosition = cdMatrix * (dxy + sipPolynomial(dxy))
        where dxy = pixelPosition - pixelOrigin
    """
    crpix = (metadata.get("CRPIX1") - 1, metadata.get("CRPIX2") - 1)
    pixelAbsoluteToRelativeMap = ast.ShiftMap(crpix).getInverse()
    cdMatrix = getCdMatrixFromMetadata(metadata)
    cdMatrixMap = ast.MatrixMap(cdMatrix.copy())
    coeffList = makeSipPolyMapCoeffs(metadata, "A") + makeSipPolyMapCoeffs(metadata, "B")
    coeffArr = np.array(coeffList, dtype=float)
    sipPolyMap = ast.PolyMap(coeffArr, 2, "IterInverse=0")
    pixelToIwcMap = pixelAbsoluteToRelativeMap.then(sipPolyMap).then(cdMatrixMap)
    return afwGeom.TransformPoint2ToPoint2(pixelToIwcMap) 
class PermutedFrameSet:
    """A FrameSet with base or current frame possibly permuted, with associated
    information
    Only two-axis frames will be permuted.
    Parameters
    ----------
    frameSet : `ast.FrameSet`
        The FrameSet you wish to permute. A deep copy is made.
    permuteBase : `bool`
        Permute the base frame's axes?
    permuteCurr : `bool`
        Permute the current frame's axes?
    Raises
    ------
    RuntimeError
        If you try to permute a frame that does not have 2 axes
    Notes
    -----
    **Fields**
    frameSet : `ast.FrameSet`
        The FrameSet that may be permuted. A local copy is made.
    isBaseSkyFrame : `bool`
        Is the base frame an `ast.SkyFrame`?
    isCurrSkyFrame : `bool`
        Is the current frame an `ast.SkyFrame`?
    isBasePermuted : `bool`
        Are the base frame axes permuted?
    isCurrPermuted : `bool`
        Are the current frame axes permuted?
    """
    def __init__(self, frameSet, permuteBase, permuteCurr):
        self.frameSet = frameSet.copy()
        fsInfo = FrameSetInfo(self.frameSet)
        self.isBaseSkyFrame = fsInfo.isBaseSkyFrame
        self.isCurrSkyFrame = fsInfo.isCurrSkyFrame
        if permuteBase:
            baseNAxes = self.frameSet.getFrame(fsInfo.baseInd).nAxes
            if baseNAxes != 2:
                raise RuntimeError("Base frame has {} axes; 2 required to permute".format(baseNAxes))
            self.frameSet.current = fsInfo.baseInd
            self.frameSet.permAxes([2, 1])
            self.frameSet.current = fsInfo.currInd
        if permuteCurr:
            currNAxes = self.frameSet.getFrame(fsInfo.currInd).nAxes
            if currNAxes != 2:
                raise RuntimeError("Current frame has {} axes; 2 required to permute".format(currNAxes))
            assert self.frameSet.getFrame(fsInfo.currInd).nAxes == 2
            self.frameSet.permAxes([2, 1])
        self.isBasePermuted = permuteBase
        self.isCurrPermuted = permuteCurr
class TransformTestBaseClass(lsst.utils.tests.TestCase):
    """Base class for unit tests of Transform<X>To<Y>
    Subclasses must call `TransformTestBaseClass.setUp(self)`
    if they provide their own version.
    If a package other than afw uses this class then it must
    override the `getTestDir` method to avoid writing into
    afw's test directory.
    """
    def getTestDir(self):
        """Return a directory where temporary test files can be written
        The default implementation returns the test directory of the `afw`
        package.
        If this class is used by a test in a package other than `afw`
        then the subclass must override this method.
        """
        return os.path.join(lsst.utils.getPackageDir("afw"), "tests")
    def setUp(self):
        """Set up a test
        Subclasses should call this method if they override setUp.
        """
        # tell unittest to use the msg argument of asserts as a supplement
        # to the error message, rather than as the whole error message
        self.longMessage = True
        # list of endpoint class name prefixes; the full name is prefix + "Endpoint"
        self.endpointPrefixes = ("Generic", "Point2", "SpherePoint")
        # GoodNAxes is dict of endpoint class name prefix:
        #    tuple containing 0 or more valid numbers of axes
        self.goodNAxes = {
            "Generic": (1, 2, 3, 4),  # all numbers of axes are valid for GenericEndpoint
            "Point2": (2,),
            "SpherePoint": (2,),
        }
        # BadAxes is dict of endpoint class name prefix:
        #    tuple containing 0 or more invalid numbers of axes
        self.badNAxes = {
            "Generic": (),  # all numbers of axes are valid for GenericEndpoint
            "Point2": (1, 3, 4),
            "SpherePoint": (1, 3, 4),
        }
        # Dict of frame index: identity name for frames created by makeFrameSet
        self.frameIdentDict = {
            1: "baseFrame",
            2: "frame2",
            3: "frame3",
            4: "currFrame",
        }
    @staticmethod
    def makeRawArrayData(nPoints, nAxes, delta=0.123):
        """Make an array of generic point data
        The data will be suitable for spherical points
        Parameters
        ----------
        nPoints : `int`
            Number of points in the array
        nAxes : `int`
            Number of axes in the point
        Returns
        -------
        np.array of floats with shape (nAxes, nPoints)
            The values are as follows; if nAxes != 2:
                The first point has values `[0, delta, 2*delta, ..., (nAxes-1)*delta]`
                The Nth point has those values + N
            if nAxes == 2 then the data is scaled so that the max value of axis 1
                is a bit less than pi/2
        """
        delta = 0.123
        # oneAxis = [0, 1, 2, ...nPoints-1]
        oneAxis = np.arange(nPoints, dtype=float)  # [0, 1, 2...]
        # rawData = [oneAxis, oneAxis + delta, oneAxis + 2 delta, ...]
        rawData = np.array([j * delta + oneAxis for j in range(nAxes)], dtype=float)
        if nAxes == 2:
            # scale rawData so that max value of 2nd axis is a bit less than pi/2,
            # thus making the data safe for SpherePoint
            maxLatitude = np.max(rawData[1])
            rawData *= math.pi * 0.4999 / maxLatitude
        return rawData
    @staticmethod
    def makeRawPointData(nAxes, delta=0.123):
        """Make one generic point
        Parameters
        ----------
        nAxes : `int`
            Number of axes in the point
        delta : `float`
            Increment between axis values
        Returns
        -------
        A list of `nAxes` floats with values `[0, delta, ..., (nAxes-1)*delta]
        """
        return [i*delta for i in range(nAxes)]
    @staticmethod
    def makeEndpoint(name, nAxes=None):
        """Make an endpoint
        Parameters
        ----------
        name : `str`
            Endpoint class name prefix; the full class name is name + "Endpoint"
        nAxes : `int` or `None`, optional
            number of axes; an int is required if `name` == "Generic";
            otherwise ignored
        Returns
        -------
        subclass of `lsst.afw.geom.BaseEndpoint`
            The constructed endpoint
        Raises
        ------
        TypeError
            If `name` == "Generic" and `nAxes` is None or <= 0
        """
        EndpointClassName = name + "Endpoint"
        EndpointClass = getattr(afwGeom, EndpointClassName)
        if name == "Generic":
            if nAxes is None:
                raise TypeError("nAxes must be an integer for GenericEndpoint")
            return EndpointClass(nAxes)
        return EndpointClass()
    @classmethod
    def makeGoodFrame(cls, name, nAxes=None):
        """Return the appropriate frame for the given name and nAxes
        Parameters
        ----------
        name : `str`
            Endpoint class name prefix; the full class name is name + "Endpoint"
        nAxes : `int` or `None`, optional
            number of axes; an int is required if `name` == "Generic";
            otherwise ignored
        Returns
        -------
        `ast.Frame`
            The constructed frame
        Raises
        ------
        TypeError
            If `name` == "Generic" and `nAxes` is `None` or <= 0
        """
        return cls.makeEndpoint(name, nAxes).makeFrame()
    @staticmethod
    def makeBadFrames(name):
        """Return a list of 0 or more frames that are not a valid match for the
        named endpoint
        Parameters
        ----------
        name : `str`
            Endpoint class name prefix; the full class name is name + "Endpoint"
        Returns
        -------
        Collection of `ast.Frame`
            A collection of 0 or more frames
        """
        return {
            "Generic": [],
            "Point2": [
                ast.SkyFrame(),
                ast.Frame(1),
                ast.Frame(3),
            ],
            "SpherePoint": [
                ast.Frame(1),
                ast.Frame(2),
                ast.Frame(3),
            ],
        }[name]
    def makeFrameSet(self, baseFrame, currFrame):
        """Make a FrameSet
        The FrameSet will contain 4 frames and three transforms connecting them.
        The idenity of each frame is provided by self.frameIdentDict
        Frame       Index   Mapping from this frame to the next
        `baseFrame`   1     `ast.UnitMap(nIn)`
        Frame(nIn)    2     `polyMap`
        Frame(nOut)   3     `ast.UnitMap(nOut)`
        `currFrame`   4
        where:
        - `nIn` = `baseFrame.nAxes`
        - `nOut` = `currFrame.nAxes`
        - `polyMap` = `makeTwoWayPolyMap(nIn, nOut)`
        Returns
        ------
        `ast.FrameSet`
            The FrameSet as described above
        Parameters
        ----------
        baseFrame : `ast.Frame`
            base frame
        currFrame : `ast.Frame`
            current frame
        """
        nIn = baseFrame.nAxes
        nOut = currFrame.nAxes
        polyMap = makeTwoWayPolyMap(nIn, nOut)
        # The only way to set the Ident of a frame in a FrameSet is to set it in advance,
        # and I don't want to modify the inputs, so replace the input frames with copies
        baseFrame = baseFrame.copy()
        baseFrame.ident = self.frameIdentDict[1]
        currFrame = currFrame.copy()
        currFrame.ident = self.frameIdentDict[4]
        frameSet = ast.FrameSet(baseFrame)
        frame2 = ast.Frame(nIn)
        frame2.ident = self.frameIdentDict[2]
        frameSet.addFrame(ast.FrameSet.CURRENT, ast.UnitMap(nIn), frame2)
        frame3 = ast.Frame(nOut)
        frame3.ident = self.frameIdentDict[3]
        frameSet.addFrame(ast.FrameSet.CURRENT, polyMap, frame3)
        frameSet.addFrame(ast.FrameSet.CURRENT, ast.UnitMap(nOut), currFrame)
        return frameSet
    @staticmethod
    def permuteFrameSetIter(frameSet):
        """Iterator over 0 or more frameSets with SkyFrames axes permuted
        Only base and current SkyFrames are permuted. If neither the base nor
        the current frame is a SkyFrame then no frames are returned.
        Returns
        -------
        iterator over `PermutedFrameSet`
        """
        fsInfo = FrameSetInfo(frameSet)
        if not (fsInfo.isBaseSkyFrame or fsInfo.isCurrSkyFrame):
            return
        permuteBaseList = [False, True] if fsInfo.isBaseSkyFrame else [False]
        permuteCurrList = [False, True] if fsInfo.isCurrSkyFrame else [False]
        for permuteBase in permuteBaseList:
            for permuteCurr in permuteCurrList:
                yield PermutedFrameSet(frameSet, permuteBase, permuteCurr)
    @staticmethod
    def makeJacobian(nIn, nOut, inPoint):
        """Make a Jacobian matrix for the equation described by
        `makeTwoWayPolyMap`.
        Parameters
        ----------
        nIn, nOut : `int`
            the dimensions of the input and output data; see makeTwoWayPolyMap
        inPoint : `numpy.ndarray`
            an array of size `nIn` representing the point at which the Jacobian
            is measured
        Returns
        -------
        J : `numpy.ndarray`
            an `nOut` x `nIn` array of first derivatives
        """
        basePolyMapCoeff = 0.001  # see makeTwoWayPolyMap
        baseCoeff = 2.0 * basePolyMapCoeff
        coeffs = np.empty((nOut, nIn))
        for iOut in range(nOut):
            coeffOffset = baseCoeff * iOut
            for iIn in range(nIn):
                coeffs[iOut, iIn] = baseCoeff * (iIn + 1) + coeffOffset
                coeffs[iOut, iIn] *= inPoint[iIn]
        assert coeffs.ndim == 2
        # Avoid spurious errors when comparing to a simplified array
        assert coeffs.shape == (nOut, nIn)
        return coeffs
    def checkTransformation(self, transform, mapping, msg=""):
        """Check applyForward and applyInverse for a transform
        Parameters
        ----------
        transform : `lsst.afw.geom.Transform`
            The transform to check
        mapping : `ast.Mapping`
            The mapping the transform should use. This mapping
            must contain valid forward or inverse transformations,
            but they need not match if both present. Hence the
            mappings returned by make*PolyMap are acceptable.
        msg : `str`
            Error message suffix describing test parameters
        """
        fromEndpoint = transform.fromEndpoint
        toEndpoint = transform.toEndpoint
        mappingFromTransform = transform.getMapping()
        nIn = mapping.nIn
        nOut = mapping.nOut
        self.assertEqual(nIn, fromEndpoint.nAxes, msg=msg)
        self.assertEqual(nOut, toEndpoint.nAxes, msg=msg)
        # forward transformation of one point
        rawInPoint = self.makeRawPointData(nIn)
        inPoint = fromEndpoint.pointFromData(rawInPoint)
        # forward transformation of an array of points
        nPoints = 7  # arbitrary
        rawInArray = self.makeRawArrayData(nPoints, nIn)
        inArray = fromEndpoint.arrayFromData(rawInArray)
        if mapping.hasForward:
            self.assertTrue(transform.hasForward)
            outPoint = transform.applyForward(inPoint)
            rawOutPoint = toEndpoint.dataFromPoint(outPoint)
            assert_allclose(rawOutPoint, mapping.applyForward(rawInPoint), err_msg=msg)
            assert_allclose(rawOutPoint, mappingFromTransform.applyForward(rawInPoint), err_msg=msg)
            outArray = transform.applyForward(inArray)
            rawOutArray = toEndpoint.dataFromArray(outArray)
            self.assertFloatsAlmostEqual(rawOutArray, mapping.applyForward(rawInArray), msg=msg)
            self.assertFloatsAlmostEqual(rawOutArray, mappingFromTransform.applyForward(rawInArray), msg=msg)
        else:
            # Need outPoint, but don't need it to be consistent with inPoint
            rawOutPoint = self.makeRawPointData(nOut)
            outPoint = toEndpoint.pointFromData(rawOutPoint)
            rawOutArray = self.makeRawArrayData(nPoints, nOut)
            outArray = toEndpoint.arrayFromData(rawOutArray)
            self.assertFalse(transform.hasForward)
        if mapping.hasInverse:
            self.assertTrue(transform.hasInverse)
            # inverse transformation of one point;
            # remember that the inverse need not give the original values
            # (see the description of the `mapping` parameter)
            inversePoint = transform.applyInverse(outPoint)
            rawInversePoint = fromEndpoint.dataFromPoint(inversePoint)
            assert_allclose(rawInversePoint, mapping.applyInverse(rawOutPoint), err_msg=msg)
            assert_allclose(rawInversePoint, mappingFromTransform.applyInverse(rawOutPoint), err_msg=msg)
            # inverse transformation of an array of points;
            # remember that the inverse will not give the original values
            # (see the description of the `mapping` parameter)
            inverseArray = transform.applyInverse(outArray)
            rawInverseArray = fromEndpoint.dataFromArray(inverseArray)
            self.assertFloatsAlmostEqual(rawInverseArray, mapping.applyInverse(rawOutArray), msg=msg)
            self.assertFloatsAlmostEqual(rawInverseArray, mappingFromTransform.applyInverse(rawOutArray),
                                         msg=msg)
        else:
            self.assertFalse(transform.hasInverse)
    def checkInverseTransformation(self, forward, inverse, msg=""):
        """Check that two Transforms are each others' inverses.
        Parameters
        ----------
        forward : `lsst.afw.geom.Transform`
            the reference Transform to test
        inverse : `lsst.afw.geom.Transform`
            the transform that should be the inverse of `forward`
        msg : `str`
            error message suffix describing test parameters
        """
        fromEndpoint = forward.fromEndpoint
        toEndpoint = forward.toEndpoint
        forwardMapping = forward.getMapping()
        inverseMapping = inverse.getMapping()
        # properties
        self.assertEqual(forward.fromEndpoint,
                         inverse.toEndpoint, msg=msg)
        self.assertEqual(forward.toEndpoint,
                         inverse.fromEndpoint, msg=msg)
        self.assertEqual(forward.hasForward, inverse.hasInverse, msg=msg)
        self.assertEqual(forward.hasInverse, inverse.hasForward, msg=msg)
        # transformations of one point
        # we don't care about whether the transformation itself is correct
        # (see checkTransformation), so inPoint/outPoint need not be related
        rawInPoint = self.makeRawPointData(fromEndpoint.nAxes)
        inPoint = fromEndpoint.pointFromData(rawInPoint)
        rawOutPoint = self.makeRawPointData(toEndpoint.nAxes)
        outPoint = toEndpoint.pointFromData(rawOutPoint)
        # transformations of arrays of points
        nPoints = 7  # arbitrary
        rawInArray = self.makeRawArrayData(nPoints, fromEndpoint.nAxes)
        inArray = fromEndpoint.arrayFromData(rawInArray)
        rawOutArray = self.makeRawArrayData(nPoints, toEndpoint.nAxes)
        outArray = toEndpoint.arrayFromData(rawOutArray)
        if forward.hasForward:
            self.assertEqual(forward.applyForward(inPoint),
                             inverse.applyInverse(inPoint), msg=msg)
            self.assertEqual(forwardMapping.applyForward(rawInPoint),
                             inverseMapping.applyInverse(rawInPoint), msg=msg)
            # Assertions must work with both lists and numpy arrays
            assert_array_equal(forward.applyForward(inArray),
                               inverse.applyInverse(inArray),
                               err_msg=msg)
            assert_array_equal(forwardMapping.applyForward(rawInArray),
                               inverseMapping.applyInverse(rawInArray),
                               err_msg=msg)
        if forward.hasInverse:
            self.assertEqual(forward.applyInverse(outPoint),
                             inverse.applyForward(outPoint), msg=msg)
            self.assertEqual(forwardMapping.applyInverse(rawOutPoint),
                             inverseMapping.applyForward(rawOutPoint), msg=msg)
            assert_array_equal(forward.applyInverse(outArray),
                               inverse.applyForward(outArray),
                               err_msg=msg)
            assert_array_equal(forwardMapping.applyInverse(rawOutArray),
                               inverseMapping.applyForward(rawOutArray),
                               err_msg=msg)
    def checkTransformFromMapping(self, fromName, toName):
        """Check Transform_<fromName>_<toName> using the Mapping constructor
        Parameters
        ----------
        fromName, toName : `str`
            Endpoint name prefix for "from" and "to" endpoints, respectively,
            e.g. "Point2" for `lsst.afw.geom.Point2Endpoint`
        fromAxes, toAxes : `int`
            number of axes in fromFrame and toFrame, respectively
        """
        transformClassName = "Transform{}To{}".format(fromName, toName)
        TransformClass = getattr(afwGeom, transformClassName)
        baseMsg = "TransformClass={}".format(TransformClass.__name__)
        # check valid numbers of inputs and outputs
        for nIn, nOut in itertools.product(self.goodNAxes[fromName],
                                           self.goodNAxes[toName]):
            msg = "{}, nIn={}, nOut={}".format(baseMsg, nIn, nOut)
            polyMap = makeTwoWayPolyMap(nIn, nOut)
            transform = TransformClass(polyMap)
            # desired output from `str(transform)`
            desStr = "{}[{}->{}]".format(transformClassName, nIn, nOut)
            self.assertEqual("{}".format(transform), desStr)
            self.assertEqual(repr(transform), "lsst.afw.geom." + desStr)
            self.checkTransformation(transform, polyMap, msg=msg)
            # Forward transform but no inverse
            polyMap = makeForwardPolyMap(nIn, nOut)
            transform = TransformClass(polyMap)
            self.checkTransformation(transform, polyMap, msg=msg)
            # Inverse transform but no forward
            polyMap = makeForwardPolyMap(nOut, nIn).getInverse()
            transform = TransformClass(polyMap)
            self.checkTransformation(transform, polyMap, msg=msg)
        # check invalid # of output against valid # of inputs
        for nIn, badNOut in itertools.product(self.goodNAxes[fromName],
                                              self.badNAxes[toName]):
            badPolyMap = makeTwoWayPolyMap(nIn, badNOut)
            msg = "{}, nIn={}, badNOut={}".format(baseMsg, nIn, badNOut)
            with self.assertRaises(InvalidParameterError, msg=msg):
                TransformClass(badPolyMap)
        # check invalid # of inputs against valid and invalid # of outputs
        for badNIn, nOut in itertools.product(self.badNAxes[fromName],
                                              self.goodNAxes[toName] + self.badNAxes[toName]):
                badPolyMap = makeTwoWayPolyMap(badNIn, nOut)
                msg = "{}, badNIn={}, nOut={}".format(baseMsg, nIn, nOut)
                with self.assertRaises(InvalidParameterError, msg=msg):
                    TransformClass(badPolyMap)
    def checkTransformFromFrameSet(self, fromName, toName):
        """Check Transform_<fromName>_<toName> using the FrameSet constructor
        Parameters
        ----------
        fromName, toName : `str`
            Endpoint name prefix for "from" and "to" endpoints, respectively,
            e.g. "Point2" for `lsst.afw.geom.Point2Endpoint`
        """
        transformClassName = "Transform{}To{}".format(fromName, toName)
        TransformClass = getattr(afwGeom, transformClassName)
        baseMsg = "TransformClass={}".format(TransformClass.__name__)
        for nIn, nOut in itertools.product(self.goodNAxes[fromName],
                                           self.goodNAxes[toName]):
            msg = "{}, nIn={}, nOut={}".format(baseMsg, nIn, nOut)
            baseFrame = self.makeGoodFrame(fromName, nIn)
            currFrame = self.makeGoodFrame(toName, nOut)
            frameSet = self.makeFrameSet(baseFrame, currFrame)
            self.assertEqual(frameSet.nFrame, 4)
            # construct 0 or more frame sets that are invalid for this transform class
            for badBaseFrame in self.makeBadFrames(fromName):
                badFrameSet = self.makeFrameSet(badBaseFrame, currFrame)
                with self.assertRaises(InvalidParameterError):
                    TransformClass(badFrameSet)
                for badCurrFrame in self.makeBadFrames(toName):
                    reallyBadFrameSet = self.makeFrameSet(badBaseFrame, badCurrFrame)
                    with self.assertRaises(InvalidParameterError):
                        TransformClass(reallyBadFrameSet)
            for badCurrFrame in self.makeBadFrames(toName):
                badFrameSet = self.makeFrameSet(baseFrame, badCurrFrame)
                with self.assertRaises(InvalidParameterError):
                    TransformClass(badFrameSet)
            transform = TransformClass(frameSet)
            desStr = "{}[{}->{}]".format(transformClassName, nIn, nOut)
            self.assertEqual("{}".format(transform), desStr)
            self.assertEqual(repr(transform), "lsst.afw.geom." + desStr)
            self.checkPersistence(transform)
            mappingFromTransform = transform.getMapping()
            transformCopy = TransformClass(mappingFromTransform)
            self.assertEqual(type(transform), type(transformCopy))
            self.assertEqual(transform.getMapping(), mappingFromTransform)
            polyMap = makeTwoWayPolyMap(nIn, nOut)
            self.checkTransformation(transform, mapping=polyMap, msg=msg)
            # If the base and/or current frame of frameSet is a SkyFrame,
            # try permuting that frame (in place, so the connected mappings are
            # correctly updated). The Transform constructor should undo the permutation,
            # (via SpherePointEndpoint.normalizeFrame) in its internal copy of frameSet,
            # forcing the axes of the SkyFrame into standard (longitude, latitude) order
            for permutedFS in self.permuteFrameSetIter(frameSet):
                if permutedFS.isBaseSkyFrame:
                    baseFrame = permutedFS.frameSet.getFrame(ast.FrameSet.BASE)
                    # desired base longitude axis
                    desBaseLonAxis = 2 if permutedFS.isBasePermuted else 1
                    self.assertEqual(baseFrame.lonAxis, desBaseLonAxis)
                if permutedFS.isCurrSkyFrame:
                    currFrame = permutedFS.frameSet.getFrame(ast.FrameSet.CURRENT)
                    # desired current base longitude axis
                    desCurrLonAxis = 2 if permutedFS.isCurrPermuted else 1
                    self.assertEqual(currFrame.lonAxis, desCurrLonAxis)
                permTransform = TransformClass(permutedFS.frameSet)
                self.checkTransformation(permTransform, mapping=polyMap, msg=msg)
    def checkGetInverse(self, fromName, toName):
        """Test Transform<fromName>To<toName>.getInverse
        Parameters
        ----------
        fromName, toName : `str`
            Endpoint name prefix for "from" and "to" endpoints, respectively,
            e.g. "Point2" for `lsst.afw.geom.Point2Endpoint`
        """
        transformClassName = "Transform{}To{}".format(fromName, toName)
        TransformClass = getattr(afwGeom, transformClassName)
        baseMsg = "TransformClass={}".format(TransformClass.__name__)
        for nIn, nOut in itertools.product(self.goodNAxes[fromName],
                                           self.goodNAxes[toName]):
            msg = "{}, nIn={}, nOut={}".format(baseMsg, nIn, nOut)
            self.checkInverseMapping(
                TransformClass,
                makeTwoWayPolyMap(nIn, nOut),
                "{}, Map={}".format(msg, "TwoWay"))
            self.checkInverseMapping(
                TransformClass,
                makeForwardPolyMap(nIn, nOut),
                "{}, Map={}".format(msg, "Forward"))
            self.checkInverseMapping(
                TransformClass,
                makeForwardPolyMap(nOut, nIn).getInverse(),
                "{}, Map={}".format(msg, "Inverse"))
    def checkInverseMapping(self, TransformClass, mapping, msg):
        """Test Transform<fromName>To<toName>.getInverse for a specific
        mapping.
        Parameters
        ----------
        TransformClass : `type`
            The class of transform to test, such as TransformPoint2ToPoint2
        mapping : `ast.Mapping`
            The mapping to use for the transform
        msg : `str`
            Error message suffix
        """
        transform = TransformClass(mapping)
        inverse = transform.getInverse()
        inverseInverse = inverse.getInverse()
        self.checkInverseTransformation(transform, inverse, msg=msg)
        self.checkInverseTransformation(inverse, inverseInverse, msg=msg)
        self.checkTransformation(inverseInverse, mapping, msg=msg)
    def checkGetJacobian(self, fromName, toName):
        """Test Transform<fromName>To<toName>.getJacobian
        Parameters
        ----------
        fromName, toName : `str`
            Endpoint name prefix for "from" and "to" endpoints, respectively,
            e.g. "Point2" for `lsst.afw.geom.Point2Endpoint`
        """
        transformClassName = "Transform{}To{}".format(fromName, toName)
        TransformClass = getattr(afwGeom, transformClassName)
        baseMsg = "TransformClass={}".format(TransformClass.__name__)
        for nIn, nOut in itertools.product(self.goodNAxes[fromName],
                                           self.goodNAxes[toName]):
            msg = "{}, nIn={}, nOut={}".format(baseMsg, nIn, nOut)
            polyMap = makeForwardPolyMap(nIn, nOut)
            transform = TransformClass(polyMap)
            fromEndpoint = transform.fromEndpoint
            # Test multiple points to ensure correct functional form
            rawInPoint = self.makeRawPointData(nIn)
            inPoint = fromEndpoint.pointFromData(rawInPoint)
            jacobian = transform.getJacobian(inPoint)
            assert_allclose(jacobian, self.makeJacobian(nIn, nOut, rawInPoint),
                            err_msg=msg)
            rawInPoint = self.makeRawPointData(nIn, 0.111)
            inPoint = fromEndpoint.pointFromData(rawInPoint)
            jacobian = transform.getJacobian(inPoint)
            assert_allclose(jacobian, self.makeJacobian(nIn, nOut, rawInPoint),
                            err_msg=msg)
    def checkThen(self, fromName, midName, toName):
        """Test Transform<fromName>To<midName>.then(Transform<midName>To<toName>)
        Parameters
        ----------
        fromName : `str`
            the prefix of the starting endpoint (e.g., "Point2" for a
            Point2Endpoint) for the final, concatenated Transform
        midName : `str`
            the prefix for the shared endpoint where two Transforms will be
            concatenated
        toName : `str`
            the prefix of the ending endpoint for the final, concatenated
            Transform
        """
        TransformClass1 = getattr(afwGeom,
                                  "Transform{}To{}".format(fromName, midName))
        TransformClass2 = getattr(afwGeom,
                                  "Transform{}To{}".format(midName, toName))
        baseMsg = "{}.then({})".format(TransformClass1.__name__,
                                       TransformClass2.__name__)
        for nIn, nMid, nOut in itertools.product(self.goodNAxes[fromName],
                                                 self.goodNAxes[midName],
                                                 self.goodNAxes[toName]):
            msg = "{}, nIn={}, nMid={}, nOut={}".format(
                baseMsg, nIn, nMid, nOut)
            polyMap1 = makeTwoWayPolyMap(nIn, nMid)
            transform1 = TransformClass1(polyMap1)
            polyMap2 = makeTwoWayPolyMap(nMid, nOut)
            transform2 = TransformClass2(polyMap2)
            transform = transform1.then(transform2)
            fromEndpoint = transform1.fromEndpoint
            toEndpoint = transform2.toEndpoint
            inPoint = fromEndpoint.pointFromData(self.makeRawPointData(nIn))
            outPointMerged = transform.applyForward(inPoint)
            outPointSeparate = transform2.applyForward(
                transform1.applyForward(inPoint))
            assert_allclose(toEndpoint.dataFromPoint(outPointMerged),
                            toEndpoint.dataFromPoint(outPointSeparate),
                            err_msg=msg)
            outPoint = toEndpoint.pointFromData(self.makeRawPointData(nOut))
            inPointMerged = transform.applyInverse(outPoint)
            inPointSeparate = transform1.applyInverse(
                transform2.applyInverse(outPoint))
            assert_allclose(
                fromEndpoint.dataFromPoint(inPointMerged),
                fromEndpoint.dataFromPoint(inPointSeparate),
                err_msg=msg)
        # Mismatched number of axes should fail
        if midName == "Generic":
            nIn = self.goodNAxes[fromName][0]
            nOut = self.goodNAxes[toName][0]
            polyMap = makeTwoWayPolyMap(nIn, 3)
            transform1 = TransformClass1(polyMap)
            polyMap = makeTwoWayPolyMap(2, nOut)
            transform2 = TransformClass2(polyMap)
            with self.assertRaises(InvalidParameterError):
                transform = transform1.then(transform2)
        # Mismatched types of endpoints should fail
        if fromName != midName:
            # Use TransformClass1 for both args to keep test logic simple
            outName = midName
            joinNAxes = set(self.goodNAxes[fromName]).intersection(
                self.goodNAxes[outName])
            for nIn, nMid, nOut in itertools.product(self.goodNAxes[fromName],
                                                     joinNAxes,
                                                     self.goodNAxes[outName]):
                polyMap = makeTwoWayPolyMap(nIn, nMid)
                transform1 = TransformClass1(polyMap)
                polyMap = makeTwoWayPolyMap(nMid, nOut)
                transform2 = TransformClass1(polyMap)
                with self.assertRaises(InvalidParameterError):
                    transform = transform1.then(transform2)
    def assertTransformsEqual(self, transform1, transform2):
        """Assert that two transforms are equal"""
        self.assertEqual(type(transform1), type(transform2))
        self.assertEqual(transform1.fromEndpoint, transform2.fromEndpoint)
        self.assertEqual(transform1.toEndpoint, transform2.toEndpoint)
        self.assertEqual(transform1.getMapping(), transform2.getMapping())
        fromEndpoint = transform1.fromEndpoint
        toEndpoint = transform1.toEndpoint
        mapping = transform1.getMapping()
        nIn = mapping.nIn
        nOut = mapping.nOut
        if mapping.hasForward:
            nPoints = 7  # arbitrary
            rawInArray = self.makeRawArrayData(nPoints, nIn)
            inArray = fromEndpoint.arrayFromData(rawInArray)
            outArray = transform1.applyForward(inArray)
            outData = toEndpoint.dataFromArray(outArray)
            outArrayRoundTrip = transform2.applyForward(inArray)
            outDataRoundTrip = toEndpoint.dataFromArray(outArrayRoundTrip)
            assert_allclose(outData, outDataRoundTrip)
        if mapping.hasInverse:
            nPoints = 7  # arbitrary
            rawOutArray = self.makeRawArrayData(nPoints, nOut)
            outArray = toEndpoint.arrayFromData(rawOutArray)
            inArray = transform1.applyInverse(outArray)
            inData = fromEndpoint.dataFromArray(inArray)
            inArrayRoundTrip = transform2.applyInverse(outArray)
            inDataRoundTrip = fromEndpoint.dataFromArray(inArrayRoundTrip)
            assert_allclose(inData, inDataRoundTrip)
    def checkPersistence(self, transform):
        """Check persistence of a transform
        """
        className = type(transform).__name__
        # check writeString and readString
        transformStr = transform.writeString()
        serialVersion, serialClassName, serialRest = transformStr.split(" ", 2)
        self.assertEqual(int(serialVersion), 1)
        self.assertEqual(serialClassName, className)
        badStr1 = " ".join(["2", serialClassName, serialRest])
        with self.assertRaises(lsst.pex.exceptions.InvalidParameterError):
            transform.readString(badStr1)
        badClassName = "x" + serialClassName
        badStr2 = " ".join(["1", badClassName, serialRest])
        with self.assertRaises(lsst.pex.exceptions.InvalidParameterError):
            transform.readString(badStr2)
        transformFromStr1 = transform.readString(transformStr)
        self.assertTransformsEqual(transform, transformFromStr1)
        # check transformFromString
        transformFromStr2 = afwGeom.transformFromString(transformStr)
        self.assertTransformsEqual(transform, transformFromStr2)
        # Check pickling
        self.assertTransformsEqual(transform, pickle.loads(pickle.dumps(transform)))
        # Check afw::table::io persistence round-trip
        with lsst.utils.tests.getTempFilePath(".fits") as filename:
            transform.writeFits(filename)
            self.assertTransformsEqual(transform, type(transform).readFits(filename))