matplotlib savefig performance, saving multiple pngs within loop

Since the limitation in this case is the call to plt.savefig() it cannot be optimized a lot. Internally the figure is rendered from scratch and that takes a while. Possibly reducing the number of vertices to be drawn might reduce the time a bit.

The time to run your code on my machine (Win 8, i5 with 4 cores 3.5GHz) is 2.5 seconds. This seems not too bad. One can get a little improvement by using Multiprocessing.

A note about Multiprocessing: It may seem surprising that using the state machine of pyplot inside multiprocessing should work at all. But it does.
And in this case here, since every image is based on the same figure and axes object, one does not even have to create new figures and axes.

I modified an answer I gave here a while ago for your case and the total time is roughly halved using multiprocessing and 5 processes on 4 cores. I appended a barplot which shows the effect of multiprocessing.

import numpy as np
#import matplotlib as mpl
#mpl.use('agg') # use of agg seems to slow things down a bit
import matplotlib.pyplot as plt
import multiprocessing
import time, os

def make_plot(d):
    start = time.clock()
    x,y=d
    #using aspect in this way causes a warning for me
    #aspect = np.random.random(1)+y/2.0-x 
    xrand = np.random.random(2)*x
    xlim = [min(xrand), max(xrand)]
    yrand = np.random.random(2)*y
    ylim = [min(yrand), max(yrand)]
    filename="{:d}_{:d}.png".format(x,y)
    ax = plt.gca()
    #ax.set_aspect(abs(aspect[0]))
    ax.set_xlim(xlim)
    ax.set_ylim(ylim)
    plt.savefig('figs/'+filename)
    stop = time.clock()
    return np.array([x,y, start, stop])

if not os.path.isdir('figs'):
    os.makedirs('figs')
data = np.random.rand(25, 25)

fig = plt.figure()
ax = fig.add_axes([0., 0., 1., 1.])
ax.imshow(data, interpolation='nearest')


some_list = []
for i in range(1, 8):
    for j in range(3, 13):
        some_list.append((i,j))


if __name__ == "__main__":
    multiprocessing.freeze_support()
    tstart = time.clock()
    print tstart
    num_proc = 5
    p = multiprocessing.Pool(num_proc)

    nu = p.map(make_plot, some_list)

    tooktime="Plotting of {} frames took {:.2f} seconds"
    tooktime = tooktime.format(len(some_list), time.clock()-tstart)
    print tooktime
    nu = np.array(nu)

    plt.close("all")
    fig, ax = plt.subplots(figsize=(8,5))
    plt.suptitle(tooktime)
    ax.barh(np.arange(len(some_list)), nu[:,3]-nu[:,2], 
            height=np.ones(len(some_list)), left=nu[:,2],  align="center")
    ax.set_xlabel("time [s]")
    ax.set_ylabel("image number")
    ax.set_ylim([-1,70])
    plt.tight_layout()
    plt.savefig(__file__+".png")
    plt.show()

enter image description here

Leave a Comment