EDIT:
Even though the problem in the question has been updated, an algebraic approach can still be used to simplify matters. You still don't have to bother with 3-D matrices. Your result is just going to be this:
output = mean(v.^2).*A.^2 + 2.*mean(v.*w).*A.*B + mean(w.^2).*B.^2;
If your matrices and vectors are large, this solution will give you much better performance due to the reduced amount of memory required as compared to solutions using BSXFUN or REPMAT.
Explanation:
Assuming M
is the m-by-n-by-d matrix that you get as a result before taking the mean along the third dimension, this is what a span along the third dimension will contain:
M(i,j,:) = A(i,j).*v + B(i,j).*w;
In other words, the vector v
scaled by A(i,j)
plus the vector w
scaled by B(i,j)
. And this is what you get when you apply an element-wise squaring:
M(i,j,:).^2 = (A(i,j).*v + B(i,j).*w).^2;
= (A(i,j).*v).^2 + ...
2.*A(i,j).*B(i,j).*v.*w + ...
(B(i,j).*w).^2;
Now, when you take the mean across the third dimension, the result for each element output(i,j)
will be the following:
output(i,j) = mean(M(i,j,:).^2);
= mean((A(i,j).*v).^2 + ...
2.*A(i,j).*B(i,j).*v.*w + ...
(B(i,j).*w).^2);
= sum((A(i,j).*v).^2 + ...
2.*A(i,j).*B(i,j).*v.*w + ...
(B(i,j).*w).^2)/d;
= sum((A(i,j).*v).^2)/d + ...
sum(2.*A(i,j).*B(i,j).*v.*w)/d + ...
sum((B(i,j).*w).^2)/d;
= A(i,j).^2.*mean(v.^2) + ...
2.*A(i,j).*B(i,j).*mean(v.*w) + ...
B(i,j).^2.*mean(w.^2);