apply black code formatting

This commit is contained in:
Christoph Rieke
2021-08-30 15:28:51 +02:00
parent 29f57e894b
commit 8312319f51
5 changed files with 323 additions and 194 deletions

View File

@@ -8,8 +8,9 @@ class CurvedText(mtext.Text):
""" """
A text object that follows an arbitrary curve. A text object that follows an arbitrary curve.
""" """
def __init__(self, x, y, text, axes, **kwargs): def __init__(self, x, y, text, axes, **kwargs):
super(CurvedText, self).__init__(x[0],y[0],' ', **kwargs) super(CurvedText, self).__init__(x[0], y[0], " ", **kwargs)
axes.add_artist(self) axes.add_artist(self)
@@ -21,29 +22,28 @@ class CurvedText(mtext.Text):
##creating the text objects ##creating the text objects
self.__Characters = [] self.__Characters = []
for c in text: for c in text:
if c == ' ': if c == " ":
##make this an invisible 'a': ##make this an invisible 'a':
t = mtext.Text(0,0,'a') t = mtext.Text(0, 0, "a")
t.set_alpha(0.0) t.set_alpha(0.0)
else: else:
t = mtext.Text(0,0,c, **kwargs) t = mtext.Text(0, 0, c, **kwargs)
#resetting unnecessary arguments # resetting unnecessary arguments
t.set_ha('center') t.set_ha("center")
t.set_rotation(0) t.set_rotation(0)
t.set_zorder(self.__zorder +1) t.set_zorder(self.__zorder + 1)
self.__Characters.append((c,t)) self.__Characters.append((c, t))
axes.add_artist(t) axes.add_artist(t)
##overloading some member functions, to assure correct functionality ##overloading some member functions, to assure correct functionality
##on update ##on update
def set_zorder(self, zorder): def set_zorder(self, zorder):
super(CurvedText, self).set_zorder(zorder) super(CurvedText, self).set_zorder(zorder)
self.__zorder = self.get_zorder() self.__zorder = self.get_zorder()
for c,t in self.__Characters: for c, t in self.__Characters:
t.set_zorder(self.__zorder+1) t.set_zorder(self.__zorder + 1)
def draw(self, renderer, *args, **kwargs): def draw(self, renderer, *args, **kwargs):
""" """
@@ -53,12 +53,12 @@ class CurvedText(mtext.Text):
""" """
self.update_positions(renderer) self.update_positions(renderer)
def update_positions(self,renderer): def update_positions(self, renderer):
""" """
Update positions and rotations of the individual text elements. Update positions and rotations of the individual text elements.
""" """
#preparations # preparations
##determining the aspect ratio: ##determining the aspect ratio:
##from https://stackoverflow.com/a/42014041/2454357 ##from https://stackoverflow.com/a/42014041/2454357
@@ -71,94 +71,98 @@ class CurvedText(mtext.Text):
## Ratio of display units ## Ratio of display units
_, _, w, h = self.axes.get_position().bounds _, _, w, h = self.axes.get_position().bounds
##final aspect ratio ##final aspect ratio
aspect = ((figW * w)/(figH * h))*(ylim[1]-ylim[0])/(xlim[1]-xlim[0]) aspect = ((figW * w) / (figH * h)) * (ylim[1] - ylim[0]) / (xlim[1] - xlim[0])
#points of the curve in figure coordinates: # points of the curve in figure coordinates:
x_fig,y_fig = ( x_fig, y_fig = (
np.array(l) for l in zip(*self.axes.transData.transform([ np.array(l)
(i,j) for i,j in zip(self.__x,self.__y) for l in zip(
])) *self.axes.transData.transform(
[(i, j) for i, j in zip(self.__x, self.__y)]
)
)
) )
#point distances in figure coordinates # point distances in figure coordinates
x_fig_dist = (x_fig[1:]-x_fig[:-1]) x_fig_dist = x_fig[1:] - x_fig[:-1]
y_fig_dist = (y_fig[1:]-y_fig[:-1]) y_fig_dist = y_fig[1:] - y_fig[:-1]
r_fig_dist = np.sqrt(x_fig_dist**2+y_fig_dist**2) r_fig_dist = np.sqrt(x_fig_dist ** 2 + y_fig_dist ** 2)
#arc length in figure coordinates # arc length in figure coordinates
l_fig = np.insert(np.cumsum(r_fig_dist),0,0) l_fig = np.insert(np.cumsum(r_fig_dist), 0, 0)
#angles in figure coordinates # angles in figure coordinates
rads = np.arctan2((y_fig[1:] - y_fig[:-1]),(x_fig[1:] - x_fig[:-1])) rads = np.arctan2((y_fig[1:] - y_fig[:-1]), (x_fig[1:] - x_fig[:-1]))
degs = np.rad2deg(rads) degs = np.rad2deg(rads)
rel_pos = 10 rel_pos = 10
for c,t in self.__Characters: for c, t in self.__Characters:
#finding the width of c: # finding the width of c:
t.set_rotation(0) t.set_rotation(0)
t.set_va('center') t.set_va("center")
bbox1 = t.get_window_extent(renderer=renderer) bbox1 = t.get_window_extent(renderer=renderer)
w = bbox1.width w = bbox1.width
h = bbox1.height h = bbox1.height
#ignore all letters that don't fit: # ignore all letters that don't fit:
if rel_pos+w/2 > l_fig[-1]: if rel_pos + w / 2 > l_fig[-1]:
t.set_alpha(0.0) t.set_alpha(0.0)
rel_pos += w rel_pos += w
continue continue
elif c != ' ': elif c != " ":
t.set_alpha(1.0) t.set_alpha(1.0)
#finding the two data points between which the horizontal # finding the two data points between which the horizontal
#center point of the character will be situated # center point of the character will be situated
#left and right indices: # left and right indices:
il = np.where(rel_pos+w/2 >= l_fig)[0][-1] il = np.where(rel_pos + w / 2 >= l_fig)[0][-1]
ir = np.where(rel_pos+w/2 <= l_fig)[0][0] ir = np.where(rel_pos + w / 2 <= l_fig)[0][0]
#if we exactly hit a data point: # if we exactly hit a data point:
if ir == il: if ir == il:
ir += 1 ir += 1
#how much of the letter width was needed to find il: # how much of the letter width was needed to find il:
used = l_fig[il]-rel_pos used = l_fig[il] - rel_pos
rel_pos = l_fig[il] rel_pos = l_fig[il]
#relative distance between il and ir where the center # relative distance between il and ir where the center
#of the character will be # of the character will be
fraction = (w/2-used)/r_fig_dist[il] fraction = (w / 2 - used) / r_fig_dist[il]
##setting the character position in data coordinates: ##setting the character position in data coordinates:
##interpolate between the two points: ##interpolate between the two points:
x = self.__x[il]+fraction*(self.__x[ir]-self.__x[il]) x = self.__x[il] + fraction * (self.__x[ir] - self.__x[il])
y = self.__y[il]+fraction*(self.__y[ir]-self.__y[il]) y = self.__y[il] + fraction * (self.__y[ir] - self.__y[il])
#getting the offset when setting correct vertical alignment # getting the offset when setting correct vertical alignment
#in data coordinates # in data coordinates
t.set_va(self.get_va()) t.set_va(self.get_va())
bbox2 = t.get_window_extent(renderer=renderer) bbox2 = t.get_window_extent(renderer=renderer)
bbox1d = self.axes.transData.inverted().transform(bbox1) bbox1d = self.axes.transData.inverted().transform(bbox1)
bbox2d = self.axes.transData.inverted().transform(bbox2) bbox2d = self.axes.transData.inverted().transform(bbox2)
dr = np.array(bbox2d[0]-bbox1d[0]) dr = np.array(bbox2d[0] - bbox1d[0])
#the rotation/stretch matrix # the rotation/stretch matrix
rad = rads[il] rad = rads[il]
rot_mat = np.array([ rot_mat = np.array(
[math.cos(rad), math.sin(rad)*aspect], [
[-math.sin(rad)/aspect, math.cos(rad)] [math.cos(rad), math.sin(rad) * aspect],
]) [-math.sin(rad) / aspect, math.cos(rad)],
]
)
##computing the offset vector of the rotated character ##computing the offset vector of the rotated character
drp = np.dot(dr,rot_mat) drp = np.dot(dr, rot_mat)
#setting final position and rotation: # setting final position and rotation:
t.set_position(np.array([x,y])+drp) t.set_position(np.array([x, y]) + drp)
t.set_rotation(degs[il]) t.set_rotation(degs[il])
t.set_va('center') t.set_va("center")
t.set_ha('center') t.set_ha("center")
#updating rel_pos to right edge of character # updating rel_pos to right edge of character
rel_pos += w-used rel_pos += w - used

View File

@@ -19,88 +19,93 @@ from .fetch import get_perimeter, get_layer
def get_hash(key): def get_hash(key):
return frozenset(key.items()) if type(key) == dict else key return frozenset(key.items()) if type(key) == dict else key
# Drawing functions # Drawing functions
def show_palette(palette, description = ''): def show_palette(palette, description=""):
''' """
Helper to display palette in Markdown Helper to display palette in Markdown
''' """
colorboxes = [ colorboxes = [
f'![](https://placehold.it/30x30/{c[1:]}/{c[1:]}?text=)' f"![](https://placehold.it/30x30/{c[1:]}/{c[1:]}?text=)" for c in palette
for c in palette
] ]
display(Markdown((description))) display(Markdown((description)))
display(Markdown(tabulate(pd.DataFrame(colorboxes), showindex = False))) display(Markdown(tabulate(pd.DataFrame(colorboxes), showindex=False)))
def get_patch(shape, **kwargs): def get_patch(shape, **kwargs):
''' """
Convert shapely object to matplotlib patch Convert shapely object to matplotlib patch
''' """
#if type(shape) == Path: # if type(shape) == Path:
# return patches.PathPatch(shape, **kwargs) # return patches.PathPatch(shape, **kwargs)
if type(shape) == Polygon and shape.area > 0: if type(shape) == Polygon and shape.area > 0:
return PolygonPatch(list(zip(*shape.exterior.xy)), **kwargs) return PolygonPatch(list(zip(*shape.exterior.xy)), **kwargs)
else: else:
return None return None
# Plot a single shape # Plot a single shape
def plot_shape(shape, ax, vsketch = None, **kwargs): def plot_shape(shape, ax, vsketch=None, **kwargs):
''' """
Plot shapely object Plot shapely object
''' """
if isinstance(shape, Iterable) and type(shape) != MultiLineString: if isinstance(shape, Iterable) and type(shape) != MultiLineString:
for shape_ in shape: for shape_ in shape:
plot_shape(shape_, ax, vsketch = vsketch, **kwargs) plot_shape(shape_, ax, vsketch=vsketch, **kwargs)
else: else:
if not shape.is_empty: if not shape.is_empty:
if vsketch is None: if vsketch is None:
ax.add_patch(PolygonPatch(shape, **kwargs)) ax.add_patch(PolygonPatch(shape, **kwargs))
else: else:
if ('draw' not in kwargs) or kwargs['draw']: if ("draw" not in kwargs) or kwargs["draw"]:
if 'stroke' in kwargs: if "stroke" in kwargs:
vsketch.stroke(kwargs['stroke']) vsketch.stroke(kwargs["stroke"])
else: else:
vsketch.stroke(1) vsketch.stroke(1)
if 'penWidth' in kwargs: if "penWidth" in kwargs:
vsketch.penWidth(kwargs['penWidth']) vsketch.penWidth(kwargs["penWidth"])
else: else:
vsketch.penWidth(0.3) vsketch.penWidth(0.3)
if 'fill' in kwargs: if "fill" in kwargs:
vsketch.fill(kwargs['fill']) vsketch.fill(kwargs["fill"])
else: else:
vsketch.noFill() vsketch.noFill()
vsketch.geometry(shape) vsketch.geometry(shape)
# Plot a collection of shapes # Plot a collection of shapes
def plot_shapes(shapes, ax, vsketch = None, palette = None, **kwargs): def plot_shapes(shapes, ax, vsketch=None, palette=None, **kwargs):
''' """
Plot collection of shapely objects (optionally, use a color palette) Plot collection of shapely objects (optionally, use a color palette)
''' """
if not isinstance(shapes, Iterable): if not isinstance(shapes, Iterable):
shapes = [shapes] shapes = [shapes]
for shape in shapes: for shape in shapes:
if palette is None: if palette is None:
plot_shape(shape, ax, vsketch = vsketch, **kwargs) plot_shape(shape, ax, vsketch=vsketch, **kwargs)
else: else:
plot_shape(shape, ax, vsketch = vsketch, fc = choice(palette), **kwargs) plot_shape(shape, ax, vsketch=vsketch, fc=choice(palette), **kwargs)
# Parse query (by coordinates, OSMId or name) # Parse query (by coordinates, OSMId or name)
def parse_query(query): def parse_query(query):
if isinstance(query, GeoDataFrame): if isinstance(query, GeoDataFrame):
return 'polygon' return "polygon"
elif isinstance(query, tuple): elif isinstance(query, tuple):
return 'coordinates' return "coordinates"
elif re.match('''[A-Z][0-9]+''', query): elif re.match("""[A-Z][0-9]+""", query):
return 'osmid' return "osmid"
else: else:
return 'address' return "address"
# Apply transformation (translation & scale) to layers # Apply transformation (translation & scale) to layers
def transform(layers, x, y, scale_x, scale_y, rotation): def transform(layers, x, y, scale_x, scale_y, rotation):
@@ -118,38 +123,46 @@ def transform(layers, x, y, scale_x, scale_y, rotation):
layers = dict(zip(k, v)) layers = dict(zip(k, v))
return layers return layers
def draw_text(ax, text, x, y, **kwargs): def draw_text(ax, text, x, y, **kwargs):
ax.text(x, y, text, **kwargs) ax.text(x, y, text, **kwargs)
# Plot # Plot
def plot( def plot(
# Address # Address
query, query,
# Whether to use a backup for the layers # Whether to use a backup for the layers
backup = None, backup=None,
# Custom postprocessing function on layers # Custom postprocessing function on layers
postprocessing = None, postprocessing=None,
# Radius (in case of circular plot) # Radius (in case of circular plot)
radius = None, radius=None,
# Which layers to plot # Which layers to plot
layers = {'perimeter': {}}, layers={"perimeter": {}},
# Drawing params for each layer (matplotlib params such as 'fc', 'ec', 'fill', etc.) # Drawing params for each layer (matplotlib params such as 'fc', 'ec', 'fill', etc.)
drawing_kwargs = {}, drawing_kwargs={},
# OSM Caption parameters # OSM Caption parameters
osm_credit = {}, osm_credit={},
# Figure parameters # Figure parameters
figsize = (10, 10), ax = None, title = None, figsize=(10, 10),
ax=None,
title=None,
# Vsketch parameters # Vsketch parameters
vsketch = None, vsketch=None,
# Transform (translation & scale) params # Transform (translation & scale) params
x = None, y = None, scale_x = None, scale_y = None, rotation = None, x=None,
): y=None,
scale_x=None,
scale_y=None,
rotation=None,
):
# Interpret query # Interpret query
query_mode = parse_query(query) query_mode = parse_query(query)
# Save maximum dilation for later use # Save maximum dilation for later use
dilations = [kwargs['dilate'] for kwargs in layers.values() if 'dilate' in kwargs] dilations = [kwargs["dilate"] for kwargs in layers.values() if "dilate" in kwargs]
max_dilation = max(dilations) if len(dilations) > 0 else 0 max_dilation = max(dilations) if len(dilations) > 0 else 0
#################### ####################
@@ -164,20 +177,20 @@ def plot(
# Define base kwargs # Define base kwargs
if radius: if radius:
base_kwargs = { base_kwargs = {
'point': query if query_mode == 'coordinates' else ox.geocode(query), "point": query if query_mode == "coordinates" else ox.geocode(query),
'radius': radius "radius": radius,
} }
else: else:
base_kwargs = { base_kwargs = {
'perimeter': query if query_mode == 'polygon' else get_perimeter(query, by_osmid = query_mode == 'osmid') "perimeter": query
if query_mode == "polygon"
else get_perimeter(query, by_osmid=query_mode == "osmid")
} }
# Fetch layers # Fetch layers
layers = { layers = {
layer: get_layer( layer: get_layer(
layer, layer, **base_kwargs, **(kwargs if type(kwargs) == dict else {})
**base_kwargs,
**(kwargs if type(kwargs) == dict else {})
) )
for layer, kwargs in layers.items() for layer, kwargs in layers.items()
} }
@@ -196,22 +209,22 @@ def plot(
# Matplot-specific stuff (only run if vsketch mode isn't activated) # Matplot-specific stuff (only run if vsketch mode isn't activated)
if vsketch is None: if vsketch is None:
# Ajust axis # Ajust axis
ax.axis('off') ax.axis("off")
ax.axis('equal') ax.axis("equal")
ax.autoscale() ax.autoscale()
# Plot background # Plot background
if 'background' in drawing_kwargs: if "background" in drawing_kwargs:
geom = scale(box(*layers['perimeter'].bounds), 2, 2) geom = scale(box(*layers["perimeter"].bounds), 2, 2)
if vsketch is None: if vsketch is None:
ax.add_patch(PolygonPatch(geom, **drawing_kwargs['background'])) ax.add_patch(PolygonPatch(geom, **drawing_kwargs["background"]))
else: else:
vsketch.geometry(geom) vsketch.geometry(geom)
# Adjust bounds # Adjust bounds
xmin, ymin, xmax, ymax = layers['perimeter'].buffer(max_dilation).bounds xmin, ymin, xmax, ymax = layers["perimeter"].buffer(max_dilation).bounds
dx, dy = xmax-xmin, ymax-ymin dx, dy = xmax - xmin, ymax - ymin
if vsketch is None: if vsketch is None:
ax.set_xlim(xmin, xmax) ax.set_xlim(xmin, xmax)
ax.set_ylim(ymin, ymax) ax.set_ylim(ymin, ymax)
@@ -219,27 +232,58 @@ def plot(
# Draw layers # Draw layers
for layer, shapes in layers.items(): for layer, shapes in layers.items():
kwargs = drawing_kwargs[layer] if layer in drawing_kwargs else {} kwargs = drawing_kwargs[layer] if layer in drawing_kwargs else {}
if 'hatch_c' in kwargs: if "hatch_c" in kwargs:
# Draw hatched shape # Draw hatched shape
plot_shapes(shapes, ax, vsketch = vsketch, lw = 0, ec = kwargs['hatch_c'], **{k:v for k,v in kwargs.items() if k not in ['lw', 'ec', 'hatch_c']}) plot_shapes(
shapes,
ax,
vsketch=vsketch,
lw=0,
ec=kwargs["hatch_c"],
**{k: v for k, v in kwargs.items() if k not in ["lw", "ec", "hatch_c"]},
)
# Draw shape contour only # Draw shape contour only
plot_shapes(shapes, ax, vsketch = vsketch, fill = False, **{k:v for k,v in kwargs.items() if k not in ['hatch_c', 'hatch', 'fill']}) plot_shapes(
shapes,
ax,
vsketch=vsketch,
fill=False,
**{
k: v
for k, v in kwargs.items()
if k not in ["hatch_c", "hatch", "fill"]
},
)
else: else:
# Draw shape normally # Draw shape normally
plot_shapes(shapes, ax, vsketch = vsketch, **kwargs) plot_shapes(shapes, ax, vsketch=vsketch, **kwargs)
if ((isinstance(osm_credit, dict)) or (osm_credit is True)) and (vsketch is None): if ((isinstance(osm_credit, dict)) or (osm_credit is True)) and (vsketch is None):
x, y = figsize x, y = figsize
d = .8*(x**2+y**2)**.5 d = 0.8 * (x ** 2 + y ** 2) ** 0.5
draw_text( draw_text(
ax, ax,
(osm_credit['text'] if 'text' in osm_credit else 'data © OpenStreetMap contributors\ngithub.com/marceloprates/prettymaps'), (
x = xmin + (osm_credit['x']*dx if 'x' in osm_credit else 0), osm_credit["text"]
y = ymax - 4*d - (osm_credit['y']*dy if 'y' in osm_credit else 0), if "text" in osm_credit
fontfamily = (osm_credit['fontfamily'] if 'fontfamily' in osm_credit else 'Ubuntu Mono'), else "data © OpenStreetMap contributors\ngithub.com/marceloprates/prettymaps"
fontsize = (osm_credit['fontsize']*d if 'fontsize' in osm_credit else d), ),
zorder = (osm_credit['zorder'] if 'zorder' in osm_credit else len(layers)+1), x=xmin + (osm_credit["x"] * dx if "x" in osm_credit else 0),
**{k:v for k,v in osm_credit.items() if k not in ['text', 'x', 'y', 'fontfamily', 'fontsize', 'zorder']} y=ymax - 4 * d - (osm_credit["y"] * dy if "y" in osm_credit else 0),
fontfamily=(
osm_credit["fontfamily"]
if "fontfamily" in osm_credit
else "Ubuntu Mono"
),
fontsize=(osm_credit["fontsize"] * d if "fontsize" in osm_credit else d),
zorder=(
osm_credit["zorder"] if "zorder" in osm_credit else len(layers) + 1
),
**{
k: v
for k, v in osm_credit.items()
if k not in ["text", "x", "y", "fontfamily", "fontsize", "zorder"]
},
) )
# Return perimeter # Return perimeter

View File

@@ -8,39 +8,67 @@ from geopandas import GeoDataFrame
# Compute circular or square boundary given point, radius and crs # Compute circular or square boundary given point, radius and crs
def get_boundary(point, radius, crs, circle = True, dilate = 0): def get_boundary(point, radius, crs, circle=True, dilate=0):
if circle: if circle:
return ox.project_gdf( return (
GeoDataFrame(geometry = [Point(point[::-1])], crs = crs) ox.project_gdf(GeoDataFrame(geometry=[Point(point[::-1])], crs=crs))
).geometry[0].buffer(radius) .geometry[0]
.buffer(radius)
)
else: else:
x, y = np.stack(ox.project_gdf( x, y = np.stack(
GeoDataFrame(geometry = [Point(point[::-1])], crs = crs) ox.project_gdf(GeoDataFrame(geometry=[Point(point[::-1])], crs=crs))
).geometry[0].xy) .geometry[0]
.xy
)
r = radius r = radius
return Polygon([ return Polygon(
(x-r, y-r), (x+r, y-r), (x+r, y+r), (x-r, y+r) [(x - r, y - r), (x + r, y - r), (x + r, y + r), (x - r, y + r)]
]).buffer(dilate) ).buffer(dilate)
# Get perimeter # Get perimeter
def get_perimeter(query, by_osmid = False, **kwargs): def get_perimeter(query, by_osmid=False, **kwargs):
return ox.geocode_to_gdf(query, by_osmid = by_osmid, **kwargs, **{x: kwargs[x] for x in ['circle', 'dilate'] if x in kwargs.keys()}) return ox.geocode_to_gdf(
query,
by_osmid=by_osmid,
**kwargs,
**{x: kwargs[x] for x in ["circle", "dilate"] if x in kwargs.keys()}
)
# Get geometries # Get geometries
def get_geometries(perimeter = None, point = None, radius = None, tags = {}, perimeter_tolerance = 0, union = True, circle = True, dilate = 0): def get_geometries(
perimeter=None,
point=None,
radius=None,
tags={},
perimeter_tolerance=0,
union=True,
circle=True,
dilate=0,
):
if perimeter is not None: if perimeter is not None:
# Boundary defined by polygon (perimeter) # Boundary defined by polygon (perimeter)
geometries = ox.geometries_from_polygon( geometries = ox.geometries_from_polygon(
unary_union(perimeter.geometry).buffer(perimeter_tolerance) if perimeter_tolerance > 0 else unary_union(perimeter.geometry), unary_union(perimeter.geometry).buffer(perimeter_tolerance)
tags = {tags: True} if type(tags) == str else tags if perimeter_tolerance > 0
else unary_union(perimeter.geometry),
tags={tags: True} if type(tags) == str else tags,
) )
perimeter = unary_union(ox.project_gdf(perimeter).geometry) perimeter = unary_union(ox.project_gdf(perimeter).geometry)
elif (point is not None) and (radius is not None): elif (point is not None) and (radius is not None):
# Boundary defined by circle with radius 'radius' around point # Boundary defined by circle with radius 'radius' around point
geometries = ox.geometries_from_point(point, dist = radius+dilate, tags = {tags: True} if type(tags) == str else tags) geometries = ox.geometries_from_point(
perimeter = get_boundary(point, radius, geometries.crs, circle = circle, dilate = dilate) point,
dist=radius + dilate,
tags={tags: True} if type(tags) == str else tags,
)
perimeter = get_boundary(
point, radius, geometries.crs, circle=circle, dilate=dilate
)
# Project GDF # Project GDF
if len(geometries) > 0: if len(geometries) > 0:
@@ -50,82 +78,135 @@ def get_geometries(perimeter = None, point = None, radius = None, tags = {}, per
geometries = geometries.intersection(perimeter) geometries = geometries.intersection(perimeter)
if union: if union:
geometries = unary_union(reduce(lambda x,y: x+y, [ geometries = unary_union(
reduce(
lambda x, y: x + y,
[
[x] if type(x) == Polygon else list(x) [x] if type(x) == Polygon else list(x)
for x in geometries if type(x) in [Polygon, MultiPolygon] for x in geometries
], [])) if type(x) in [Polygon, MultiPolygon]
],
[],
)
)
else: else:
geometries = MultiPolygon(reduce(lambda x,y: x+y, [ geometries = MultiPolygon(
reduce(
lambda x, y: x + y,
[
[x] if type(x) == Polygon else list(x) [x] if type(x) == Polygon else list(x)
for x in geometries if type(x) in [Polygon, MultiPolygon] for x in geometries
], [])) if type(x) in [Polygon, MultiPolygon]
],
[],
)
)
return geometries return geometries
# Get streets
def get_streets(perimeter = None, point = None, radius = None, layer = 'streets', width = 6, custom_filter = None, buffer = 0, retain_all = False, circle = True, dilate = 0):
if layer == 'streets': # Get streets
layer = 'highway' def get_streets(
perimeter=None,
point=None,
radius=None,
layer="streets",
width=6,
custom_filter=None,
buffer=0,
retain_all=False,
circle=True,
dilate=0,
):
if layer == "streets":
layer = "highway"
# Boundary defined by polygon (perimeter) # Boundary defined by polygon (perimeter)
if perimeter is not None: if perimeter is not None:
# Fetch streets data, project & convert to GDF # Fetch streets data, project & convert to GDF
streets = ox.graph_from_polygon(unary_union(perimeter.geometry).buffer(buffer) if buffer > 0 else unary_union(perimeter.geometry), custom_filter = custom_filter) streets = ox.graph_from_polygon(
unary_union(perimeter.geometry).buffer(buffer)
if buffer > 0
else unary_union(perimeter.geometry),
custom_filter=custom_filter,
)
streets = ox.project_graph(streets) streets = ox.project_graph(streets)
streets = ox.graph_to_gdfs(streets, nodes = False) streets = ox.graph_to_gdfs(streets, nodes=False)
# Boundary defined by polygon (perimeter) # Boundary defined by polygon (perimeter)
elif (point is not None) and (radius is not None): elif (point is not None) and (radius is not None):
# Fetch streets data, save CRS & project # Fetch streets data, save CRS & project
streets = ox.graph_from_point(point, dist = radius+dilate+buffer, retain_all = retain_all, custom_filter = custom_filter) streets = ox.graph_from_point(
crs = ox.graph_to_gdfs(streets, nodes = False).crs point,
dist=radius + dilate + buffer,
retain_all=retain_all,
custom_filter=custom_filter,
)
crs = ox.graph_to_gdfs(streets, nodes=False).crs
streets = ox.project_graph(streets) streets = ox.project_graph(streets)
# Compute perimeter from point & CRS # Compute perimeter from point & CRS
perimeter = get_boundary(point, radius, crs, circle = circle, dilate = dilate) perimeter = get_boundary(point, radius, crs, circle=circle, dilate=dilate)
# Convert to GDF # Convert to GDF
streets = ox.graph_to_gdfs(streets, nodes = False) streets = ox.graph_to_gdfs(streets, nodes=False)
# Intersect with perimeter & filter empty elements # Intersect with perimeter & filter empty elements
streets.geometry = streets.geometry.intersection(perimeter) streets.geometry = streets.geometry.intersection(perimeter)
streets = streets[~streets.geometry.is_empty] streets = streets[~streets.geometry.is_empty]
if type(width) == dict: if type(width) == dict:
streets = unary_union([ streets = unary_union(
[
# Dilate streets of each highway type == 'highway' using width 'w' # Dilate streets of each highway type == 'highway' using width 'w'
MultiLineString( MultiLineString(
streets[(streets[layer] == highway) & (streets.geometry.type == 'LineString')].geometry.tolist() + streets[
list(reduce(lambda x, y: x+y, [ (streets[layer] == highway)
& (streets.geometry.type == "LineString")
].geometry.tolist()
+ list(
reduce(
lambda x, y: x + y,
[
list(lines) list(lines)
for lines in streets[(streets[layer] == highway) & (streets.geometry.type == 'MultiLineString')].geometry for lines in streets[
], [])) (streets[layer] == highway)
& (streets.geometry.type == "MultiLineString")
].geometry
],
[],
)
)
).buffer(w) ).buffer(w)
for highway, w in width.items() for highway, w in width.items()
]) ]
)
else: else:
# Dilate all streets by same amount 'width' # Dilate all streets by same amount 'width'
streets = MultiLineString(streets.geometry.tolist()).buffer(width) streets = MultiLineString(streets.geometry.tolist()).buffer(width)
return streets return streets
# Get any layer # Get any layer
def get_layer(layer, **kwargs): def get_layer(layer, **kwargs):
# Fetch perimeter # Fetch perimeter
if layer == 'perimeter': if layer == "perimeter":
# If perimeter is already provided: # If perimeter is already provided:
if 'perimeter' in kwargs: if "perimeter" in kwargs:
return unary_union(ox.project_gdf(kwargs['perimeter']).geometry) return unary_union(ox.project_gdf(kwargs["perimeter"]).geometry)
# If point and radius are provided: # If point and radius are provided:
elif 'point' in kwargs and 'radius' in kwargs: elif "point" in kwargs and "radius" in kwargs:
crs = "EPSG:4326" crs = "EPSG:4326"
perimeter = get_boundary( perimeter = get_boundary(
kwargs['point'], kwargs['radius'], crs, kwargs["point"],
**{x: kwargs[x] for x in ['circle', 'dilate'] if x in kwargs.keys()} kwargs["radius"],
crs,
**{x: kwargs[x] for x in ["circle", "dilate"] if x in kwargs.keys()}
) )
return perimeter return perimeter
else: else:
raise Exception("Either 'perimeter' or 'point' & 'radius' must be provided") raise Exception("Either 'perimeter' or 'point' & 'radius' must be provided")
# Fetch streets or railway # Fetch streets or railway
if layer in ['streets', 'railway', 'waterway']: if layer in ["streets", "railway", "waterway"]:
return get_streets(**kwargs, layer = layer) return get_streets(**kwargs, layer=layer)
# Fetch geometries # Fetch geometries
else: else:
return get_geometries(**kwargs) return get_geometries(**kwargs)

View File

@@ -4,18 +4,18 @@ from pathlib import Path
parent_dir = Path(__file__).resolve().parent parent_dir = Path(__file__).resolve().parent
setup( setup(
name='prettymaps', name="prettymaps",
version='1.0.0', version="1.0.0",
description='A simple python library to draw pretty maps from OpenStreetMap data', description="A simple python library to draw pretty maps from OpenStreetMap data",
long_description=parent_dir.joinpath("README.md").read_text(), long_description=parent_dir.joinpath("README.md").read_text(),
long_description_content_type="text/markdown", long_description_content_type="text/markdown",
url='https://github.com/marceloprates/prettymaps', url="https://github.com/marceloprates/prettymaps",
author='Marcelo Prates', author="Marcelo Prates",
author_email='marceloorp@gmail.com', author_email="marceloorp@gmail.com",
license='MIT License', license="MIT License",
packages=find_packages(exclude=("assets", "notebooks", "prints", "script")), packages=find_packages(exclude=("assets", "notebooks", "prints", "script")),
install_requires=parent_dir.joinpath("requirements.txt").read_text().splitlines(), install_requires=parent_dir.joinpath("requirements.txt").read_text().splitlines(),
classifiers=[ classifiers=[
'Intended Audience :: Science/Research', "Intended Audience :: Science/Research",
], ],
) )