// train.c
// training routine using MSE and standard backprop
#include "includes.c"
int main() {
srand(time(NULL));
Model m;
Dataset d;
init_model(&m);
load_dataset("emojis.bin", &d);
int steps = 3000; // training steps
int epoch_to_print = 1000; // print loss and save image every
float latent[LATENT];
float output[IMAGESIZE];
float d_output[IMAGESIZE];
float d_latent[LATENT];
float avg_loss = 0;
for (int step = 0; step < steps; step++) {
int idx = rand() % d.count;
float *img = &d.images[idx * IMAGESIZE];
// ---- forward ----
// the easiest forward phase of all time :)
encode(&m, img, latent);
decode(&m, latent, output);
// ---- loss + gradient ----
float loss = 0;
for (int i = 0; i < IMAGESIZE; i++) {
float diff = output[i] - img[i];
loss += diff * diff; // mean squared error
float t = 2.0f * output[i] - 1.0f; // tanh(sum)
float dout_dsum = 0.5f * (1.0f - t*t); // derivative of decoder output
float dL_dout = 2.0f * diff; // derivative of squared error
d_output[i] = dL_dout * dout_dsum; // now dL/dsum
// d_output contains gradient of loss wrt sum
}
// ---- backprop decoder ----
for (int j = 0; j < LATENT; j++)
d_latent[j] = 0;
for (int j = 0; j < LATENT; j++) { // for each latent component
for (int i = 0; i < IMAGESIZE; i++) { // for each image component
int w2_idx = j * IMAGESIZE + i;
float grad = d_output[i] * latent[j]; // scale gradient(per pixel) by latent value
float w = m.w2[w2_idx];
d_latent[j] += d_output[i] * w; // save the contribution for each latent component to backprop encoder
m.w2[w2_idx] -= LR * grad; // backprop it! (scaled by LR)
}
}
for (int i = 0; i < IMAGESIZE; i++)
m.b2[i] -= LR * d_output[i]; // also backprop in the weights...
// ---- backprop encoder ----
for (int j = 0; j < LATENT; j++) {
float dz = d_latent[j] * (1 - latent[j] * latent[j]); // derivative for dl/dsum
for (int i = 0; i < IMAGESIZE; i++) {
int w1_idx = i * LATENT + j;
float grad = dz * img[i]; // = dl/dw1
m.w1[w1_idx] -= LR * grad; // backprop again!
}
m.b1[j] -= LR * dz; // and the weights too
}
avg_loss += loss;
if (step % epoch_to_print == 0) {
// print loss, on first dont divide by 1000
if (step != 0) printf("step %d loss %f\n", step, avg_loss / epoch_to_print / IMAGESIZE);
else printf("initial loss %f\n", loss / IMAGESIZE);
avg_loss = 0;
// save reconstruction so we can see improvements
float image[IMAGESIZE];
char filename[16];
snprintf(filename, 16, "%d.bmp", step);
reconstruct(&m, img, image);
save_bmp(filename, image);
}
}
save_model("model.bin", &m);
return 0;
}