Skip to content

Instantly share code, notes, and snippets.

@florindumitru
Created March 28, 2021 18:06
Show Gist options
  • Select an option

  • Save florindumitru/c3936fc062d0f4d6f9b79a634037cfc7 to your computer and use it in GitHub Desktop.

Select an option

Save florindumitru/c3936fc062d0f4d6f9b79a634037cfc7 to your computer and use it in GitHub Desktop.
import * as tf from '@tensorflow/tfjs-node-gpu';
const csvUrl =
'https://storage.googleapis.com/tfjs-examples/multivariate-linear-regression/data/boston-housing-train.csv';
async function run() {
// We want to predict the column "medv", which represents a median value of
// a home (in $1000s), so we mark it as a label.
const csvDataset = tf.data.csv(
csvUrl, {
columnConfigs: {
medv: {
isLabel: true
}
}
});
// Number of features is the number of column names minus one for the label
// column.
const numOfFeatures = (await csvDataset.columnNames()).length - 1;
// Prepare the Dataset for training.
const flattenedDataset =
csvDataset
.map(({ xs, ys }) => {
// Convert xs(features) and ys(labels) from object form (keyed by
// column name) to array form.
return { xs: Object.values(xs), ys: Object.values(ys) };
})
.batch(10);
// Define the model.
const model = tf.sequential();
model.add(tf.layers.dense({
inputShape: [numOfFeatures],
units: 1
}));
model.compile({
optimizer: tf.train.sgd(0.000001),
loss: 'meanSquaredError'
});
// Fit the model using the prepared Dataset
return model.fitDataset(flattenedDataset, {
epochs: 30,
callbacks: {
onEpochEnd: async (epoch, logs) => {
console.log(epoch + ':' + logs.loss);
}
}
});
}
await run();
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment