From 1e49908fc94aafdbd3e5a270df07d775872cf9d0 Mon Sep 17 00:00:00 2001 From: Ben Root Date: Mon, 6 Jun 2011 11:32:21 -0500 Subject: [PATCH 1/2] This should make the hungarian algorithm accept rectangular cost matrices. Also enabled the tests. NOTE: Only tested on rectangular matrices of shape nxm such that m > n. Tests need to be expanded to test m < n. --- scikits/learn/utils/hungarian.py | 11 ++++----- scikits/learn/utils/tests/test_hungarian.py | 27 +++++++++++++-------- 2 files changed, 22 insertions(+), 16 deletions(-) diff --git a/scikits/learn/utils/hungarian.py b/scikits/learn/utils/hungarian.py index 3e5affdcf82cf..a579cf5922ee4 100644 --- a/scikits/learn/utils/hungarian.py +++ b/scikits/learn/utils/hungarian.py @@ -56,12 +56,13 @@ def compute(self, cost_matrix): """ self.C = cost_matrix.copy() self.n = n = self.C.shape[0] + self.m = m = self.C.shape[1] self.row_uncovered = np.ones(n, dtype=np.bool) - self.col_uncovered = np.ones(n, dtype=np.bool) + self.col_uncovered = np.ones(m, dtype=np.bool) self.Z0_r = 0 self.Z0_c = 0 - self.path = np.zeros((2*n, 2), dtype=int) - self.marked = np.zeros((n, n), dtype=int) + self.path = np.zeros((n+m, 2), dtype=int) + self.marked = np.zeros((n, m), dtype=int) done = False step = 1 @@ -131,9 +132,7 @@ def _step4(self): n = self.n while True: # Find an uncovered zero - raveled_idx = np.argmax(covered_C) - col = raveled_idx % n - row = raveled_idx // n + row, col = np.unravel_index(np.argmax(covered_C), (self.n, self.m)) if covered_C[row, col] == 0: return 6 else: diff --git a/scikits/learn/utils/tests/test_hungarian.py b/scikits/learn/utils/tests/test_hungarian.py index 7b6c8230b8802..f3cb5b4de5e07 100644 --- a/scikits/learn/utils/tests/test_hungarian.py +++ b/scikits/learn/utils/tests/test_hungarian.py @@ -15,11 +15,11 @@ def test_hungarian(): ), ## Rectangular variant - #([[400, 150, 400, 1], - # [400, 450, 600, 2], - # [300, 225, 300, 3]], - # 452 # expected cost - #), + ([[400, 150, 400, 1], + [400, 450, 600, 2], + [300, 225, 300, 3]], + 452 # expected cost + ), # Square ([[10, 10, 8], @@ -29,11 +29,11 @@ def test_hungarian(): ), ## Rectangular variant - #([[10, 10, 8, 11], - # [ 9, 8, 1, 1], - # [ 9, 7, 4, 10]], - # 15 - #), + ([[10, 10, 8, 11], + [ 9, 8, 1, 1], + [ 9, 7, 4, 10]], + 15 + ), ] m = _Hungarian() @@ -54,3 +54,10 @@ def test_find_permutation(): np.testing.assert_array_equal(find_permutation(B, A), np.arange(10)[::-1]) + +if __name__ == '__main__' : + print "find_permutations test..." + test_find_permutation() + print "Hungarian test..." + test_hungarian() + From 38f9a46dc426499836526585283e1afd5101c972 Mon Sep 17 00:00:00 2001 From: Ben Root Date: Mon, 6 Jun 2011 13:58:42 -0500 Subject: [PATCH 2/2] An additional check needed in case where there are fewer columns than rows. All assignments are made, but the algorithm wants to keep going because there are some rows left. --- scikits/learn/utils/hungarian.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/scikits/learn/utils/hungarian.py b/scikits/learn/utils/hungarian.py index a579cf5922ee4..c225e40cc6513 100644 --- a/scikits/learn/utils/hungarian.py +++ b/scikits/learn/utils/hungarian.py @@ -112,7 +112,7 @@ def _step3(self): marked = (self.marked == 1) self.col_uncovered[np.any(marked, axis=0)] = False - if marked.sum() >= self.n: + if marked.sum() >= min(self.m, self.n) : return 7 # done else: return 4 @@ -130,9 +130,10 @@ def _step4(self): covered_C = C*self.row_uncovered[:, np.newaxis] covered_C *= self.col_uncovered.astype(np.int) n = self.n + m = self.m while True: # Find an uncovered zero - row, col = np.unravel_index(np.argmax(covered_C), (self.n, self.m)) + row, col = np.unravel_index(np.argmax(covered_C), (n, m)) if covered_C[row, col] == 0: return 6 else: @@ -211,10 +212,11 @@ def _step6(self): lines. """ # the smallest uncovered value in the matrix - minval = np.min(self.C[self.row_uncovered], axis=0) - minval = np.min(minval[self.col_uncovered]) - self.C[np.logical_not(self.row_uncovered)] += minval - self.C[:, self.col_uncovered] -= minval + if np.any(self.row_uncovered) and np.any(self.col_uncovered): + minval = np.min(self.C[self.row_uncovered], axis=0) + minval = np.min(minval[self.col_uncovered]) + self.C[np.logical_not(self.row_uncovered)] += minval + self.C[:, self.col_uncovered] -= minval return 4 def _find_prime_in_row(self, row):