Created
March 28, 2021 18:06
-
-
Save florindumitru/c3936fc062d0f4d6f9b79a634037cfc7 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import * as tf from '@tensorflow/tfjs-node-gpu'; | |
| const csvUrl = | |
| 'https://storage.googleapis.com/tfjs-examples/multivariate-linear-regression/data/boston-housing-train.csv'; | |
| async function run() { | |
| // We want to predict the column "medv", which represents a median value of | |
| // a home (in $1000s), so we mark it as a label. | |
| const csvDataset = tf.data.csv( | |
| csvUrl, { | |
| columnConfigs: { | |
| medv: { | |
| isLabel: true | |
| } | |
| } | |
| }); | |
| // Number of features is the number of column names minus one for the label | |
| // column. | |
| const numOfFeatures = (await csvDataset.columnNames()).length - 1; | |
| // Prepare the Dataset for training. | |
| const flattenedDataset = | |
| csvDataset | |
| .map(({ xs, ys }) => { | |
| // Convert xs(features) and ys(labels) from object form (keyed by | |
| // column name) to array form. | |
| return { xs: Object.values(xs), ys: Object.values(ys) }; | |
| }) | |
| .batch(10); | |
| // Define the model. | |
| const model = tf.sequential(); | |
| model.add(tf.layers.dense({ | |
| inputShape: [numOfFeatures], | |
| units: 1 | |
| })); | |
| model.compile({ | |
| optimizer: tf.train.sgd(0.000001), | |
| loss: 'meanSquaredError' | |
| }); | |
| // Fit the model using the prepared Dataset | |
| return model.fitDataset(flattenedDataset, { | |
| epochs: 30, | |
| callbacks: { | |
| onEpochEnd: async (epoch, logs) => { | |
| console.log(epoch + ':' + logs.loss); | |
| } | |
| } | |
| }); | |
| } | |
| await run(); |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment