Admittedly it gets more complicated when summing two things at the same time:
let Pair dnormMean dnormNormMean =
fold (Pair <$> dimap (\(Pair _ dnormI) -> dnormI) (/ fromIntegral cc) sum
<*> dimap (\(Pair normBti dnormI) -> normBti * dnormI) (/ fromIntegral cc) sum)
$ map (\i -> Pair (((inp ! (off + i)) - meanBt) * rstdBt)
((weight ! i) * (dout ! (off + i))))
[0 .. cc - 1]
float dnorm_mean = 0.0f;
float dnorm_norm_mean = 0.0f;
for (int i = 0; i < C; i++) {
float norm_bti = (inp_bt[i] - mean_bt) * rstd_bt;
float dnorm_i = weight[i] * dout_bt[i];
dnorm_mean += dnorm_i;
dnorm_norm_mean += dnorm_i * norm_bti;
}
dnorm_mean = dnorm_mean / C;
dnorm_norm_mean = dnorm_norm_mean / C;