00001
00002
00003
00004 #include "GaussianMixturePdf.hpp"
00005
00006
00007 using namespace indii::ml::aux;
00008
00009 GaussianMixturePdf::GaussianMixturePdf() : StandardMixturePdf<GaussianPdf>() {
00010
00011 }
00012
00013 GaussianMixturePdf::GaussianMixturePdf(const unsigned int N) :
00014 StandardMixturePdf<GaussianPdf>(N) {
00015
00016 }
00017
00018 GaussianMixturePdf::GaussianMixturePdf(const unsigned int K,
00019 const DiracMixturePdf& p) :
00020 StandardMixturePdf<GaussianPdf>(p.getDimensions()) {
00021 std::vector<unsigned int> clusters(p.getSize());
00022 unsigned int i, k;
00023
00024 vector ws(K);
00025 std::vector<vector> mus;
00026 std::vector<symmetric_matrix> sigmas;
00027 std::vector<GaussianPdf> xs;
00028
00029
00030 vector mu(getDimensions());
00031 symmetric_matrix sigma(getDimensions());
00032 GaussianPdf x(mu, sigma);
00033 for (k = 0; k < K; k++) {
00034 mus.push_back(mu);
00035 sigmas.push_back(sigma);
00036 xs.push_back(x);
00037 }
00038 for (i = 0; i < p.getSize(); i++) {
00039 clusters[i] = i % K;
00040 }
00041
00042 bool change = false;
00043 unsigned int k_max;
00044 double d, d_max;
00045 do {
00046
00047 ws.clear();
00048 for (k = 0; k < K; k++) {
00049 mus[k].clear();
00050 sigmas[k].clear();
00051 }
00052
00053 for (i = 0; i < p.getSize(); i++) {
00054 k = clusters[i];
00055 ws(k) += p.getWeight(i);
00056 noalias(mus[k]) += p.getWeight(i) * p.get(i);
00057 noalias(sigmas[k]) += p.getWeight(i) * outer_prod(p.get(i), p.get(i));
00058 }
00059 for (k = 0; k < K; k++) {
00060 mus[k] /= ws(k);
00061 sigmas[k] /= ws(k);
00062 sigmas[k] -= outer_prod(mus[k],mus[k]);
00063 }
00064
00065
00066 for (k = 0; k < K; k++) {
00067 xs[k].setExpectation(mus[k]);
00068 xs[k].setCovariance(sigmas[k]);
00069 }
00070
00071 change = false;
00072 for (i = 0; i < p.getSize(); i++) {
00073 k_max = K;
00074 d_max = 0.0;
00075 for (k = 0; k < K; k++) {
00076 d = ws(k) * xs[k].densityAt(p.get(i));
00077 if (d > d_max) {
00078 k_max = k;
00079 d_max = d;
00080 }
00081 }
00082 change = change || clusters[i] != k_max;
00083 clusters[i] = k_max;
00084 }
00085 } while (change);
00086
00087
00088 for (k = 0; k < K; k++) {
00089 add(xs[k], ws(k));
00090 }
00091 }
00092
00093 GaussianMixturePdf::~GaussianMixturePdf() {
00094
00095 }
00096