Here is an updated script that matches the structure of the original figure.
python
import numpy as np import matplotlib.pyplot as plt  # ---- weight generators ------------------------------------------------------ def linear(n):                 # 1 → 0     return 1 - np.arange(n) / n  def uniform(n):                # flat     return np.ones(n)  def slow_linear(n, k=0.5):     # gentler slope     return 1 - k * np.arange(n) / n  # ---- single‑panel plot ------------------------------------------------------ def draw(ax, w_x=None, w_y=None, steps=0, total=30):     dx = 1 / total     if w_y is not None:         ax.bar(np.arange(steps) / total, w_y, width=dx,                color='tab:blue', align='edge', edgecolor='none')     if w_x is not None:         ax.bar(np.arange(steps) / total, w_x, width=dx,                color='tab:orange', align='edge', edgecolor='none')     ax.set_xlim(0, 1)     ax.set_ylim(0, 1)     ax.set_xticks([0, 1]); ax.set_xticklabels(['0 %', '100 %'])     ax.set_yticks([0, .5, 1])     ax.tick_params(axis='both', length=0)     ax.spines[['top', 'right']].set_visible(False)  # ---- figure ----------------------------------------------------------------- T = 30 cols = [T // 3, 2*T // 3, T]  fig, ax = plt.subplots(4, 3, figsize=(10, 10), constrained_layout=True)  col_lbl = ['1/3 through Training', '2/3 through Training', 'End of Training'] row_lbl = ['Linear Decay Schedule',            'Polyak Averaging\n$y$: constant   $x$: linear',            'Primal Averaging\n$x,y$: linear',            'Schedule‑Free\n$y$: slow linear   $x$: linear']  for j, t in enumerate(col_lbl):     ax[0, j].set_title(t, fontsize=10) for i, lab in enumerate(row_lbl):     ax[i, 0].set_ylabel(lab, rotation=0, ha='right',                         va='center', fontsize=9, labelpad=70)  for j, n in enumerate(cols):     # row 0     draw(ax[0, j], w_y=linear(n), steps=n, total=T)     # row 1     draw(ax[1, j], w_x=linear(n), w_y=uniform(n), steps=n, total=T)     # row 2     draw(ax[2, j], w_x=linear(n), steps=n, total=T)     # row 3     draw(ax[3, j], w_x=linear(n), w_y=slow_linear(n, 0.5), steps=n, total=T)  for a in ax[-1, :]:     a.set_xlabel('T')  plt.show()
- Blue bars: gradient‑location sequence yyy
 - Orange bars: evaluation sequence xxx The panels now reproduce the piece‑wise linear wedges (including the blank region after the current step count) and the colour/height pattern in the reference.