"""
This module contains functions for plotting magnetic field data.
"""
from typing import Optional
import matplotlib.pyplot as plt
import numpy as np
from mtflib import mtf
from .solvers import calculate_b_field
def _get_3d_axes(ax=None):
"""
Helper to get a 3D matplotlib axis. Creates a new one if not provided.
Args:
ax (matplotlib.axes.Axes3D): The axis to use.
Returns:
matplotlib.axes.Axes3D: The 3D axis.
"""
if ax is None:
fig = plt.figure()
ax = fig.add_subplot(111, projection="3d")
return ax
[docs]
def plot_1d_field(
coil_instance,
field_component: str,
axis: str = "x",
start_point: Optional[np.ndarray] = None,
end_point: Optional[np.ndarray] = None,
num_points: int = 100,
plot_type: str = "line",
log_scale: bool = False,
ax=None,
title: str = "",
xlabel: str = "",
ylabel: str = "",
**kwargs,
):
"""
Plots a vector field component along a 1D line.
"""
if field_component not in ["x", "y", "z", "norm"]:
raise ValueError("field_component must be 'x', 'y', 'z', or 'norm'.")
if (start_point is None and end_point is not None) or (
start_point is not None and end_point is None
):
raise ValueError(
"start_point and end_point must both be provided or both be None."
)
if start_point is None:
if axis not in ["x", "y", "z"]:
raise ValueError("axis must be 'x', 'y', or 'z' if start_point is None.")
# Auto-size the plot based on coil dimensions
max_size = coil_instance.get_max_size()
max_size = np.max(max_size).item()
center = coil_instance.get_center_point()
min_val = center[0] - 1.25 * max_size / 2
max_val = center[0] + 1.25 * max_size / 2
if axis == "x":
line_points = np.linspace(min_val, max_val, num_points)
field_points = np.vstack([
line_points,
np.full(num_points, center[1]),
np.full(num_points, center[2]),
]).T
plot_axis_label = "x-axis"
elif axis == "y":
line_points = np.linspace(min_val, max_val, num_points)
field_points = np.vstack([
np.full(num_points, center[0]),
line_points,
np.full(num_points, center[2]),
]).T
plot_axis_label = "y-axis"
elif axis == "z":
line_points = np.linspace(min_val, max_val, num_points)
field_points = np.vstack([
np.full(num_points, center[0]),
np.full(num_points, center[1]),
line_points,
]).T
plot_axis_label = "z-axis"
else:
line_points = np.linspace(0, 1, num_points)
field_points = np.array([
start_point + t * (end_point - start_point) for t in line_points
])
plot_axis_label = "line"
# Calculate the B-field
vector_field = calculate_b_field(coil_instance, field_points=field_points)
# Extract the requested component
if field_component == "x":
field_values = np.array([v.x for v in vector_field._vectors_mtf])
elif field_component == "y":
field_values = np.array([v.y for v in vector_field._vectors_mtf])
elif field_component == "z":
field_values = np.array([v.z for v in vector_field._vectors_mtf])
elif field_component == "norm":
field_values = vector_field.get_magnitude()
# Evaluate the components if they are MTFs
if isinstance(field_values[0], mtf):
field_values = np.array([
val.extract_coefficient(tuple([0] * val.dimension)).item()
for val in field_values
])
# Explicitly cast to float/real to avoid ComplexWarning
field_values = np.real(field_values).astype(float)
line_points = np.real(line_points).astype(float)
# Plot the data
if ax is None:
fig, ax = plt.subplots()
if plot_type == "line":
ax.plot(line_points, field_values, **kwargs)
elif plot_type == "scatter":
ax.scatter(line_points, field_values, **kwargs)
else:
raise ValueError("plot_type must be 'line' or 'scatter'.")
# Customize plot
if log_scale:
ax.set_yscale("log")
if not title:
title = f"Field component {field_component} along {plot_axis_label}"
if not xlabel:
xlabel = plot_axis_label
if not ylabel:
ylabel = f"Vector field component ({field_component})"
ax.set_title(title)
ax.set_xlabel(xlabel)
ax.set_ylabel(ylabel)
ax.grid(True)
if ax is None:
plt.show()
[docs]
def plot_2d_field(
coil_instance,
field_component: str = "norm",
plane: str = "xy",
center: Optional[np.ndarray] = None,
normal: Optional[np.ndarray] = None,
size_a: Optional[float] = None,
size_b: Optional[float] = None,
num_points_a: int = 50,
num_points_b: int = 50,
plot_type: str = "heatmap",
ax=None,
title: str = "",
offset_from_center: float = 0.0,
**kwargs,
):
"""
Plots a vector field on a 2D plane.
Args:
coil_instance: The coil object to calculate the field from.
field_component (str): Component to plot ('x', 'y', 'z', or 'norm').
plane (str): Plane to plot on ('xy', 'yz', or 'xz').
center (np.ndarray, optional): Center point of the plot. Defaults to
coil center.
normal (np.ndarray, optional): Normal vector for custom plane.
size_a (float, optional): Size along the first axis of the plane.
size_b (float, optional): Size along the second axis of the plane.
num_points_a (int): Grid points along first axis.
num_points_b (int): Grid points along second axis.
plot_type (str): Type of plot ('heatmap', 'quiver', 'streamline').
ax (matplotlib.axes.Axes, optional): Matplotlib axis to plot on.
title (str): Plot title.
offset_from_center (float): Offset from the center along the normal.
**kwargs: Additional arguments passed to the plotting function.
"""
if field_component not in ["x", "y", "z", "norm"]:
raise ValueError("field_component must be 'x', 'y', 'z', or 'norm'.")
if plot_type not in ["quiver", "streamline", "heatmap"]:
raise ValueError("plot_type must be 'quiver', 'streamline', or 'heatmap'.")
# Determine the plane and default center
if center is None:
center = coil_instance.get_center_point()
# Automatically determine plot size if not specified
if size_a is None or size_b is None:
max_size = coil_instance.get_max_size()
max_size = np.max(max_size).item()
size_a = 1.25 * max_size if size_a is None else size_a
size_b = 1.25 * max_size if size_b is None else size_b
if plane == "xy":
axis_labels = ("x", "y")
elif plane == "yz":
axis_labels = ("y", "z")
elif plane == "xz":
axis_labels = ("x", "z")
else:
if size_a is None or size_b is None:
raise ValueError(
"size_a and size_b must be specified for custom planes."
)
# Grid generation
if normal is None:
if plane == "xy":
a_coords = np.linspace(
center[0] - size_a / 2, center[0] + size_a / 2, num_points_a
)
b_coords = np.linspace(
center[1] - size_b / 2, center[1] + size_b / 2, num_points_b
)
A, B = np.meshgrid(a_coords, b_coords)
# Ensure coordinates are real
A = np.real(A)
B = np.real(B)
C = np.full(A.shape, center[2])
C = C + offset_from_center
field_points = np.vstack([A.ravel(), B.ravel(), C.ravel()]).T
elif plane == "yz":
a_coords = np.linspace(
center[1] - size_a / 2, center[1] + size_a / 2, num_points_a
)
b_coords = np.linspace(
center[2] - size_b / 2, center[2] + size_b / 2, num_points_b
)
A, B = np.meshgrid(a_coords, b_coords)
# Ensure coordinates are real
A = np.real(A)
B = np.real(B)
C = np.full(A.shape, center[0])
C = C + offset_from_center
field_points = np.vstack([C.ravel(), A.ravel(), B.ravel()]).T
elif plane == "xz":
a_coords = np.linspace(
center[0] - size_a / 2, center[0] + size_a / 2, num_points_a
)
b_coords = np.linspace(
center[2] - size_b / 2, center[2] + size_b / 2, num_points_b
)
A, B = np.meshgrid(a_coords, b_coords)
# Ensure coordinates are real
A = np.real(A)
B = np.real(B)
C = np.full(A.shape, center[1])
C = C + offset_from_center
field_points = np.vstack([A.ravel(), C.ravel(), B.ravel()]).T
else:
raise ValueError(
"plane must be 'xy', 'yz', or 'xz' if normal is not provided."
)
else:
# Generate grid for a custom plane
assert normal is not None
n = normal / np.linalg.norm(normal)
if np.allclose(n, np.array([0, 0, 1])) or np.allclose(n, np.array([0, 0, -1])):
u = np.array([1, 0, 0])
else:
u = np.cross(n, np.array([0, 0, 1]))
u = u / np.linalg.norm(u)
v = np.cross(n, u)
a_coords = np.linspace(-size_a / 2, size_a / 2, num_points_a)
b_coords = np.linspace(-size_b / 2, size_b / 2, num_points_b)
A, B = np.meshgrid(a_coords, b_coords)
field_points = np.zeros((num_points_a * num_points_b, 3))
for i in range(A.shape[0]):
for j in range(A.shape[1]):
point = center + offset_from_center * n + A[i, j] * u + B[i, j] * v
field_points[i * num_points_b + j] = point
vector_field = calculate_b_field(coil_instance, field_points=field_points)
b_vectors = np.array([b.to_numpy_array() for b in vector_field._vectors_mtf])
if ax is None:
fig, ax = plt.subplots()
# Plotting logic
if plot_type == "quiver":
if normal is None:
if plane == "xy":
U, V = b_vectors[:, 0], b_vectors[:, 1]
elif plane == "yz":
U, V = b_vectors[:, 1], b_vectors[:, 2]
elif plane == "xz":
U, V = b_vectors[:, 0], b_vectors[:, 2]
# Explicitly cast to float to avoid ComplexWarning
ax.quiver(
A,
B,
np.real(U).reshape(A.shape).astype(float),
np.real(V).reshape(B.shape).astype(float),
**kwargs,
)
else:
projected_b = b_vectors - np.dot(b_vectors, normal[:, np.newaxis]) * normal
U, V = np.dot(projected_b, u), np.dot(projected_b, v)
# Explicitly cast to float to avoid ComplexWarning
ax.quiver(
A,
B,
np.real(U).reshape(A.shape).astype(float),
np.real(V).reshape(B.shape).astype(float),
**kwargs,
)
elif plot_type == "streamline":
if normal is None:
if plane == "xy":
U, V = b_vectors[:, 0], b_vectors[:, 1]
elif plane == "yz":
U, V = b_vectors[:, 1], b_vectors[:, 2]
elif plane == "xz":
U, V = b_vectors[:, 0], b_vectors[:, 2]
# Explicitly cast to float to avoid ComplexWarning
ax.streamplot(
A,
B,
np.real(U).reshape(A.shape).astype(float),
np.real(V).reshape(B.shape).astype(float),
**kwargs,
)
else:
projected_b = b_vectors - np.dot(b_vectors, normal[:, np.newaxis]) * normal
U, V = np.dot(projected_b, u), np.dot(projected_b, v)
# Explicitly cast to float to avoid ComplexWarning
ax.streamplot(
A,
B,
np.real(U).reshape(A.shape).astype(float),
np.real(V).reshape(B.shape).astype(float),
**kwargs,
)
elif plot_type == "heatmap":
if field_component == "norm":
field_data = vector_field.get_magnitude()
elif field_component == "x":
field_data = np.array([
b.x.extract_coefficient(tuple([0] * b.x.dimension)).item()
for b in vector_field._vectors_mtf
])
elif field_component == "y":
field_data = np.array([
b.y.extract_coefficient(tuple([0] * b.y.dimension)).item()
for b in vector_field._vectors_mtf
])
else: # z
field_data = np.array([
b.z.extract_coefficient(tuple([0] * b.z.dimension)).item()
for b in vector_field._vectors_mtf
])
# Explicitly cast to float/real
field_data = np.real(field_data).astype(float)
c = ax.pcolormesh(A, B, field_data.reshape(A.shape), **kwargs)
plt.colorbar(c, ax=ax)
else:
raise ValueError("plot_type must be 'quiver', 'streamline', or 'heatmap'.")
# Set titles and labels
if not title:
title = f"Field {field_component} on {plane}-plane"
ax.set_title(title)
if normal is None:
ax.set_xlabel(f"{axis_labels[0]}-axis")
ax.set_ylabel(f"{axis_labels[1]}-axis")
else:
ax.set_xlabel("a-axis")
ax.set_ylabel("b-axis")
ax.set_aspect("equal", adjustable="box")
ax.grid(True)
if ax is None:
# Check for interactive backend to avoid warning in tests
if plt.get_backend() != "Agg":
plt.show()
[docs]
def plot_field_vectors_3d(
coil_instance,
num_points_a: int = 10,
num_points_b: int = 10,
num_points_c: int = 10,
title: str = "",
ax=None,
**kwargs,
):
"""
Generates a 3D quiver plot of the magnetic field vectors on a grid
around the coil.
This method automatically creates a grid of field points based on the
coil's dimensions and then calculates and plots the magnetic field
vectors at these points.
Args:
coil_instance (Coil): An instance of a Coil subclass.
num_points_a (int): Number of grid points along the x-dimension.
num_points_b (int): Number of grid points along the y-dimension.
num_points_c (int): Number of grid points along the z-dimension.
title (str, optional): The title of the plot. Defaults to an
auto-generated title.
ax (matplotlib.axes.Axes3D, optional): The 3D axis to plot on.
If None, a new figure is created.
**kwargs: Additional keyword arguments for the `ax.quiver` function.
"""
# Get the coil's bounding box to create a reasonable grid
max_size = coil_instance.get_max_size()
center = coil_instance.get_center_point()
x_range = np.linspace(
center[0] - 1.25 * max_size[0] / 2,
center[0] + 1.25 * max_size[0] / 2,
num_points_a,
)
y_range = np.linspace(
center[1] - 1.25 * max_size[1] / 2,
center[1] + 1.25 * max_size[1] / 2,
num_points_b,
)
z_range = np.linspace(
center[2] - 1.25 * max_size[2] / 2,
center[2] + 1.25 * max_size[2] / 2,
num_points_c,
)
# Create the grid of field points
X, Y, Z = np.meshgrid(x_range, y_range, z_range)
# Ensure coordinates are real
X = np.real(X)
Y = np.real(Y)
Z = np.real(Z)
field_points = np.vstack([X.ravel(), Y.ravel(), Z.ravel()]).T
# Calculate the magnetic field at these points
vector_field = calculate_b_field(coil_instance, field_points)
b_vectors = np.array([b.to_numpy_array() for b in vector_field._vectors_mtf])
U, V, W = b_vectors[:, 0], b_vectors[:, 1], b_vectors[:, 2]
# Reshape the 1D field component arrays to match the 3D meshgrid shape
# Explicitly cast to float to avoid ComplexWarning
U = np.real(U).reshape(X.shape).astype(float)
V = np.real(V).reshape(Y.shape).astype(float)
W = np.real(W).reshape(Z.shape).astype(float)
# Get or create the 3D axes
plot_ax = _get_3d_axes(ax)
# Plot the vectors
plot_ax.quiver(X, Y, Z, U, V, W, **kwargs)
# Set title and labels
if not title:
title = f"3D Magnetic Field Vectors from {coil_instance.__class__.__name__}"
plot_ax.set_title(title)
plot_ax.set_xlabel("X-axis")
plot_ax.set_ylabel("Y-axis")
plot_ax.set_zlabel("Z-axis")
plot_ax.set_aspect("equal", "box")
if ax is None:
# Check for interactive backend to avoid warning in tests
if plt.get_backend() != "Agg":
plt.show()