How to Remove Excess Top and Bottom space in 3D matplotlib.animation Figure? (original) (raw)

Hi everyone,

I’m currently working on a 3D animation using matplotlib.animation.FuncAnimation to visualize the tautochrone problem. In this animation, multiple balls slide down their own cycloidal paths within a transparent cuboid box. While the animation logic is functioning well, I am facing an issue with excess vertical space (padding) in the figure. Specifically, there is too much space at the top and bottom of the rendered animation, making the content appear vertically squished and surrounded by empty areas.

Here’s what I’ve tried so far:

self.fig.subplots_adjust(left=0.05, right=0.95, top=0.85, bottom=0.05)

Despite these, the issue persists. The Axes3D object seems to leave persistent top and bottom padding that tight_layout() doesn’t manage effectively. I’ve seen this happen especially with 3D plots where the axis limits and aspect ratios are heavily customized.

What I’m looking for:

I’m attaching my full code and the animation. Any tips or workarounds would be highly appreciated!

tautochrone_3d

class TautochroneAnimator3D:
    """
    Animates the tautochrone problem in 3D, showing particles sliding down
    cycloidal paths from different starting points with one ball per path.
    """
    def __init__(self,
                 initial_angles_deg,
                 r=1.0,
                 ball_radius=0.1,
                 g=9.81,
                 fps=30,
                 save_anim=False,
                 filename=None,
                 cuboid_padding=0.2):
        """
        Initializes the 3D Tautochrone Animator with one ball per path.

        Parameters:
        -----------
        initial_angles_deg : list[float]
            Starting angles in degrees for each ball. One ball per path.
            Angle 0 is the left cusp, 180 is the lowest point, 360 is the right cusp.
        r : float
            Radius of the generating circle for the cycloid (default: 1.0).
        ball_radius : float
            Radius of the ball in data units (default: 0.1).
        g : float
            Acceleration due to gravity (default: 9.81). Affects the time scale.
        fps : int
            Frames per second for the animation (default: 30).
        save_anim : bool
            Whether to save the animation to a file (default: False).
        filename : str, optional
            Name of the file to save the animation. If None and save_anim is True,
            a default name is generated.
        cuboid_padding : float
            Padding factor for cuboid dimensions (default: 0.2).
        """
        self.r = r
        self.g = g
        self.fps = fps
        self.save_anim = save_anim
        self.filename = filename
        self.ball_radius = ball_radius
        self.cuboid_padding = cuboid_padding

        # Always expect a list of angles
        if isinstance(initial_angles_deg, (int, float)):
            self.initial_angles_deg = [float(initial_angles_deg)]
        elif isinstance(initial_angles_deg, list):
            self.initial_angles_deg = [float(angle) for angle in initial_angles_deg]
        else:
            raise TypeError("initial_angles_deg must be a number or a list of numbers.")

        # Add stationary ball at 180 degrees if not already present
        if 180.0 not in self.initial_angles_deg:
            self.initial_angles_deg.append(180.0)
            self.has_stationary_ball = True
        else:
            self.has_stationary_ball = True
            
        # Validate angles
        if not all(0 <= angle <= 360 for angle in self.initial_angles_deg):
            raise ValueError("Initial angles must be between 0 and 360 degrees.")

        self.theta0_list = [np.radians(angle) for angle in self.initial_angles_deg]
        self.num_paths = len(self.initial_angles_deg)  
        self.num_balls = self.num_paths  

        # Calculate physics constants and timing
        self.omega = np.sqrt(self.g / (4 * self.r))
        self.time_to_bottom = np.pi / (2 * self.omega)  # Time = pi * sqrt(r/g)

        # Set animation duration
        self.duration = self.time_to_bottom * 1.2
        self.num_frames = int(self.duration * self.fps)

        # Animation elements
        self.fig = None
        self.ax = None
        self.balls = []  # List to hold ball Scatter plots
        self.ball_colors = []
        self.paths = []  # List to hold cycloid path plots
        
        # Index of the stationary ball (180 degrees)
        self.stationary_idx = self.initial_angles_deg.index(180.0)

    def _cycloid_path(self, theta):
        """Parametric equation for the inverted cycloid path in x-z plane."""
        x = self.r * (theta - np.sin(theta))
        z = -self.r * (1 - np.cos(theta))
        return x, z

    def _get_theta_at_time(self, theta0, t, is_stationary=False):
        """Calculates the angular position theta(t) of a ball.
        The stationary ball always stays at pi (180 degrees)."""
        if is_stationary:
            return np.pi  # Always at the bottom
        
        if t >= self.time_to_bottom:
            return np.pi  # Reached the bottom
        elif t <= 0:
            return theta0  # Initial position
        else:
            omega = self.omega
            cos_val = np.cos(omega * t)
            if theta0 < np.pi:
                # Ball rolls right from theta0 < pi to pi
                arg = np.cos(theta0 / 2.0) * cos_val
                arg = np.clip(arg, -1.0, 1.0)
                return 2.0 * np.arccos(arg)
            else:
                # Ball rolls left from theta0 > pi to pi
                k = -np.cos(theta0 / 2.0)
                arg = k * cos_val
                arg = np.clip(arg, -1.0, 1.0)
                return 2.0 * np.pi - 2.0 * np.arccos(arg)

    def _get_ball_position(self, theta, y_offset):
        """
        Calculates the center position for the ball in 3D so it rests ON the cycloid curve.
        """
        x_curve, z_curve = self._cycloid_path(theta)

        # Calculate normal vector in the x-z plane
        if np.isclose(theta, 0) or np.isclose(theta, 2 * np.pi):  # At cusps
            nx = 1 if theta < np.pi else -1
            nz = 0
        elif np.isclose(theta, np.pi):  # At bottom point
            nx = 0
            nz = 1
        else:
            # Normal vector: (-dz/dtheta, dx/dtheta) = (r*sin(theta), r*(1 - cos(theta)))
            nx = self.r * np.sin(theta)
            nz = self.r * (1 - np.cos(theta))
            norm = np.sqrt(nx**2 + nz**2)
            if norm > 1e-9:
                nx /= norm
                nz /= norm

        # Offset ball center
        x_center = x_curve + self.ball_radius * nx
        z_center = z_curve + self.ball_radius * nz
        y_center = y_offset

        return x_center, y_center, z_center

    def _setup_plot(self):
        """Sets up the 3D matplotlib figure and axes."""
        plt.style.use('dark_background')
        self.fig = plt.figure(figsize=(10, 10))
        self.fig.subplots_adjust(left=0.05, right=0.95, top=0.85, bottom=0.05)
        self.fig.tight_layout()

        self.ax = self.fig.add_subplot(111, projection='3d')
        self.ax.set_title(f"3D Tautochrone Curves (Cycloids)\nTime to bottom: {self.time_to_bottom:.3f}s",
                          fontsize=16)
        

        # Calculate cuboid dimensions
        max_x = self.r * (2 * np.pi)  # Cycloid length from theta=0 to 2pi
        cuboid_length = max_x 
        cuboid_width = self.r * self.num_paths * (1 + self.cuboid_padding)
        cuboid_height = 2 * self.r * (1 + self.cuboid_padding)

        # Set plot limits
        self.ax.set_xlim(0, cuboid_length)
        self.ax.set_ylim(0, cuboid_width)
        self.ax.set_zlim(-2 * self.r - self.r * self.cuboid_padding, self.r * self.cuboid_padding)
        self.ax.set_box_aspect([cuboid_length, cuboid_width, cuboid_height])

        # Remove gridlines and ticks
        self.ax.axis('off')
        for axis in ['x', 'y', 'z']:
            getattr(self.ax, f'{axis}axis').set_pane_color((0,0,0,1))
        self.ax.grid(False)
        
        
        # Create cuboid
        vertices = np.array([
            [0, 0, -2 * self.r],  # Adjusted to match cycloid z-level
            [cuboid_length, 0, -2 * self.r],
            [cuboid_length, cuboid_width, -2 * self.r],
            [0, cuboid_width, -2 * self.r],
            [0, 0, -2 * self.r + cuboid_height],
            [cuboid_length, 0, -2 * self.r + cuboid_height],
            [cuboid_length, cuboid_width, -2 * self.r + cuboid_height],
            [0, cuboid_width, -2 * self.r + cuboid_height]
        ])

        faces = [
            [vertices[0], vertices[1], vertices[2], vertices[3]],  # bottom
            [vertices[4], vertices[5], vertices[6], vertices[7]],  # top
            [vertices[0], vertices[1], vertices[5], vertices[4]],  # front
            [vertices[2], vertices[3], vertices[7], vertices[6]],  # back
            [vertices[1], vertices[2], vertices[6], vertices[5]],  # right
            [vertices[0], vertices[3], vertices[7], vertices[4]]   # left
        ]

        cuboid = Poly3DCollection(faces, alpha=0.2, linewidth=1, edgecolor='white')
        cuboid.set_facecolor('gray')
        cuboid.set_zorder(1)
        self.ax.add_collection3d(cuboid)

        # Plot cycloid paths and initialize balls
        theta_path = np.linspace(0, 2 * np.pi, 400)
        x_path, z_path = self._cycloid_path(theta_path)
        color_cycle = itertools.cycle(plt.cm.tab10.colors)

        for i in range(self.num_paths):
            y_offset = (i + 0.5) * (cuboid_width / self.num_paths)
            self.ax.plot(x_path, [y_offset] * len(x_path), z_path, color='gray', lw=2.5, zorder=2)

            # Initial angle for this path's ball
            theta0 = self.theta0_list[i]
            color = next(color_cycle)
            self.ball_colors.append(color)

            x0, y0, z0 = self._get_ball_position(theta0, y_offset)
            
            if i == self.stationary_idx:
                label = f"Stationary Ball (180°)"
            else:
                label = f"Ball {i+1}: {np.degrees(theta0):.0f}°"

            # Plot ball as a scatter point
            ball = self.ax.scatter([x0], [y0], [z0], c=color, s=250, zorder=5, label=label)
            self.balls.append(ball)

        # Add equation text
        cycloid_eq = "Inverted Cycloid: <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>x</mi><mo>=</mo><mi>r</mi><mo stretchy="false">(</mo><mspace linebreak="newline"></mspace><mi>t</mi><mi>h</mi><mi>e</mi><mi>t</mi><mi>a</mi><mo>−</mo><mspace linebreak="newline"></mspace><mi>s</mi><mi>i</mi><mi>n</mi><mspace linebreak="newline"></mspace><mi>t</mi><mi>h</mi><mi>e</mi><mi>t</mi><mi>a</mi><mo stretchy="false">)</mo></mrow><annotation encoding="application/x-tex">x = r(\\theta - \\sin\\theta)</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.4306em;"></span><span class="mord mathnormal">x</span><span class="mspace" style="margin-right:0.2778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2778em;"></span></span><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord mathnormal" style="margin-right:0.02778em;">r</span><span class="mopen">(</span></span><span class="mspace newline"></span><span class="base"><span class="strut" style="height:0.7778em;vertical-align:-0.0833em;"></span><span class="mord mathnormal">t</span><span class="mord mathnormal">h</span><span class="mord mathnormal">e</span><span class="mord mathnormal">t</span><span class="mord mathnormal">a</span><span class="mspace" style="margin-right:0.2222em;"></span><span class="mbin">−</span></span><span class="mspace newline"></span><span class="base"><span class="strut" style="height:0.6595em;"></span><span class="mord mathnormal">s</span><span class="mord mathnormal">in</span></span><span class="mspace newline"></span><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord mathnormal">t</span><span class="mord mathnormal">h</span><span class="mord mathnormal">e</span><span class="mord mathnormal">t</span><span class="mord mathnormal">a</span><span class="mclose">)</span></span></span></span>, <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>z</mi><mo>=</mo><mo>−</mo><mi>r</mi><mo stretchy="false">(</mo><mn>1</mn><mo>−</mo><mspace linebreak="newline"></mspace><mi>c</mi><mi>o</mi><mi>s</mi><mspace linebreak="newline"></mspace><mi>t</mi><mi>h</mi><mi>e</mi><mi>t</mi><mi>a</mi><mo stretchy="false">)</mo></mrow><annotation encoding="application/x-tex">z = -r(1 - \\cos\\theta)</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.4306em;"></span><span class="mord mathnormal" style="margin-right:0.04398em;">z</span><span class="mspace" style="margin-right:0.2778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2778em;"></span></span><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord">−</span><span class="mord mathnormal" style="margin-right:0.02778em;">r</span><span class="mopen">(</span><span class="mord">1</span><span class="mspace" style="margin-right:0.2222em;"></span><span class="mbin">−</span></span><span class="mspace newline"></span><span class="base"><span class="strut" style="height:0.4306em;"></span><span class="mord mathnormal">cos</span></span><span class="mspace newline"></span><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord mathnormal">t</span><span class="mord mathnormal">h</span><span class="mord mathnormal">e</span><span class="mord mathnormal">t</span><span class="mord mathnormal">a</span><span class="mclose">)</span></span></span></span> ; <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mo stretchy="false">(</mo><mn>0</mn><mspace linebreak="newline"></mspace><mi>l</mi><mi>e</mi><mi>q</mi><mspace linebreak="newline"></mspace><mi>t</mi><mi>h</mi><mi>e</mi><mi>t</mi><mi>a</mi><mspace linebreak="newline"></mspace><mi>l</mi><mi>e</mi><mi>q</mi><mn>2</mn><mspace linebreak="newline"></mspace><mi>p</mi><mi>i</mi><mo stretchy="false">)</mo></mrow><annotation encoding="application/x-tex">(0 \\leq \\theta \\leq 2\\pi)</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mopen">(</span><span class="mord">0</span></span><span class="mspace newline"></span><span class="base"><span class="strut" style="height:0.8889em;vertical-align:-0.1944em;"></span><span class="mord mathnormal" style="margin-right:0.01968em;">l</span><span class="mord mathnormal">e</span><span class="mord mathnormal" style="margin-right:0.03588em;">q</span></span><span class="mspace newline"></span><span class="base"><span class="strut" style="height:0.6944em;"></span><span class="mord mathnormal">t</span><span class="mord mathnormal">h</span><span class="mord mathnormal">e</span><span class="mord mathnormal">t</span><span class="mord mathnormal">a</span></span><span class="mspace newline"></span><span class="base"><span class="strut" style="height:0.8889em;vertical-align:-0.1944em;"></span><span class="mord mathnormal" style="margin-right:0.01968em;">l</span><span class="mord mathnormal">e</span><span class="mord mathnormal" style="margin-right:0.03588em;">q</span><span class="mord">2</span></span><span class="mspace newline"></span><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord mathnormal">p</span><span class="mord mathnormal">i</span><span class="mclose">)</span></span></span></span>"
        self.ax.text2D(0.5, 0.85, cycloid_eq, transform=self.ax.transAxes, 
                      ha='center', va='top', fontsize=10, color='white',
                      bbox=dict(facecolor='black', alpha=0.8, edgecolor='gray', boxstyle='round,pad=0.3'))

        self.ax.legend(bbox_to_anchor=(0.5, 0.2), loc='lower center', ncols=2, fontsize=8, markerscale=0.5)

        # Set initial view angle
        self.ax.view_init(elev=10, azim=280) 

    def _update(self, frame):
        """Updates the animation elements for each frame."""
        t = frame / self.fps
        artists_to_update = []

        for i in range(self.num_paths):
            y_offset = (i + 0.5) * (self.r * self.num_paths * (1 + self.cuboid_padding) / self.num_paths)
            
            # Determine if this is the stationary ball
            is_stationary = (i == self.stationary_idx)
            theta0 = self.theta0_list[i]
            
            # Get current position
            current_theta = self._get_theta_at_time(theta0, t, is_stationary)
            x, y, z = self._get_ball_position(current_theta, y_offset)
            
            # Update ball position
            self.balls[i]._offsets3d = ([x], [y], [z])
            artists_to_update.append(self.balls[i])

        return artists_to_update

    def animate(self):
        """Creates and runs (or saves) the animation."""
        self._setup_plot()
        if self.fig is None:
            raise ValueError("Figure was not properly initialized")
        ani = FuncAnimation(self.fig, self._update,
                           frames=self.num_frames,
                           interval=int(1000 / self.fps),
                           blit=False,
                           repeat=False)
        if self.save_anim:
            if not self.filename:
                angles_str = "_".join(map(str, self.initial_angles_deg))
                self.filename = f"tautochrone_3d_{self.num_balls}balls_{angles_str}deg.gif"
            save_dir = "ANIMATIONS/TAUTOCHRONE"
            os.makedirs(save_dir, exist_ok=True)
            filepath = os.path.join(save_dir, self.filename)
            print(f"Saving animation to {os.path.abspath(filepath)}...")
            try:
                ani.save(filepath, writer='pillow', fps=self.fps)
                print("Animation saved successfully!")
            except Exception as e:
                print(f"Error saving animation: {e}")
                print("Make sure you have necessary writers installed (e.g., Pillow).")
                print("Attempting to show animation instead.")
                plt.show()
            finally:
                plt.close(self.fig)
        else:
            plt.show()
        return ani

if __name__ == "__main__":
    animator = TautochroneAnimator3D(
        initial_angles_deg=[0, 90, 135, 180, 225, 270, 360], 
        r=10,
        ball_radius=1.6,
        g=9.8,
        fps=50,
        save_anim=True
    )
    anim = animator.animate()