This commit is contained in:
Josh Deprez 2021-10-03 15:34:37 +11:00
parent c74d915542
commit 1586d25f2b
2 changed files with 179 additions and 23 deletions

View file

@ -56,22 +56,26 @@ func (s *LinearSpline) Interpolate(x float64) float64 {
return x <= s.Points[n].X
})
if x == s.Points[i].X {
// Hit a point exactly
// Hit the point i exactly
return s.Points[i].Y
}
// In the interval between point i and point i+1
return s.Points[i].Y + (x-s.Points[i].X)*s.deriv[i]
// In the interval between point i-1 and point i
return s.Points[i-1].Y + (x-s.Points[i-1].X)*s.deriv[i-1]
}
// CubicSpline implements a cubic spline.
// CubicSpline implements a natural cubic spline. A cubic spline interpolates
// the given Points while ensuring first and second derivatives are continuous.
type CubicSpline struct {
// Points on the spline
Points []Float2
deriv, deriv2 []float64
// moments and intervals
m, h []float64
// slope of line before and after spline, for extrapolation
preslope, postslope float64
}
// Prepare
// Prepare sorts the points and computes internal information.
func (s *CubicSpline) Prepare() error {
if len(s.Points) < 1 {
return errors.New("spline needs at least 1 point")
@ -80,21 +84,94 @@ func (s *CubicSpline) Prepare() error {
sort.Slice(s.Points, func(i, j int) bool {
return s.Points[i].X < s.Points[j].X
})
// Check for points with equal X.
// Check for points with equal X, and compute intervals.
N := len(s.Points)
if N == 1 {
return nil
}
s.m = make([]float64, N)
s.h = make([]float64, N-1)
for i := range s.Points[1:] {
if s.Points[i].X == s.Points[i+1].X {
return fmt.Errorf("spline value defined twice [%v, %v]", s.Points[i], s.Points[i+1])
}
s.h[i] = s.Points[i+1].X - s.Points[i].X
}
// TODO: compute deriv and deriv2
// Compute moments. m[0] and m[N-1] are chosen to be 0 (natural cubic spline).
// Given:
// ɣ(i) = 2.0 * (h[i-1] + h[i])
// b(i) = 6.0 * ((Points[i+1].Y-Points[i].Y)/h[i] - (Points[i].Y-Points[i-1].Y)/h[i-1])
// we solve for m[i] in the equations:
// h[i-1]*m[i-1] + ɣ(i)*m[i] + h[i]*m[i+1] = b(i)
// for i = 1...N-2.
//
// Written as a diagonally dominant tridiagonal matrix equation:
//
// [ɣ(1) h[1] 0 0 ... 0 ] [ m[1] ] [ b(1) ]
// [h[1] ɣ(2) h[2] 0 ... 0 ] [ m[2] ] [ b(2) ]
// [0 h[2] ɣ(3) h[3] ... 0 ] [ m[3] ] = [ b(3) ]
// [0 0 h[3] ɣ(4) ... ... ] [ ... ] [ ... ]
// [...................... ... h[N-3] ] [ ... ] [ ... ]
// [0 0 ... 0 h[N-3] ɣ(N-2) ] [ m[N-2] ] [ b(N-2) ]
//
// This is solvable in O(N) using simplified Gaussian elimination
// ("Thomas algorithm").
// Setup:
diag, upper, B := make([]float64, N-1), make([]float64, N-1), make([]float64, N-1)
for i := 1; i < N-1; i++ {
diag[i] = 2.0 * (s.h[i-1] + s.h[i])
upper[i] = s.h[i]
B[i] = 6.0 * ((s.Points[i+1].Y-s.Points[i].Y)/s.h[i] - (s.Points[i].Y-s.Points[i-1].Y)/s.h[i-1])
}
// Forward elimination:
for i := 2; i < N-1; i++ {
t := s.h[i-1] / diag[i-1] // lower[i] / diag[i-1]
diag[i] -= t * upper[i-1]
B[i] -= t * B[i-1]
}
// Back substitution:
for i := N - 2; i > 0; i-- {
s.m[i] = (B[i] - s.h[i]*s.m[i+1]) / diag[i]
}
// Divide all the moments by 6, since all the terms with moments in them
// from this point onwards are divided by six.
for i := range s.m {
s.m[i] /= 6.0
}
// Pre- and post-slope:
s.preslope = -s.m[1]*s.h[0] + (s.Points[1].Y-s.Points[0].Y)/s.h[0]
s.postslope = s.m[N-2]*s.h[N-2] + (s.Points[N-1].Y-s.Points[N-2].Y)/s.h[N-2]
return nil
}
// Interpolate, given x, returns y where (x,y) is a point on the spline.
// If x is outside the spline, it extrapolates from either the first or
// last segments of the spline.
func (s *CubicSpline) Interpolate(x float64) float64 {
N := len(s.Points)
if N == 1 {
return s.Points[0].Y
}
// TODO
return 0
if x < s.Points[0].X {
// Comes before the start of the spline, extrapolate
return s.Points[0].Y + (x-s.Points[0].X)*s.preslope
}
if x > s.Points[N-1].X {
// Comes after the end of the spline, extrapolate
return s.Points[N-1].Y + (x-s.Points[N-1].X)*s.postslope
}
// Somewhere in the middle
i := sort.Search(N, func(n int) bool {
return x <= s.Points[n].X
})
if x == s.Points[i].X {
// Hit the point i exactly
return s.Points[i].Y
}
// In the interval between point i-1 and point i
x0, x1 := x-s.Points[i-1].X, s.Points[i].X-x
return (s.m[i-1]*(x1*x1*x1)+s.m[i]*(x0*x0*x0))/s.h[i-1] -
(s.m[i-1]*x1+s.m[i]*x0)*s.h[i-1] +
(s.Points[i-1].Y*x1+s.Points[i].Y*x0)/s.h[i-1]
}

View file

@ -1,6 +1,9 @@
package geom
import "testing"
import (
"math"
"testing"
)
func TestLinearSplineNoPoints(t *testing.T) {
s := &LinearSpline{}
@ -49,21 +52,21 @@ func TestLinearSpline(t *testing.T) {
{x: -6, want: -0.5},
{x: -5.5, want: 0.25},
{x: -5, want: 1},
{x: -4.5, want: 4.5},
{x: -4, want: 3},
{x: -3.5, want: 1.5},
{x: -4.5, want: 0.75},
{x: -4, want: 0.5},
{x: -3.5, want: 0.25},
{x: -3, want: 0},
{x: -2.5, want: -4.25},
{x: -2.5, want: -1.5},
{x: -2, want: -3},
{x: -1.5, want: 12.5},
{x: -1, want: 9},
{x: -0.5, want: 5.5},
{x: -1.5, want: -1.75},
{x: -1, want: -0.5},
{x: -0.5, want: 0.75},
{x: 0, want: 2},
{x: 0.5, want: -5.75},
{x: 0.5, want: -1.5},
{x: 1, want: -5},
{x: 1.5, want: -11},
{x: 2, want: -8},
{x: 2.5, want: -5},
{x: 1.5, want: -4.25},
{x: 2, want: -3.5},
{x: 2.5, want: -2.75},
{x: 3, want: -2},
{x: 3.5, want: 1},
{x: 4, want: 4},
@ -77,3 +80,79 @@ func TestLinearSpline(t *testing.T) {
}
}
}
func TestCubicSplineNoPoints(t *testing.T) {
s := &CubicSpline{}
if err := s.Prepare(); err == nil {
t.Errorf("s.Prepare() = %v, want error", err)
}
}
func TestCubicSplineEqualXPoints(t *testing.T) {
s := &CubicSpline{
Points: []Float2{{-5, 1}, {-2, 7}, {-2, -3}, {0, 2}, {3, -2}},
}
if err := s.Prepare(); err == nil {
t.Errorf("s.Prepare() = %v, want error", err)
}
}
func TestCubicSplineOnePoint(t *testing.T) {
s := &CubicSpline{
Points: []Float2{{-2, -3}},
}
if err := s.Prepare(); err != nil {
t.Errorf("s.Prepare() = %v, want nil", err)
}
for _, x := range []float64{-5, -4, -2, 0, 1, 7} {
if got, want := s.Interpolate(x), float64(-3); got != want {
t.Errorf("s.Interpolate(%v) = %v, want %v", x, got, want)
}
}
}
func TestCubicSpline(t *testing.T) {
s := &CubicSpline{
Points: []Float2{{-7, -2}, {-5, 1}, {-3, 0}, {-2, -3}, {0, 2}, {1, -5}, {3, -2}, {4, 4}},
}
if err := s.Prepare(); err != nil {
t.Errorf("s.Prepare() = %v, want nil", err)
}
tests := []struct {
x, want float64
}{
{x: -8, want: -3.648342225609756},
{x: -7.5, want: -2.824171112804878},
{x: -7, want: -2},
{x: -6.5, want: -1.180464581745427},
{x: -6, want: -0.3887433307926829},
{x: -5.5, want: 0.3473495855564025},
{x: -5, want: 1},
{x: -4.5, want: 1.5067079125381098},
{x: -4, want: 1.6662299923780488},
{x: -3.5, want: 1.2426370760289636},
{x: -3, want: 0},
{x: -2.5, want: -1.9368449885670733},
{x: -2, want: -3},
{x: -1.5, want: -1.855450886051829},
{x: -1, want: 0.45221989329268286},
{x: -0.5, want: 2.2837807259908534},
{x: 0, want: 2},
{x: 0.5, want: -1.229539824695122},
{x: 1, want: -5},
{x: 1.5, want: -6.734946646341463},
{x: 2, want: -6.406821646341463},
{x: 2.5, want: -4.6252858231707314},
{x: 3, want: -2},
{x: 3.5, want: 0.941477705792683},
{x: 4, want: 4},
{x: 4.5, want: 7.078029725609756},
{x: 5, want: 10.156059451219512},
{x: 5.5, want: 13.234089176829269},
}
for _, test := range tests {
if got := s.Interpolate(test.x); math.Abs(got-test.want) > 0.0000001 {
t.Errorf("s.Interpolate(%v) = %v, want %v", test.x, got, test.want)
}
}
}