Skip to content

Commit

Permalink
BUG: Fix work array construction for various weight shapes. (#18741)
Browse files Browse the repository at this point in the history
[skip ci]

Closes gh-18739
  • Loading branch information
rkern committed Jun 28, 2023
1 parent 5cac3d9 commit 899f4ef
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 3 deletions.
12 changes: 9 additions & 3 deletions scipy/odr/_odrpack.py
Original file line number Diff line number Diff line change
Expand Up @@ -899,10 +899,16 @@ def _gen_work(self):
elif len(self.data.we.shape) == 3:
ld2we, ldwe = self.data.we.shape[1:]
else:
# Okay, this isn't precisely right, but for this calculation,
# it's fine
we = self.data.we
ldwe = 1
ld2we = self.data.we.shape[1]
ld2we = 1
if we.ndim == 1 and q == 1:
ldwe = n
elif we.ndim == 2:
if we.shape == (q, q):
ld2we = q
elif we.shape == (q, n):
ldwe = n

if self.job % 10 < 2:
# ODR not OLS
Expand Down
29 changes: 29 additions & 0 deletions scipy/odr/tests/test_odr.py
Original file line number Diff line number Diff line change
Expand Up @@ -531,3 +531,32 @@ def func(b, x):
p = Model(func)
p.set_meta(name='Sample Model Meta', ref='ODRPACK')
assert_equal(p.meta, {'name': 'Sample Model Meta', 'ref': 'ODRPACK'})

def test_work_array_del_init(self):
"""
Verify fix for gh-18739 where del_init=1 fails.
"""
def func(b, x):
return b[0] + b[1] * x

# generate some data
n_data = 4
x = np.arange(n_data)
y = np.where(x % 2, x + 0.1, x - 0.1)
x_err = np.full(n_data, 0.1)
y_err = np.full(n_data, 0.1)

linear_model = Model(func)
# Try various shapes of the `we` array from various `sy` and `covy`
rd0 = RealData(x, y, sx=x_err, sy=y_err)
rd1 = RealData(x, y, sx=x_err, sy=0.1)
rd2 = RealData(x, y, sx=x_err, sy=[0.1])
rd3 = RealData(x, y, sx=x_err, sy=np.full((1, n_data), 0.1))
rd4 = RealData(x, y, sx=x_err, covy=[[0.01]])
rd5 = RealData(x, y, sx=x_err, covy=np.full((1, 1, n_data), 0.01))
for rd in [rd0, rd1, rd2, rd3, rd4, rd5]:
odr_obj = ODR(rd, linear_model, beta0=[0.4, 0.4],
delta0=np.full(n_data, -0.1))
odr_obj.set_job(fit_type=0, del_init=1)
# Just make sure that it runs without raising an exception.
odr_obj.run()

0 comments on commit 899f4ef

Please sign in to comment.