diff options
Diffstat (limited to 'lib/test/test_intersection.py')
-rw-r--r-- | lib/test/test_intersection.py | 14 |
1 files changed, 10 insertions, 4 deletions
diff --git a/lib/test/test_intersection.py b/lib/test/test_intersection.py index d25ce2e..cf4c3d5 100644 --- a/lib/test/test_intersection.py +++ b/lib/test/test_intersection.py @@ -29,11 +29,18 @@ class intersection_test: def __call__(self, func): @functools.wraps(func) def run(*args, **kwargs): + if GRAPH_MODE: + from mpl_toolkits.basemap import Basemap + from matplotlib import pyplot as plt + polys = func(*args, **kwargs) intersections = [] num_permutations = math.factorial(len(polys)) step_size = int(max(float(num_permutations) / 20.0, 1.0)) + + areas = [x.area() for x in polys] + if GRAPH_MODE: print("%d permutations" % num_permutations) for method in ('parallel', 'serial'): @@ -48,10 +55,7 @@ class intersection_test: intersection = polygon.SphericalPolygon.multi_intersection( permutation, method=method) intersections.append(intersection) - areas = [x.area() for x in permutation] intersection_area = intersection.area() - assert np.all(intersection_area < areas) - if GRAPH_MODE: fig = plt.figure() m = Basemap(projection=self._proj, @@ -67,10 +71,12 @@ class intersection_test: plt.savefig(filename) fig.clear() + assert np.all(intersection_area < areas) + lengths = np.array([len(x._points) for x in intersections]) assert np.all(lengths == [lengths[0]]) areas = np.array([x.area() for x in intersections]) - assert_array_almost_equal(areas, areas[0], decimal=3) + assert_array_almost_equal(areas, areas[0], decimal=1) return run |