Skip to content

Commit 07408a8

Browse files
committed
Fix vulnerability to evaluating at points with 0.0
1 parent d1c046a commit 07408a8

File tree

2 files changed

+23
-21
lines changed

2 files changed

+23
-21
lines changed

interpolation/complete_poly.py

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -358,7 +358,7 @@ def _complete_poly_der_impl_vec(z, d, der, out):
358358

359359
# Update index and out
360360
ix += 1
361-
out[ix] = c2 * t1*t2 * z[der]**(c2-1)
361+
out[ix] = c2 * t1*t2 * z[der]**(c2-1) if c2>0 else 0.0
362362

363363
return
364364

@@ -369,12 +369,12 @@ def _complete_poly_der_impl_vec(z, d, der, out):
369369
for i2 in range(i1, nvar):
370370
(c2, t2) = (c1+1, 1.0) if i2==der else (c1, z[i2])
371371
ix += 1
372-
out[ix] = c2 * t1*t2 * z[der]**(c2-1)
372+
out[ix] = c2 * t1*t2 * z[der]**(c2-1) if c2>0 else 0.0
373373

374374
for i3 in range(i2, nvar):
375375
(c3, t3) = (c2+1, 1.0) if i3==der else (c2, z[i3])
376376
ix += 1
377-
out[ix] = c3 * t1*t2*t3 * z[der]**(c3-1)
377+
out[ix] = c3 * t1*t2*t3 * z[der]**(c3-1) if c3>0 else 0.0
378378

379379
return
380380

@@ -385,17 +385,17 @@ def _complete_poly_der_impl_vec(z, d, der, out):
385385
for i2 in range(i1, nvar):
386386
(c2, t2) = (c1+1, 1.0) if i2==der else (c1, z[i2])
387387
ix += 1
388-
out[ix] = c2 * t1*t2* z[der]**(c2-1)
388+
out[ix] = c2 * t1*t2* z[der]**(c2-1) if c2>0 else 0.0
389389

390390
for i3 in range(i2, nvar):
391391
(c3, t3) = (c2+1, 1.0) if i3==der else (c2, z[i3])
392392
ix += 1
393-
out[ix] = c3 * t1*t2*t3* z[der]**(c3-1)
393+
out[ix] = c3*t1*t2*t3*z[der]**(c3-1) if c3>0 else 0.0
394394

395395
for i4 in range(i3, nvar):
396396
(c4, t4) = (c3+1, 1.0) if i4==der else (c3, z[i4])
397397
ix += 1
398-
out[ix] = c4 * t1*t2*t3*t4 * z[der]**(c4-1)
398+
out[ix] = c4*t1*t2*t3*t4*z[der]**(c4-1) if c4>0 else 0.0
399399

400400
return
401401

@@ -406,22 +406,22 @@ def _complete_poly_der_impl_vec(z, d, der, out):
406406
for i2 in range(i1, nvar):
407407
(c2, t2) = (c1+1, 1.0) if i2==der else (c1, z[i2])
408408
ix += 1
409-
out[ix] = c2 * t1*t2* z[der]**(c2-1)
409+
out[ix] = c2 * t1*t2* z[der]**(c2-1) if c2>0 else 0.0
410410

411411
for i3 in range(i2, nvar):
412412
(c3, t3) = (c2+1, 1.0) if i3==der else (c2, z[i3])
413413
ix += 1
414-
out[ix] = c3 * t1*t2*t3* z[der]**(c3-1)
414+
out[ix] = c3 * t1*t2*t3* z[der]**(c3-1) if c3>0 else 0.0
415415

416416
for i4 in range(i3, nvar):
417417
(c4, t4) = (c3+1, 1.0) if i4==der else (c3, z[i4])
418418
ix += 1
419-
out[ix] = c4 * t1*t2*t3*t4 * z[der]**(c4-1)
419+
out[ix] = c4*t1*t2*t3*t4*z[der]**(c4-1) if c4>0 else 0.0
420420

421421
for i5 in range(i4, nvar):
422422
(c5, t5) = (c4+1, 1.0) if i5==der else (c4, z[i5])
423423
ix += 1
424-
out[ix] = c5 * t1*t2*t3*t4*t5 * z[der]**(c5-1)
424+
out[ix] = c5*t1*t2*t3*t4*t5*z[der]**(c5-1) if c5>0 else 0.0
425425

426426
return
427427

@@ -458,7 +458,7 @@ def _complete_poly_der_impl(z, d, der, out):
458458
for k in range(nobs):
459459
c1, t1 = (1, 1.0) if i1==der else (0, z[i1, k])
460460
c2, t2 = (c1+1, 1.0) if i2==der else (c1, z[i2, k])
461-
out[ix, k] = c2 * t1*t2 * z[der, k]**(c2-1)
461+
out[ix, k] = c2*t1*t2*z[der, k]**(c2-1) if c2>0 else 0.0
462462

463463
return
464464

@@ -469,13 +469,13 @@ def _complete_poly_der_impl(z, d, der, out):
469469
for k in range(nobs):
470470
c1, t1 = (1, 1.0) if i1==der else (0, z[i1, k])
471471
c2, t2 = (c1+1, 1.0) if i2==der else (c1, z[i2, k])
472-
out[ix, k] = c2 * t1*t2 * z[der, k]**(c2-1)
472+
out[ix, k] = c2*t1*t2*z[der, k]**(c2-1) if c2>0 else 0.0
473473

474474
for i3 in range(i2, nvar):
475475
ix += 1
476476
for k in range(nobs):
477477
c3, t3 = (c2+1, 1.0) if i3==der else (c2, z[i3, k])
478-
out[ix, k] = c3 * t1*t2*t3 * z[der, k]**(c3-1)
478+
out[ix, k] = c3*t1*t2*t3*z[der, k]**(c3-1) if c3>0 else 0.0
479479

480480
return
481481

@@ -486,19 +486,19 @@ def _complete_poly_der_impl(z, d, der, out):
486486
for k in range(nobs):
487487
c1, t1 = (1, 1.0) if i1==der else (0, z[i1, k])
488488
c2, t2 = (c1+1, 1.0) if i2==der else (c1, z[i2, k])
489-
out[ix, k] = c2 * t1*t2 * z[der, k]**(c2-1)
489+
out[ix, k] = c2*t1*t2*z[der, k]**(c2-1) if c2>0 else 0.0
490490

491491
for i3 in range(i2, nvar):
492492
ix += 1
493493
for k in range(nobs):
494494
c3, t3 = (c2+1, 1.0) if i3==der else (c2, z[i3, k])
495-
out[ix, k] = c3 * t1*t2*t3 * z[der, k]**(c3-1)
495+
out[ix, k] = c3*t1*t2*t3*z[der, k]**(c3-1) if c3>0 else 0.0
496496

497497
for i4 in range(i3, nvar):
498498
ix += 1
499499
for k in range(nobs):
500500
c4, t4 = (c3+1, 1.0) if i4==der else (c3, z[i4, k])
501-
out[ix, k] = c4 * t1*t2*t3*t4 * z[der, k]**(c4-1)
501+
out[ix, k] = c4*t1*t2*t3*t4*z[der, k]**(c4-1) if c4>0 else 0.0
502502

503503
return
504504

@@ -509,25 +509,25 @@ def _complete_poly_der_impl(z, d, der, out):
509509
for k in range(nobs):
510510
c1, t1 = (1, 1.0) if i1==der else (0, z[i1, k])
511511
c2, t2 = (c1+1, 1.0) if i2==der else (c1, z[i2, k])
512-
out[ix, k] = c2 * t1*t2 * z[der, k]**(c2-1)
512+
out[ix, k] = c2*t1*t2*z[der, k]**(c2-1) if c2>0 else 0.0
513513

514514
for i3 in range(i2, nvar):
515515
ix += 1
516516
for k in range(nobs):
517517
c3, t3 = (c2+1, 1.0) if i3==der else (c2, z[i3, k])
518-
out[ix, k] = c3 * t1*t2*t3 * z[der, k]**(c3-1)
518+
out[ix, k] = c3*t1*t2*t3*z[der, k]**(c3-1) if c3>0 else 0.0
519519

520520
for i4 in range(i3, nvar):
521521
ix += 1
522522
for k in range(nobs):
523523
c4, t4 = (c3+1, 1.0) if i4==der else (c3, z[i4, k])
524-
out[ix, k] = c4 * t1*t2*t3*t4 * z[der, k]**(c4-1)
524+
out[ix, k] = c4*t1*t2*t3*t4*z[der, k]**(c4-1) if c4>0 else 0.0
525525

526526
for i5 in range(i4, nvar):
527527
ix += 1
528528
for k in range(nobs):
529529
c5, t5 = (c4+1, 1.0) if i5==der else (c4, z[i5, k])
530-
out[ix, k] = c5 * t1*t2*t3*t4*t5 * z[der, k]**(c5-1)
530+
out[ix, k] = c5*t1*t2*t3*t4*t5*z[der, k]**(c5-1) if c5>0 else 0.0
531531

532532
return
533533

interpolation/tests/test_complete.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,9 +66,11 @@ def test_complete_derivative():
6666
out_mat = complete_polynomial_der(z, 2, 1)
6767
assert(abs(out_mat[0, :]).max() < 1e-10)
6868
assert(abs(out_mat[2, :] - np.ones(2)).max() < 1e-10)
69-
assert(abs(out_mat[-1, :] - np.array([5.0, 6.0])).max() < 1e-10)
69+
assert(abs(out_mat[-2, :] - np.array([5.0, 6.0])).max() < 1e-10)
7070

7171

7272
if __name__ == '__main__':
7373
test_complete_scalar()
7474
test_complete_vector()
75+
test_complete_derivative()
76+

0 commit comments

Comments
 (0)