potter.py (original) (raw)

#!/usr/bin/env python

# -*- coding: utf-8 -*-

DISCOUNTS = [0, 1, 0.95, 0.9, 0.8, 0.75]

def remain(books, n, m):

"""

Given a set of books, remove m from each of the first n kinds of

books. Return the new book set.

"""

new_books = books[:]

for i in xrange(n):

new_books[i] -= m

return new_books

def discount(n, m):

""" Calculate how much are m sets of n different kinds of books. """

return 8 * n * DISCOUNTS[n] * m

def potter(books):

"""

calculate the best price of the book set.

input example:

[3, 2, 2, 8] means

3 first books

2 second books

2 third books

8 4th books

"""

## sort the books

books = sorted(books, reverse=True)

## remove zeros in the tail

if 0 in books:

books = books[:books.index(0)]

## if no books, the best price is 0 for sure

if not books:

return 0

## only 1 kind of books

elif len(books) == 1:

return 8 * books[0]

## 2 kinds of books

elif len(books) == 2:

return discount(2, min(books)) + potter(remain(books, 2, min(books)))

## 3 kinds of ...

elif len(books) == 3:

return discount(3, min(books)) + potter(remain(books, 3, min(books)))

## otherwise

return min([discount(n, 1) + potter(remain(books, n, 1))

for n in xrange(len(books), 3, -1)])

def price(set):

"""

Firstly count how many different kind of books are there,

then pass the result to function potter to calculate the best price.

example input:

[0,0,3,3,3] means

2 first books

3 4th books

"""

books = [0] * 5

for n in set:

books[n] += 1

return potter(books)

def tests():

def assert_equals(a, b):

print a, b

assert a == b

assert_equals(price([]), 0)

assert_equals(price([0]), 8)

assert_equals(price([0, 1]), 8 * 2 * 0.95)

assert_equals(price([0, 1, 2]), 8 * 3 * .9)

assert_equals(price([0, 1, 2, 3]), 8 * 4 * .8)

assert_equals(price([0, 1, 2, 3, 4]), 8 * 5 * .75)

assert_equals(price([2, 2]), 8 * 2)

assert_equals(price([0, 0, 1]), 8 * 2 * 0.95 + 8 * 1)

assert_equals(price([1, 0, 0]), 8 * 2 * 0.95 + 8 * 1)

assert_equals(price([0, 0, 1, 1]), 2 * (8 * 2 * 0.95))

assert_equals(price([0, 0, 1, 1, 2, 2, 3, 4]), 2 * (4 * 8 * .80))

assert_equals(price([0, 0, 1, 1, 2, 3, 3, 4]), 2 * (8 * 4 * .8))

assert_equals(price([0, 0, 1, 1, 2, 2, 3, 3, 4, 4]), 2 * (8 * 5 * .75))

assert_equals(price([0, 0, 0, 0, 0,

1, 1, 1, 1, 1,

2, 2, 2, 2,

3, 3, 3, 3, 3,

4, 4, 4, 4]),

3 * (8 * 5 * 0.75) + 2 * (8 * 4 * 0.8))

print potter([20, 20, 20, 20, 4])

tests()