Merge pull request #257 from adrianVmariano/master

Added pivot to qr, quadratic roots, norm_fro, submatrix_set, block_matrix.
This commit is contained in:
Revar Desmera 2020-09-01 18:15:29 -07:00 committed by GitHub
commit 040747d0e0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 229 additions and 24 deletions

View File

@ -1199,6 +1199,52 @@ function zip(vecs, v2, v3, fit=false, fill=undef) =
: [for(i=[0:1:minlen-1]) [for(v=vecs) for(x=v[i]) x] ];
// Function: block_matrix()
// Usage:
// block_matrix([[M11, M12,...],[M21, M22,...], ... ])
// Description:
// Create a block matrix by supplying a matrix of matrices, which will
// be combined into one unified matrix. Every matrix in one row
// must have the same height, and the combined width of the matrices
// in each row must be equal.
function block_matrix(M) =
let(
bigM = [for(bigrow = M) each zip(bigrow)],
len0=len(bigM[0]),
badrows = [for(row=bigM) if (len(row)!=len0) 1]
)
assert(badrows==[], "Inconsistent or invalid input")
bigM;
// Function: diagonal_matrix()
// Usage:
// diagonal_matrix(diag, [offdiag])
// Description:
// Creates a square matrix with the items in the list `diag` on
// its diagonal. The off diagonal entries are set to offdiag,
// which is zero by default.
function diagonal_matrix(diag,offdiag=0) =
[for(i=[0:1:len(diag)-1]) [for(j=[0:len(diag)-1]) i==j?diag[i] : offdiag]];
// Function: submatrix_set()
// Usage: submatrix_set(M,A,[m],[n])
// Description:
// Sets a submatrix of M equal to the matrix A. By default the top left corner of M is set to A, but
// you can specify offset coordinates m and n. If A (as adjusted by m and n) extends beyond the bounds
// of M then the extra entries are ignored. You can pass in A=[[]], a null matrix, and M will be
// returned unchanged. Note that the input M need not be rectangular in shape.
function submatrix_set(M,A,m=0,n=0) =
assert(is_list(M))
assert(is_list(A))
let( badrows = [for(i=idx(A)) if (!is_list(A[i])) i])
assert(badrows==[], str("Input submatrix malformed rows: ",badrows))
[for(i=[0:1:len(M)-1])
assert(is_list(M[i]), str("Row ",i," of input matrix is not a list"))
[for(j=[0:1:len(M[i])-1])
i>=m && i <len(A)+m && j>=n && j<len(A[0])+n ? A[i-m][j-n] : M[i][j]]];
// Function: array_group()
// Description:
// Takes a flat array of values, and groups items in sets of `cnt` length.

116
math.scad
View File

@ -680,7 +680,7 @@ function convolve(p,q) =
// then the problem is solved for the matrix valued right hand side and a matrix is returned. Note that if you
// want to solve Ax=b1 and Ax=b2 that you need to form the matrix transpose([b1,b2]) for the right hand side and then
// transpose the returned value.
function linear_solve(A,b) =
function linear_solve(A,b,pivot=true) =
assert(is_matrix(A), "Input should be a matrix.")
let(
m = len(A),
@ -688,19 +688,17 @@ function linear_solve(A,b) =
)
assert(is_vector(b,m) || is_matrix(b,m),"Invalid right hand side or incompatible with the matrix")
let (
qr = m<n? qr_factor(transpose(A)) : qr_factor(A),
qr = m<n? qr_factor(transpose(A),pivot) : qr_factor(A,pivot),
maxdim = max(n,m),
mindim = min(n,m),
Q = submatrix(qr[0],[0:maxdim-1], [0:mindim-1]),
R = submatrix(qr[1],[0:mindim-1], [0:mindim-1]),
P = qr[2],
zeros = [for(i=[0:mindim-1]) if (approx(R[i][i],0)) i]
)
zeros != [] ? [] :
m<n
// avoiding input validation in back_substitute
? let( n = len(R) )
Q*reverse(_back_substitute(transpose(R, reverse=true), reverse(b)))
: _back_substitute(R, transpose(Q)*b);
m<n ? Q*back_substitute(R,transpose(P)*b,transpose=true) // Too messy to avoid input checks here
: P*_back_substitute(R, transpose(Q)*b); // Calling internal version skips input checks
// Function: matrix_inverse()
// Usage:
@ -714,20 +712,40 @@ function matrix_inverse(A) =
assert(is_matrix(A,square=true),"Input to matrix_inverse() must be a square matrix")
linear_solve(A,ident(len(A)));
// Function: null_space()
// Usage:
// null_space(A)
// Description:
// Returns an orthonormal basis for the null space of A, namely the vectors {x} such that Ax=0. If the null space
// is just the origin then returns an empty list.
function null_space(A,eps=1e-12) =
assert(is_matrix(A))
let(
Q_R=qr_factor(transpose(A),pivot=true),
R=Q_R[1],
zrow = [for(i=idx(R)) if (is_zero(R[i],eps)) i]
)
len(zrow)==0
? []
: transpose(subindex(Q_R[0],zrow));
// Function: qr_factor()
// Usage: qr = qr_factor(A)
// Usage: qr = qr_factor(A,[pivot])
// Description:
// Calculates the QR factorization of the input matrix A and returns it as the list [Q,R]. This factorization can be
// used to solve linear systems of equations.
function qr_factor(A) =
// Calculates the QR factorization of the input matrix A and returns it as the list [Q,R,P]. This factorization can be
// used to solve linear systems of equations. The factorization is A = Q*R*transpose(P). If pivot is false (the default)
// then P is the identity matrix and A = Q*R. If pivot is true then column pivoting results in an R matrix where the diagonal
// is non-decreasing. The use of pivoting is supposed to increase accuracy for poorly conditioned problems, and is necessary
// for rank estimation or computation of the null space, but it may be slower.
function qr_factor(A, pivot=false) =
assert(is_matrix(A), "Input must be a matrix." )
let(
m = len(A),
n = len(A[0])
)
let(
qr = _qr_factor(A, Q=ident(m), column=0, m = m, n=n),
qr =_qr_factor(A, Q=ident(m),P=ident(n), pivot=pivot, column=0, m = m, n=n),
Rzero =
let( R = qr[1] )
[ for(i=[0:m-1]) [
@ -735,25 +753,31 @@ function qr_factor(A) =
for(j=[0:n-1]) i>j ? 0 : ri[j]
]
]
) [qr[0],Rzero];
) [qr[0],Rzero,qr[2]];
function _qr_factor(A,Q, column, m, n) =
column >= min(m-1,n) ? [Q,A] :
function _qr_factor(A,Q,P, pivot, column, m, n) =
column >= min(m-1,n) ? [Q,A,P] :
let(
swap = !pivot ? 1
: _swap_matrix(n,column,column+max_index([for(i=[column:n-1]) sum_of_squares([for(j=[column:m-1]) A[j][i]])])),
A = pivot ? A*swap : A,
x = [for(i=[column:1:m-1]) A[i][column]],
alpha = (x[0]<=0 ? 1 : -1) * norm(x),
u = x - concat([alpha],repeat(0,m-1)),
v = alpha==0 ? u : u / norm(u),
Qc = ident(len(x)) - 2*outer_product(v,v),
Qf = [for(i=[0:m-1])
[for(j=[0:m-1])
i<column || j<column
? (i==j ? 1 : 0)
: Qc[i-column][j-column]
]
]
Qf = [for(i=[0:m-1]) [for(j=[0:m-1]) i<column || j<column ? (i==j ? 1 : 0) : Qc[i-column][j-column]]]
)
_qr_factor(Qf*A, Q*Qf, column+1, m, n);
_qr_factor(Qf*A, Q*Qf, P*swap, pivot, column+1, m, n);
// Produces an n x n matrix that swaps column i and j (when multiplied on the right)
function _swap_matrix(n,i,j) =
assert(i<n && j<n && i>=0 && j>=0, "Swap indices out of bounds")
[for(y=[0:n-1]) [for (x=[0:n-1])
x==i ? (y==j ? 1 : 0)
: x==j ? (y==i ? 1 : 0)
: x==y ? 1 : 0]];
// Function: back_substitute()
@ -862,6 +886,17 @@ function is_matrix(A,m,n,square=false) =
&& ( !square || len(A)==len(A[0]));
// Function: norm_fro()
// Usage:
// norm_fro(A)
// Description:
// Computes frobenius norm of input matrix. The frobenius norm is the square root of the sum of the
// squares of all of the entries of the matrix. On vectors it is the same as the usual 2-norm.
// This is an easily computed norm that is convenient for comparing two matrices.
function norm_fro(A) =
sqrt(sum([for(entry=A) sum_of_squares(entry)]));
// Section: Comparisons and Logic
// Function: is_zero()
@ -1309,6 +1344,41 @@ function C_div(z1,z2) =
// Section: Polynomials
// Function: quadratic_roots()
// Usage:
// roots = quadratic_roots(a,b,c,[real])
// Description:
// Computes roots of the quadratic equation a*x^2+b*x+c==0, where the
// coefficients are real numbers. If real is true then returns only the
// real roots. Otherwise returns a pair of complex values. This method
// may be more reliable than the general root finder at distinguishing
// real roots from complex roots.
// https://people.csail.mit.edu/bkph/articles/Quadratics.pdf
function quadratic_roots(a,b,c,real=false) =
real ? [for(root = quadratic_roots(a,b,c,real=false)) if (root.y==0) root.x]
:
is_undef(b) && is_undef(c) && is_vector(a,3) ? quadratic_roots(a[0],a[1],a[2]) :
assert(is_num(a) && is_num(b) && is_num(c))
assert(a!=0 || b!=0 || c!=0, "Quadratic must have a nonzero coefficient")
a==0 && b==0 ? [] : // No solutions
a==0 ? [[-c/b,0]] :
let(
descrim = b*b-4*a*c,
sqrt_des = sqrt(abs(descrim))
)
descrim < 0 ? // Complex case
[[-b, sqrt_des],
[-b, -sqrt_des]]/2/a :
b<0 ? // b positive
[[2*c/(-b+sqrt_des),0],
[(-b+sqrt_des)/a/2,0]]
: // b negative
[[(-b-sqrt_des)/2/a, 0],
[2*c/(-b-sqrt_des),0]];
// Function: polynomial()
// Usage:
// polynomial(p, z)

View File

@ -470,6 +470,40 @@ module test_zip() {
}
test_zip();
module test_block_matrix() {
A = [[1,2],[3,4]];
B = ident(2);
assert_equal(block_matrix([[A,B],[B,A],[A,B]]), [[1,2,1,0],[3,4,0,1],[1,0,1,2],[0,1,3,4],[1,2,1,0],[3,4,0,1]]);
assert_equal(block_matrix([[A,B],ident(4)]), [[1,2,1,0],[3,4,0,1],[1,0,0,0],[0,1,0,0],[0,0,1,0],[0,0,0,1]]);
text = [["a","b"],["c","d"]];
assert_equal(block_matrix([[text,B]]), [["a","b",1,0],["c","d",0,1]]);
}
test_block_matrix();
module test_diagonal_matrix() {
assert_equal(diagonal_matrix([1,2,3]), [[1,0,0],[0,2,0],[0,0,3]]);
assert_equal(diagonal_matrix([1,"c",2]), [[1,0,0],[0,"c",0],[0,0,2]]);
assert_equal(diagonal_matrix([1,"c",2],"X"), [[1,"X","X"],["X","c","X"],["X","X",2]]);
assert_equal(diagonal_matrix([[1,1],[2,2],[3,3]], [0,0]), [[ [1,1],[0,0],[0,0]], [[0,0],[2,2],[0,0]], [[0,0],[0,0],[3,3]]]);
}
test_diagonal_matrix();
module test_submatrix_set() {
test = [[1,2,3,4,5],[6,7,8,9,10],[11,12,13,14,15], [16,17,18,19,20]];
ragged = [[1,2,3,4,5],[6,7,8,9,10],[11,12], [16,17]];
assert_equal(submatrix_set(test,[[9,8],[7,6]]), [[9,8,3,4,5],[7,6,8,9,10],[11,12,13,14,15], [16,17,18,19,20]]);
assert_equal(submatrix_set(test,[[9,7],[8,6]],1),[[1,2,3,4,5],[9,7,8,9,10],[8,6,13,14,15], [16,17,18,19,20]]);
assert_equal(submatrix_set(test,[[9,8],[7,6]],n=1), [[1,9,8,4,5],[6,7,6,9,10],[11,12,13,14,15], [16,17,18,19,20]]);
assert_equal(submatrix_set(test,[[9,8],[7,6]],1,2), [[1,2,3,4,5],[6,7,9,8,10],[11,12,7,6,15], [16,17,18,19,20]]);
assert_equal(submatrix_set(test,[[9,8],[7,6]],-1,-1), [[6,2,3,4,5],[6,7,8,9,10],[11,12,13,14,15], [16,17,18,19,20]]);
assert_equal(submatrix_set(test,[[9,8],[7,6]],n=4), [[1,2,3,4,9],[6,7,8,9,7],[11,12,13,14,15], [16,17,18,19,20]]);
assert_equal(submatrix_set(test,[[9,8],[7,6]],7,7), [[1,2,3,4,5],[6,7,8,9,10],[11,12,13,14,15], [16,17,18,19,20]]);
assert_equal(submatrix_set(ragged, [["a","b"],["c","d"]], 1, 1), [[1,2,3,4,5],[6,"a","b",9,10],[11,"c"], [16,17]]);
assert_equal(submatrix_set(test, [[]]), test);
}
test_submatrix_set();
module test_array_group() {
v = [1,2,3,4,5,6];

View File

@ -781,6 +781,12 @@ test_back_substitute();
module test_norm_fro(){
assert_approx(norm_fro([[2,3,4],[4,5,6]]), 10.29563014098700);
} test_norm_fro();
module test_linear_solve(){
M = [[-2,-5,-1,3],
[3,7,6,2],
@ -954,6 +960,38 @@ module test_real_roots(){
test_real_roots();
module test_quadratic_roots(){
assert_approx(quadratic_roots([1,4,4]),[[-2,0],[-2,0]]);
assert_approx(quadratic_roots([1,4,4],real=true),[-2,-2]);
assert_approx(quadratic_roots([1,-5,6],real=true), [2,3]);
assert_approx(quadratic_roots([1,-5,6]), [[2,0],[3,0]]);
}
test_quadratic_roots();
module test_null_space(){
assert_equal(null_space([[3,2,1],[3,6,3],[3,9,-3]]),[]);
function nullcheck(A,dim) =
let(v=null_space(A))
len(v)==dim && is_zero(A*transpose(v),eps=1e-12);
A = [[-1, 2, -5, 2],[-3,-1,3,-3],[5,0,5,0],[3,-4,11,-4]];
assert(nullcheck(A,1));
B = [
[ 4, 1, 8, 6, -2, 3],
[ 10, 5, 10, 10, 0, 5],
[ 8, 1, 8, 8, -6, 1],
[ -8, -8, 6, -1, -8, -1],
[ 2, 2, 0, 1, 2, 1],
[ 2, -3, 10, 6, -8, 1],
];
assert(nullcheck(B,3));
}
test_null_space();
module test_qr_factor() {
// Check that R is upper triangular
function is_ut(R) =
@ -962,7 +1000,15 @@ module test_qr_factor() {
// Test the R is upper trianglar, Q is orthogonal and qr=M
function qrok(qr,M) =
is_ut(qr[1]) && approx(qr[0]*transpose(qr[0]), ident(len(qr[0]))) && approx(qr[0]*qr[1],M);
is_ut(qr[1]) && approx(qr[0]*transpose(qr[0]), ident(len(qr[0]))) && approx(qr[0]*qr[1],M) && qr[2]==ident(len(qr[2]));
// Test the R is upper trianglar, Q is orthogonal, R diagonal non-increasing and qrp=M
function qrokpiv(qr,M) =
is_ut(qr[1])
&& approx(qr[0]*transpose(qr[0]), ident(len(qr[0])))
&& approx(qr[0]*qr[1]*transpose(qr[2]),M)
&& list_decreasing([for(i=[0:1:min(len(qr[1]),len(qr[1][0]))-1]) abs(qr[1][i][i])]);
M = [[1,2,9,4,5],
[6,7,8,19,10],
@ -991,6 +1037,15 @@ module test_qr_factor() {
assert(qrok(qr_factor([[7]]), [[7]]));
assert(qrok(qr_factor([[1,2,3]]), [[1,2,3]]));
assert(qrok(qr_factor([[1],[2],[3]]), [[1],[2],[3]]));
assert(qrokpiv(qr_factor(M,pivot=true),M));
assert(qrokpiv(qr_factor(select(M,0,3),pivot=true),select(M,0,3)));
assert(qrokpiv(qr_factor(transpose(select(M,0,3)),pivot=true),transpose(select(M,0,3))));
assert(qrokpiv(qr_factor(B,pivot=true),B));
assert(qrokpiv(qr_factor([[7]],pivot=true), [[7]]));
assert(qrokpiv(qr_factor([[1,2,3]],pivot=true), [[1,2,3]]));
assert(qrokpiv(qr_factor([[1],[2],[3]],pivot=true), [[1],[2],[3]]));
}
test_qr_factor();