How to Remove Excess Top and Bottom space in 3D matplotlib.animation Figure? (original) (raw)
Hi everyone,
I’m currently working on a 3D animation using matplotlib.animation.FuncAnimation
to visualize the tautochrone problem. In this animation, multiple balls slide down their own cycloidal paths within a transparent cuboid box. While the animation logic is functioning well, I am facing an issue with excess vertical space (padding) in the figure. Specifically, there is too much space at the top and bottom of the rendered animation, making the content appear vertically squished and surrounded by empty areas.
Here’s what I’ve tried so far:
- Adjusting the
figsize
(various combinations of width and height). - Using
fig.tight_layout()
right after creating the figure. - Manually adjusting subplot margins using
fig.subplots_adjust()
, e.g.:
self.fig.subplots_adjust(left=0.05, right=0.95, top=0.85, bottom=0.05)
Despite these, the issue persists. The Axes3D
object seems to leave persistent top and bottom padding that tight_layout()
doesn’t manage effectively. I’ve seen this happen especially with 3D plots where the axis limits and aspect ratios are heavily customized.
What I’m looking for:
- A reliable way to reduce or eliminate the vertical (top/bottom) white/empty space around my 3D plot in the final figure or animation.
- Best practices for making
tight_layout()
orsubplots_adjust()
work properly with 3D axes.
I’m attaching my full code and the animation. Any tips or workarounds would be highly appreciated!
class TautochroneAnimator3D:
"""
Animates the tautochrone problem in 3D, showing particles sliding down
cycloidal paths from different starting points with one ball per path.
"""
def __init__(self,
initial_angles_deg,
r=1.0,
ball_radius=0.1,
g=9.81,
fps=30,
save_anim=False,
filename=None,
cuboid_padding=0.2):
"""
Initializes the 3D Tautochrone Animator with one ball per path.
Parameters:
-----------
initial_angles_deg : list[float]
Starting angles in degrees for each ball. One ball per path.
Angle 0 is the left cusp, 180 is the lowest point, 360 is the right cusp.
r : float
Radius of the generating circle for the cycloid (default: 1.0).
ball_radius : float
Radius of the ball in data units (default: 0.1).
g : float
Acceleration due to gravity (default: 9.81). Affects the time scale.
fps : int
Frames per second for the animation (default: 30).
save_anim : bool
Whether to save the animation to a file (default: False).
filename : str, optional
Name of the file to save the animation. If None and save_anim is True,
a default name is generated.
cuboid_padding : float
Padding factor for cuboid dimensions (default: 0.2).
"""
self.r = r
self.g = g
self.fps = fps
self.save_anim = save_anim
self.filename = filename
self.ball_radius = ball_radius
self.cuboid_padding = cuboid_padding
# Always expect a list of angles
if isinstance(initial_angles_deg, (int, float)):
self.initial_angles_deg = [float(initial_angles_deg)]
elif isinstance(initial_angles_deg, list):
self.initial_angles_deg = [float(angle) for angle in initial_angles_deg]
else:
raise TypeError("initial_angles_deg must be a number or a list of numbers.")
# Add stationary ball at 180 degrees if not already present
if 180.0 not in self.initial_angles_deg:
self.initial_angles_deg.append(180.0)
self.has_stationary_ball = True
else:
self.has_stationary_ball = True
# Validate angles
if not all(0 <= angle <= 360 for angle in self.initial_angles_deg):
raise ValueError("Initial angles must be between 0 and 360 degrees.")
self.theta0_list = [np.radians(angle) for angle in self.initial_angles_deg]
self.num_paths = len(self.initial_angles_deg)
self.num_balls = self.num_paths
# Calculate physics constants and timing
self.omega = np.sqrt(self.g / (4 * self.r))
self.time_to_bottom = np.pi / (2 * self.omega) # Time = pi * sqrt(r/g)
# Set animation duration
self.duration = self.time_to_bottom * 1.2
self.num_frames = int(self.duration * self.fps)
# Animation elements
self.fig = None
self.ax = None
self.balls = [] # List to hold ball Scatter plots
self.ball_colors = []
self.paths = [] # List to hold cycloid path plots
# Index of the stationary ball (180 degrees)
self.stationary_idx = self.initial_angles_deg.index(180.0)
def _cycloid_path(self, theta):
"""Parametric equation for the inverted cycloid path in x-z plane."""
x = self.r * (theta - np.sin(theta))
z = -self.r * (1 - np.cos(theta))
return x, z
def _get_theta_at_time(self, theta0, t, is_stationary=False):
"""Calculates the angular position theta(t) of a ball.
The stationary ball always stays at pi (180 degrees)."""
if is_stationary:
return np.pi # Always at the bottom
if t >= self.time_to_bottom:
return np.pi # Reached the bottom
elif t <= 0:
return theta0 # Initial position
else:
omega = self.omega
cos_val = np.cos(omega * t)
if theta0 < np.pi:
# Ball rolls right from theta0 < pi to pi
arg = np.cos(theta0 / 2.0) * cos_val
arg = np.clip(arg, -1.0, 1.0)
return 2.0 * np.arccos(arg)
else:
# Ball rolls left from theta0 > pi to pi
k = -np.cos(theta0 / 2.0)
arg = k * cos_val
arg = np.clip(arg, -1.0, 1.0)
return 2.0 * np.pi - 2.0 * np.arccos(arg)
def _get_ball_position(self, theta, y_offset):
"""
Calculates the center position for the ball in 3D so it rests ON the cycloid curve.
"""
x_curve, z_curve = self._cycloid_path(theta)
# Calculate normal vector in the x-z plane
if np.isclose(theta, 0) or np.isclose(theta, 2 * np.pi): # At cusps
nx = 1 if theta < np.pi else -1
nz = 0
elif np.isclose(theta, np.pi): # At bottom point
nx = 0
nz = 1
else:
# Normal vector: (-dz/dtheta, dx/dtheta) = (r*sin(theta), r*(1 - cos(theta)))
nx = self.r * np.sin(theta)
nz = self.r * (1 - np.cos(theta))
norm = np.sqrt(nx**2 + nz**2)
if norm > 1e-9:
nx /= norm
nz /= norm
# Offset ball center
x_center = x_curve + self.ball_radius * nx
z_center = z_curve + self.ball_radius * nz
y_center = y_offset
return x_center, y_center, z_center
def _setup_plot(self):
"""Sets up the 3D matplotlib figure and axes."""
plt.style.use('dark_background')
self.fig = plt.figure(figsize=(10, 10))
self.fig.subplots_adjust(left=0.05, right=0.95, top=0.85, bottom=0.05)
self.fig.tight_layout()
self.ax = self.fig.add_subplot(111, projection='3d')
self.ax.set_title(f"3D Tautochrone Curves (Cycloids)\nTime to bottom: {self.time_to_bottom:.3f}s",
fontsize=16)
# Calculate cuboid dimensions
max_x = self.r * (2 * np.pi) # Cycloid length from theta=0 to 2pi
cuboid_length = max_x
cuboid_width = self.r * self.num_paths * (1 + self.cuboid_padding)
cuboid_height = 2 * self.r * (1 + self.cuboid_padding)
# Set plot limits
self.ax.set_xlim(0, cuboid_length)
self.ax.set_ylim(0, cuboid_width)
self.ax.set_zlim(-2 * self.r - self.r * self.cuboid_padding, self.r * self.cuboid_padding)
self.ax.set_box_aspect([cuboid_length, cuboid_width, cuboid_height])
# Remove gridlines and ticks
self.ax.axis('off')
for axis in ['x', 'y', 'z']:
getattr(self.ax, f'{axis}axis').set_pane_color((0,0,0,1))
self.ax.grid(False)
# Create cuboid
vertices = np.array([
[0, 0, -2 * self.r], # Adjusted to match cycloid z-level
[cuboid_length, 0, -2 * self.r],
[cuboid_length, cuboid_width, -2 * self.r],
[0, cuboid_width, -2 * self.r],
[0, 0, -2 * self.r + cuboid_height],
[cuboid_length, 0, -2 * self.r + cuboid_height],
[cuboid_length, cuboid_width, -2 * self.r + cuboid_height],
[0, cuboid_width, -2 * self.r + cuboid_height]
])
faces = [
[vertices[0], vertices[1], vertices[2], vertices[3]], # bottom
[vertices[4], vertices[5], vertices[6], vertices[7]], # top
[vertices[0], vertices[1], vertices[5], vertices[4]], # front
[vertices[2], vertices[3], vertices[7], vertices[6]], # back
[vertices[1], vertices[2], vertices[6], vertices[5]], # right
[vertices[0], vertices[3], vertices[7], vertices[4]] # left
]
cuboid = Poly3DCollection(faces, alpha=0.2, linewidth=1, edgecolor='white')
cuboid.set_facecolor('gray')
cuboid.set_zorder(1)
self.ax.add_collection3d(cuboid)
# Plot cycloid paths and initialize balls
theta_path = np.linspace(0, 2 * np.pi, 400)
x_path, z_path = self._cycloid_path(theta_path)
color_cycle = itertools.cycle(plt.cm.tab10.colors)
for i in range(self.num_paths):
y_offset = (i + 0.5) * (cuboid_width / self.num_paths)
self.ax.plot(x_path, [y_offset] * len(x_path), z_path, color='gray', lw=2.5, zorder=2)
# Initial angle for this path's ball
theta0 = self.theta0_list[i]
color = next(color_cycle)
self.ball_colors.append(color)
x0, y0, z0 = self._get_ball_position(theta0, y_offset)
if i == self.stationary_idx:
label = f"Stationary Ball (180°)"
else:
label = f"Ball {i+1}: {np.degrees(theta0):.0f}°"
# Plot ball as a scatter point
ball = self.ax.scatter([x0], [y0], [z0], c=color, s=250, zorder=5, label=label)
self.balls.append(ball)
# Add equation text
cycloid_eq = "Inverted Cycloid: <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>x</mi><mo>=</mo><mi>r</mi><mo stretchy="false">(</mo><mspace linebreak="newline"></mspace><mi>t</mi><mi>h</mi><mi>e</mi><mi>t</mi><mi>a</mi><mo>−</mo><mspace linebreak="newline"></mspace><mi>s</mi><mi>i</mi><mi>n</mi><mspace linebreak="newline"></mspace><mi>t</mi><mi>h</mi><mi>e</mi><mi>t</mi><mi>a</mi><mo stretchy="false">)</mo></mrow><annotation encoding="application/x-tex">x = r(\\theta - \\sin\\theta)</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.4306em;"></span><span class="mord mathnormal">x</span><span class="mspace" style="margin-right:0.2778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2778em;"></span></span><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord mathnormal" style="margin-right:0.02778em;">r</span><span class="mopen">(</span></span><span class="mspace newline"></span><span class="base"><span class="strut" style="height:0.7778em;vertical-align:-0.0833em;"></span><span class="mord mathnormal">t</span><span class="mord mathnormal">h</span><span class="mord mathnormal">e</span><span class="mord mathnormal">t</span><span class="mord mathnormal">a</span><span class="mspace" style="margin-right:0.2222em;"></span><span class="mbin">−</span></span><span class="mspace newline"></span><span class="base"><span class="strut" style="height:0.6595em;"></span><span class="mord mathnormal">s</span><span class="mord mathnormal">in</span></span><span class="mspace newline"></span><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord mathnormal">t</span><span class="mord mathnormal">h</span><span class="mord mathnormal">e</span><span class="mord mathnormal">t</span><span class="mord mathnormal">a</span><span class="mclose">)</span></span></span></span>, <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>z</mi><mo>=</mo><mo>−</mo><mi>r</mi><mo stretchy="false">(</mo><mn>1</mn><mo>−</mo><mspace linebreak="newline"></mspace><mi>c</mi><mi>o</mi><mi>s</mi><mspace linebreak="newline"></mspace><mi>t</mi><mi>h</mi><mi>e</mi><mi>t</mi><mi>a</mi><mo stretchy="false">)</mo></mrow><annotation encoding="application/x-tex">z = -r(1 - \\cos\\theta)</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.4306em;"></span><span class="mord mathnormal" style="margin-right:0.04398em;">z</span><span class="mspace" style="margin-right:0.2778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2778em;"></span></span><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord">−</span><span class="mord mathnormal" style="margin-right:0.02778em;">r</span><span class="mopen">(</span><span class="mord">1</span><span class="mspace" style="margin-right:0.2222em;"></span><span class="mbin">−</span></span><span class="mspace newline"></span><span class="base"><span class="strut" style="height:0.4306em;"></span><span class="mord mathnormal">cos</span></span><span class="mspace newline"></span><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord mathnormal">t</span><span class="mord mathnormal">h</span><span class="mord mathnormal">e</span><span class="mord mathnormal">t</span><span class="mord mathnormal">a</span><span class="mclose">)</span></span></span></span> ; <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mo stretchy="false">(</mo><mn>0</mn><mspace linebreak="newline"></mspace><mi>l</mi><mi>e</mi><mi>q</mi><mspace linebreak="newline"></mspace><mi>t</mi><mi>h</mi><mi>e</mi><mi>t</mi><mi>a</mi><mspace linebreak="newline"></mspace><mi>l</mi><mi>e</mi><mi>q</mi><mn>2</mn><mspace linebreak="newline"></mspace><mi>p</mi><mi>i</mi><mo stretchy="false">)</mo></mrow><annotation encoding="application/x-tex">(0 \\leq \\theta \\leq 2\\pi)</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mopen">(</span><span class="mord">0</span></span><span class="mspace newline"></span><span class="base"><span class="strut" style="height:0.8889em;vertical-align:-0.1944em;"></span><span class="mord mathnormal" style="margin-right:0.01968em;">l</span><span class="mord mathnormal">e</span><span class="mord mathnormal" style="margin-right:0.03588em;">q</span></span><span class="mspace newline"></span><span class="base"><span class="strut" style="height:0.6944em;"></span><span class="mord mathnormal">t</span><span class="mord mathnormal">h</span><span class="mord mathnormal">e</span><span class="mord mathnormal">t</span><span class="mord mathnormal">a</span></span><span class="mspace newline"></span><span class="base"><span class="strut" style="height:0.8889em;vertical-align:-0.1944em;"></span><span class="mord mathnormal" style="margin-right:0.01968em;">l</span><span class="mord mathnormal">e</span><span class="mord mathnormal" style="margin-right:0.03588em;">q</span><span class="mord">2</span></span><span class="mspace newline"></span><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord mathnormal">p</span><span class="mord mathnormal">i</span><span class="mclose">)</span></span></span></span>"
self.ax.text2D(0.5, 0.85, cycloid_eq, transform=self.ax.transAxes,
ha='center', va='top', fontsize=10, color='white',
bbox=dict(facecolor='black', alpha=0.8, edgecolor='gray', boxstyle='round,pad=0.3'))
self.ax.legend(bbox_to_anchor=(0.5, 0.2), loc='lower center', ncols=2, fontsize=8, markerscale=0.5)
# Set initial view angle
self.ax.view_init(elev=10, azim=280)
def _update(self, frame):
"""Updates the animation elements for each frame."""
t = frame / self.fps
artists_to_update = []
for i in range(self.num_paths):
y_offset = (i + 0.5) * (self.r * self.num_paths * (1 + self.cuboid_padding) / self.num_paths)
# Determine if this is the stationary ball
is_stationary = (i == self.stationary_idx)
theta0 = self.theta0_list[i]
# Get current position
current_theta = self._get_theta_at_time(theta0, t, is_stationary)
x, y, z = self._get_ball_position(current_theta, y_offset)
# Update ball position
self.balls[i]._offsets3d = ([x], [y], [z])
artists_to_update.append(self.balls[i])
return artists_to_update
def animate(self):
"""Creates and runs (or saves) the animation."""
self._setup_plot()
if self.fig is None:
raise ValueError("Figure was not properly initialized")
ani = FuncAnimation(self.fig, self._update,
frames=self.num_frames,
interval=int(1000 / self.fps),
blit=False,
repeat=False)
if self.save_anim:
if not self.filename:
angles_str = "_".join(map(str, self.initial_angles_deg))
self.filename = f"tautochrone_3d_{self.num_balls}balls_{angles_str}deg.gif"
save_dir = "ANIMATIONS/TAUTOCHRONE"
os.makedirs(save_dir, exist_ok=True)
filepath = os.path.join(save_dir, self.filename)
print(f"Saving animation to {os.path.abspath(filepath)}...")
try:
ani.save(filepath, writer='pillow', fps=self.fps)
print("Animation saved successfully!")
except Exception as e:
print(f"Error saving animation: {e}")
print("Make sure you have necessary writers installed (e.g., Pillow).")
print("Attempting to show animation instead.")
plt.show()
finally:
plt.close(self.fig)
else:
plt.show()
return ani
if __name__ == "__main__":
animator = TautochroneAnimator3D(
initial_angles_deg=[0, 90, 135, 180, 225, 270, 360],
r=10,
ball_radius=1.6,
g=9.8,
fps=50,
save_anim=True
)
anim = animator.animate()