diff options
Diffstat (limited to 'lib/stwcs/wcsutil/hstwcs.py')
-rw-r--r-- | lib/stwcs/wcsutil/hstwcs.py | 120 |
1 files changed, 96 insertions, 24 deletions
diff --git a/lib/stwcs/wcsutil/hstwcs.py b/lib/stwcs/wcsutil/hstwcs.py index 2150059..e6c812d 100644 --- a/lib/stwcs/wcsutil/hstwcs.py +++ b/lib/stwcs/wcsutil/hstwcs.py @@ -56,15 +56,45 @@ def build_default_wcsname(idctab): class NoConvergence(Exception): + """ + An error class used to report non-convergence and/or divergence of + numerical methods. It is used to report errors in the iterative solution + used by the :py:meth:`~stwcs.hstwcs.HSTWCS.all_sky2pix`\ . + + Attributes + ---------- + + best_solution : numpy.array + Best solution achieved by the method. + + accuracy : float + Accuracy of the :py:attr:`best_solution`\ . + + niter : int + Number of iterations performed by the numerical method to compute + :py:attr:`best_solution`\ . + + divergent : None, numpy.array + Indices of the points in :py:attr:`best_solution` array for which the + solution appears to be divergent. If the solution does not diverge, + `divergent` will be set to `None`. + nonconvergent : None, numpy.array + Indices of the points in :py:attr:`best_solution` array for which the + solution failed to converge within the specified maximum number + of iterations. If there are no non-converging poits (i.e., if + the required accuracy has been achieved for all points) then + `nonconvergent` will be set to `None`. + + """ def __init__(self, *args, **kwargs): super(NoConvergence, self).__init__(*args) self.best_solution = kwargs.pop('best_solution', None) - self.error_estimate = kwargs.pop('error_estimate', None) + self.accuracy = kwargs.pop('accuracy', None) self.niter = kwargs.pop('niter', None) - self.divergent = kwargs.pop('divergent', False) - self.offenders = kwargs.pop('offenders', None) + self.divergent = kwargs.pop('divergent', None) + self.nonconvergent = kwargs.pop('nonconvergent', None) # @@ -414,16 +444,16 @@ class HSTWCS(WCS): def pc2cd(self): self.wcs.cd = self.wcs.pc.copy() - def all_sky2pix(self,*args, **kwargs): + def all_sky2pix(self, *args, **kwargs): """ - all_sky2pix(*arg, accuracy=1.0e-3, maxiter=20, adaptive=False, quiet=False) + all_sky2pix(*arg, accuracy=1.0e-4, maxiter=20, adaptive=False, quiet=False) Performs full inverse transformation using iterative solution on full forward transformation with complete distortion model. Parameters ---------- - accuracy : float, optional (Default = 1.0e-3) + accuracy : float, optional (Default = 1.0e-4) Required accuracy of the solution. maxiter : int, optional (Default = 20) @@ -448,14 +478,15 @@ class HSTWCS(WCS): converged to the required accuracy. However, for the HST's ACS/WFC detector, which has the strongest distortions of all HST instruments, testing has shown that enabling this option - would lead to a 30-50\% penalty in computational time. + would lead to a 10-30\% penalty in computational time. Therefore, for HST instruments, it is recommended to set `adaptive` = `False`\ . quiet : bool, optional (Default = False) - Do not throw exceptions when the method does not converge to a - solution with the required accuracy within a specified number - of maximum iterations set by `maxiter` parameter. + Do not throw :py:class:`NoConvergence` exceptions when the method + does not converge to a solution with the required accuracy + within a specified number of maximum iterations set by `maxiter` + parameter. Instead, simply return the found solution. Raises ------ @@ -540,7 +571,7 @@ accuracy after 3 iterations. .format(nargs)) # process optional arguments: - accuracy = kwargs.pop('accuracy', 1.0e-3) + accuracy = kwargs.pop('accuracy', 1.0e-4) maxiter = kwargs.pop('maxiter', 20) quiet = kwargs.pop('quiet', False) adaptive = kwargs.pop('adaptive', False) @@ -565,8 +596,8 @@ accuracy after 3 iterations. # initial correction: dx, dy = self.pix2foc(x, y, origin) - # If pix2foc does not apply all distortion corrections - # then replace the above line with: + # If pix2foc does not apply all the required distortion + # corrections then replace the above line with: #r0, d0 = self.all_pix2sky(x, y, origin) #dx, dy = self.wcs_sky2pix(r0, d0, origin ) dx -= x0 @@ -576,23 +607,37 @@ accuracy after 3 iterations. x -= dx y -= dy + # norn (L2) squared of the correction: + dn2prev = dx**2+dy**2 + dn2 = dn2prev + # process all coordinates simultaneously: iterlist = range(1, maxiter+1) accuracy2 = accuracy**2 ind = None + inddiv = None + + divergent = False if not adaptive: for k in iterlist: # check convergence: - if np.max(dx**2+dy**2) < accuracy2: + if np.max(dn2) < accuracy2: + break + + # check for divergence: + inddiv, = np.where((dn2 > dn2prev) & (dn2 >= accuracy2)) + if inddiv.shape[0] > 0: + divergent = True break # find correction to the previous solution: dx, dy = self.pix2foc(x, y, origin) - # If pix2foc does not apply all distortion corrections - # then replace the above line with: + # If pix2foc does not apply all the required distortion + # corrections then replace the above line with: #r0, d0 = self.all_pix2sky(x, y, origin) #dx, dy = self.wcs_sky2pix(r0, d0, origin ) + dx -= x0 dy -= y0 @@ -600,18 +645,28 @@ accuracy after 3 iterations. x -= dx y -= dy + # update norn (L2) squared of the correction: + dn2prev = dn2.copy() + dn2 = dx**2+dy**2 + else: - ind, = np.where(dx**2+dy**2 >= accuracy2) + ind, = np.where(dn2 >= accuracy2) for k in iterlist: # check convergence: if ind.shape[0] == 0: break + # check for divergence: + inddiv = ind[np.where(dn2[ind] > dn2prev[ind])] + if inddiv.shape[0] > 0: + divergent = True + break + # find correction to the previous solution: dx[ind], dy[ind] = self.pix2foc(x[ind], y[ind], origin) - # If pix2foc does not apply all distortion corrections - # then replace the above line with: + # If pix2foc does not apply all the required distortion + # corrections then replace the above line with: #r0[ind], d0[ind] = self.all_pix2sky(x[ind], y[ind], origin) #dx[ind], dy[ind] = self.wcs_sky2pix(r0[ind], d0[ind], origin ) dx[ind] -= x0[ind] @@ -621,8 +676,12 @@ accuracy after 3 iterations. x[ind] -= dx[ind] y[ind] -= dy[ind] + # update norn (L2) squared of the correction: + dn2prev = dn2.copy() + dn2 = dx**2+dy**2 + # update indices of elements that still need correction: - ind, = np.where(dx**2+dy**2 >= accuracy2) + ind, = np.where(dn2 >= accuracy2) #ind = ind[np.where(dx[ind]**2+dy[ind]**2 >= accuracy2)] if k >= maxiter and not quiet: @@ -636,10 +695,23 @@ accuracy after 3 iterations. if ind is None: ind, = np.where(dx**2+dy**2 >= accuracy2) - raise NoConvergence("'HSTWCS.all_sky2pix' failed to converge " \ - "after {:d} iterations.".format(k), \ - best_solution = sol, error_estimate = err, \ - niter = k, offenders = ind) + if inddiv is None: + inddiv, = np.where(dn2[ind] > dn2prev[ind]) + + if ind.shape[0] == 0: + ind = None + inddiv = None + + elif inddiv.shape[0] == 0: + inddiv = None + + assert(ind is not None or inddiv is not None) # <-- sanity check + + raise NoConvergence("'HSTWCS.all_sky2pix' failed to converge to "\ + "requested accuracy after {:d} iterations." \ + .format(k), best_solution = sol, \ + accuracy = err, niter = k, \ + nonconvergent = ind, divergent = inddiv) if vect1D: return [x, y] |