[Numpy-discussion] Distance Matrix speed (original) (raw)

Alan G Isaac aisaac at american.edu
Mon Jun 19 00:30:12 EDT 2006


On Sun, 18 Jun 2006, Tim Hochberg apparently wrote:

Alan G Isaac wrote:

On Sun, 18 Jun 2006, Sebastian Beca apparently wrote:

def dist(): d = zeros([N, C], dtype=float) if N < C: for i in range(N): xy = A[i] - B d[i,:] = sqrt(sum(xy**2, axis=1)) return d else: for j in range(C): xy = A - B[j] d[:,j] = sqrt(sum(xy**2, axis=1)) return d

But that is 50% slower than Johannes's version:

def distloehner1(): d = A[:, newaxis, :] - B[newaxis, :, :] d = sqrt((d**2).sum(axis=2)) return d

Are you sure about that? I just ran it through timeit, using Sebastian's array sizes and I get Sebastian's version being 150% faster. This could well be cache size dependant, so may vary from box to box, but I'd expect Sebastian's current version to scale better in general.

No, I'm not sure. Script attached bottom. Most recent output follows: for reasons I have not determined, it doesn't match my previous runs ... Alan

execfile(r'c:\temp\temp.py') dist_beca : 3.042277 dist_loehner1: 3.170026

################################# #THE SCRIPT import sys sys.path.append("c:\temp") import numpy from numpy import * import timeit

K = 10 C = 2500 N = 3 # One could switch around C and N now. A = numpy.random.random( [N, K] ) B = numpy.random.random( [C, K] )

beca

def dist_beca(): d = zeros([N, C], dtype=float) if N < C: for i in range(N): xy = A[i] - B d[i,:] = sqrt(sum(xy2, axis=1)) return d else: for j in range(C): xy = A - B[j] d[:,j] = sqrt(sum(xy2, axis=1)) return d

#loehnert def dist_loehner1(): # drawback: memory usage temporarily doubled # solution see below d = A[:, newaxis, :] - B[newaxis, :, :] # written as 3 expressions for more clarity d = sqrt((d**2).sum(axis=2)) return d

if name == "main": t1 = timeit.Timer('dist_beca()', 'from temp import dist_beca').timeit(100) t8 = timeit.Timer('dist_loehner1()', 'from temp import dist_loehner1').timeit(100) fmt="%-10s:\t"+"%10.6f" print fmt%('dist_beca', t1) print fmt%('dist_loehner1', t8)



More information about the NumPy-Discussion mailing list