ai-creature / agent_sac.js
Artod's picture
Upload 16 files
b396e7b
/**
* Soft Actor Critic Agent https://arxiv.org/abs/1812.05905
* without value network.
*/
const AgentSac = (() => {
/**
* Validates the shape of a given tensor.
*
* @param {Tensor} tensor - tensor whose shape must be validated
* @param {array} shape - shape to compare with
* @param {string} [msg = ''] - message for the error
*/
const assertShape = (tensor, shape, msg = '') => {
console.assert(
JSON.stringify(tensor.shape) === JSON.stringify(shape),
msg + ' shape ' + tensor.shape + ' is not ' + shape)
}
// const VERSION = 1 // +100 for bump tower
// const VERSION = 2 // balls
// const VERSION = 3 // tests
// const VERSION = 4 // tests
// const VERSION = 5 // exp #1
// const VERSION = 6 // exp #2
// const VERSION = 7 // exp #3
// const VERSION = 8 // exp #4
// const VERSION = 9 // exp #
// const VERSION = 10 // exp # good, doesn't touch
// const VERSION = 11 // exp #
// const VERSION = 12 // exp # 25x25
// const VERSION = 13 // exp # 25x25 single CNN
// const VERSION = 15 // 15.1 stable RB 10^5
// const VERSION = 16 // reward from RL2, rb 10^6, gr/red balls, bad
// const VERSION = 18 // reward from RL2, CNN from SAC paper, works!
// const VERSION = 19 // moving balls, super!
// const VERSION = 20 // moving balls, discret impulse, bad
// const VERSION = 21 // independant look
// const VERSION = 22 // dqn arch, bad
// const VERSION = 23 // dqn trunc, works! fast learn
// const VERSION = 24 // dqn trunc 3 layers, super and fast
// const VERSION = 25 // dqn trunc 3 layers 2x512, poor
// const VERSION = 26 // rl2 cnn arc, bad too many weights
// const VERSION = 27 // sac cnn 16x6x3->16x4x2->8x3x1->2x256 and 2 clr frames, 2h, kiss, Excellent!
// const VERSION = 28 // same but 1 frame, works
// const VERSION = 29 // 1fr w/o accel, poor
// const VERSION = 30 // 2fr wide img, poor
// const VERSION = 31 // 2 small imgs, small cnn out, poor
// const VERSION = 32 // 2fr binacular
// const VERSION = 33 // 4fr binacular, Good, but poor after reload on wider cage
// const VERSION = 34 // 4fr binacular, smaller fov=2, angle 0.7, poor
// const VERSION = 35 // 4fr binacular with dist, poor
// const VERSION = 36 // 4fr binacular with dist, works but reload not
// const VERSION = 37 // BCNN achiasma, good -> reload poor
// const VERSION = 38 // BCNN achiasma, smaller cnn
// const VERSION = 39 // 1fr BCNN achiasma, smaller cnn, works super fast, 30min
// const VERSION = 40 // 2fr BCNN achiasma, 2l smaller cnn, poor
// const VERSION = 41 // 2fr BCNN achiasma, 2l smaller cnn, some perfm after 30min
// const VERSION = 41 // 1fr BCNN achiasma, 2l smaller cnn, super kiss, reload poor
// const VERSION = 42 // 2fr BCNN achiasma, 2l smaller cnn, reload poor
// const VERSION = 43 // 1fr BCNN achiasma, 3l, fov 0.8, 1h good, reload not bad
// const VERSION = 44 // 2fr BCNN achiasma, 3l, fov 0.8, slow 1h, reload not bad, a bit better than 1fr, degrade
// const VERSION = 45 // 1fr BCNN achiasma, 2l, fov 0.8, poor
// const VERSION = 46 // 2fr BCNN achiasma, 2l, fov 0.8, fast 30 min but poor on reload
// const VERSION = 47 // 1fr BCNN chiasma, 2l, fov 0.7, poor
// const VERSION = 48 // 2fr BCNN chiasma, 2l, fov 0.7 poor
// const VERSION = 49 // 1fr BCNN chiasma stacked, 3l, poor
// const VERSION = 50 // 2fr 2nets monocular, 1h good, reload poor
// const VERSION = 51 // 1fr 1nets monocular, stuck
// const VERSION = 52 // 2fr 2nets monocular, poor
// const VERSION = 53 // 2fr 2nets monocular,
// const VERSION = 54 // 2fr binocular
// const VERSION = 55 // 2fr binocular
// const VERSION = 56 // 2fr binocular
// const VERSION = 57 // 1fr binocular, sphere vimeo super
// const VERSION = 58 // 2fr binocular, sphere
// const VERSION = 59 // 1fr binocular, sphere
// const VERSION = 61 // 2fr binocular, sphere, 2lay BASELINE!!! cage 55, mass 2, ball mass 1
// const VERSION = 62
//const VERSION = 63 // 1fr 30min! cage 60
// const VERSION = 64 // 2fr nores
// const VERSION = 66 // 1fr 30min slightly slower
// const VERSION = 67 // 2fr 30min as prev
// const VERSION = 65 // 1fr l/r diff, 30min +400
// const VERSION = 68 // 1fr l/r diff, 30min -100 good
// const VERSION = 69 // 1fr l/r diff, 30min -190 good
// const VERSION = 70 // 1fr l/r diff, 30min -420
// const VERSION = 71 // 1fr l/r diff, 30min -480
// const VERSION = 72 // 1fr no diff, 30min
// const VERSION = 73 // 1fr no diff, 30min -400 cage 50
// const VERSION = 74 // 1fr diff, 30min 2.6k!
// const VERSION = 75 // 1fr diff, 30min -300
// const VERSION = 76 // 1fr diff, 20min +300!
// const VERSION = 77 // 1fr diff, 20min +3.5k!
// const VERSION = 78 // 1fr diff, 30min -90
// const VERSION = 79 // 1fr NO diff, 25min +158
// const VERSION = 80 // 1fr NO diff, 30min -200
// const VERSION = 81 // 1fr NO diff, 20min +1200
// const VERSION = 82 // 1fr NO diff, 30min
// const VERSION = 83 // 1fr NO diff, priority 30min -400
const VERSION = 84 // 1fr diff, 30min
const LOG_STD_MIN = -20
const LOG_STD_MAX = 2
const EPSILON = 1e-8
const NAME = {
ACTOR: 'actor',
Q1: 'q1',
Q2: 'q2',
Q1_TARGET: 'q1-target',
Q2_TARGET: 'q2-target',
ALPHA: 'alpha'
}
return class AgentSac {
constructor({
batchSize = 1,
frameShape = [25, 25, 3],
nFrames = 1, // Number of stacked frames per state
nActions = 3, // 3 - impuls, 3 - RGB color
nTelemetry = 10, // 3 - linear valocity, 3 - acceleration, 3 - collision point, 1 - lidar (tanh of distance)
gamma = 0.99, // Discount factor (γ)
tau = 5e-3, // Target smoothing coefficient (τ)
trainable = true, // Whether the actor is trainable
verbose = false,
forced = false, // force to create fresh models (not from checkpoint)
prefix = '', // for tests,
sighted = true,
rewardScale = 10
} = {}) {
this._batchSize = batchSize
this._frameShape = frameShape
this._nFrames = nFrames
this._nActions = nActions
this._nTelemetry = nTelemetry
this._gamma = gamma
this._tau = tau
this._trainable = trainable
this._verbose = verbose
this._inited = false
this._prefix = (prefix === '' ? '' : prefix + '-')
this._forced = forced
this._sighted = sighted
this._rewardScale = rewardScale
this._frameStackShape = [...this._frameShape.slice(0, 2), this._frameShape[2] * this._nFrames]
// https://github.com/rail-berkeley/softlearning/blob/13cf187cc93d90f7c217ea2845067491c3c65464/softlearning/algorithms/sac.py#L37
this._targetEntropy = -nActions
}
/**
* Initialization.
*/
async init() {
if (this._inited) throw Error('щ(゚Д゚щ)')
this._frameInputL = tf.input({batchShape : [null, ...this._frameStackShape]})
this._frameInputR = tf.input({batchShape : [null, ...this._frameStackShape]})
this._telemetryInput = tf.input({batchShape : [null, this._nTelemetry]})
this.actor = await this._getActor(this._prefix + NAME.ACTOR, this.trainable)
if (!this._trainable)
return
this.actorOptimizer = tf.train.adam()
this._actionInput = tf.input({batchShape : [null, this._nActions]})
this.q1 = await this._getCritic(this._prefix + NAME.Q1)
this.q1Optimizer = tf.train.adam()
this.q2 = await this._getCritic(this._prefix + NAME.Q2)
this.q2Optimizer = tf.train.adam()
this.q1Targ = await this._getCritic(this._prefix + NAME.Q1_TARGET, true) // true for batch norm
this.q2Targ = await this._getCritic(this._prefix + NAME.Q2_TARGET, true)
this._logAlpha = await this._getLogAlpha(this._prefix + NAME.ALPHA)
this.alphaOptimizer = tf.train.adam()
this.updateTargets(1)
// console.log('weights actorr', this.actor.getWeights().map(w => w.arraySync()))
// console.log('weights q1q1q1', this.q1.getWeights().map(w => w.arraySync()))
// console.log('weights q2Targ', this.q2Targ.getWeights().map(w => w.arraySync()))
this._inited = true
}
/**
* Trains networks on a batch from the replay buffer.
*
* @param {{ state, action, reward, nextState }} - trnsitions in batch
* @returns {void} nothing
*/
train({ state, action, reward, nextState }) {
if (!this._trainable)
throw new Error('Actor is not trainable')
return tf.tidy(() => {
assertShape(state[0], [this._batchSize, this._nTelemetry], 'telemetry')
assertShape(state[1], [this._batchSize, ...this._frameStackShape], 'frames')
assertShape(action, [this._batchSize, this._nActions], 'action')
assertShape(reward, [this._batchSize, 1], 'reward')
assertShape(nextState[0], [this._batchSize, this._nTelemetry], 'nextState telemetry')
assertShape(nextState[1], [this._batchSize, ...this._frameStackShape], 'nextState frames')
this._trainCritics({ state, action, reward, nextState })
this._trainActor(state)
this._trainAlpha(state)
this.updateTargets()
})
}
/**
* Train Q-networks.
*
* @param {{ state, action, reward, nextState }} transition - transition
*/
_trainCritics({ state, action, reward, nextState }) {
const getQLossFunction = (() => {
const [nextFreshAction, logPi] = this.sampleAction(nextState, true)
const q1TargValue = this.q1Targ.predict(
this._sighted ? [...nextState, nextFreshAction] : [nextState[0], nextFreshAction],
{batchSize: this._batchSize})
const q2TargValue = this.q2Targ.predict(
this._sighted ? [...nextState, nextFreshAction] : [nextState[0], nextFreshAction],
{batchSize: this._batchSize})
const qTargValue = tf.minimum(q1TargValue, q2TargValue)
// y = r + γ*(1 - d)*(min(Q1Targ(s', a'), Q2Targ(s', a')) - α*log(π(s'))
const alpha = this._getAlpha()
const target = reward.mul(tf.scalar(this._rewardScale)).add(
tf.scalar(this._gamma).mul(
qTargValue.sub(alpha.mul(logPi))
)
)
assertShape(nextFreshAction, [this._batchSize, this._nActions], 'nextFreshAction')
assertShape(logPi, [this._batchSize, 1], 'logPi')
assertShape(qTargValue, [this._batchSize, 1], 'qTargValue')
assertShape(target, [this._batchSize, 1], 'target')
return (q) => () => {
const qValue = q.predict(
this._sighted ? [...state, action] : [state[0], action],
{batchSize: this._batchSize})
// const loss = tf.scalar(0.5).mul(tf.losses.meanSquaredError(qValue, target))
const loss = tf.scalar(0.5).mul(tf.mean(qValue.sub(target).square()))
assertShape(qValue, [this._batchSize, 1], 'qValue')
return loss
}
})()
for (const [q, optimizer] of [
[this.q1, this.q1Optimizer],
[this.q2, this.q2Optimizer]
]) {
const qLossFunction = getQLossFunction(q)
const { value, grads } = tf.variableGrads(qLossFunction, q.getWeights(true)) // true means trainableOnly
optimizer.applyGradients(grads)
if (this._verbose) console.log(q.name + ' Loss: ' + value.arraySync())
}
}
/**
* Train actor networks.
*
* @param {state} state
*/
_trainActor(state) {
// TODO: consider delayed update of policy and targets (if possible)
const actorLossFunction = () => {
const [freshAction, logPi] = this.sampleAction(state, true)
const q1Value = this.q1.predict(
this._sighted ? [...state, freshAction] : [state[0], freshAction],
{batchSize: this._batchSize})
const q2Value = this.q2.predict(
this._sighted ? [...state, freshAction] : [state[0], freshAction],
{batchSize: this._batchSize})
const criticValue = tf.minimum(q1Value, q2Value)
const alpha = this._getAlpha()
const loss = alpha.mul(logPi).sub(criticValue)
assertShape(freshAction, [this._batchSize, this._nActions], 'freshAction')
assertShape(logPi, [this._batchSize, 1], 'logPi')
assertShape(q1Value, [this._batchSize, 1], 'q1Value')
assertShape(criticValue, [this._batchSize, 1], 'criticValue')
assertShape(loss, [this._batchSize, 1], 'alpha loss')
return tf.mean(loss)
}
const { value, grads } = tf.variableGrads(actorLossFunction, this.actor.getWeights(true)) // true means trainableOnly
this.actorOptimizer.applyGradients(grads)
if (this._verbose) console.log('Actor Loss: ' + value.arraySync())
}
_trainAlpha(state) {
const alphaLossFunction = () => {
const [, logPi] = this.sampleAction(state, true)
const alpha = this._getAlpha()
const loss = tf.scalar(-1).mul(
alpha.mul( // TODO: not sure whether this should be alpha or logAlpha
logPi.add(tf.scalar(this._targetEntropy))
)
)
assertShape(loss, [this._batchSize, 1], 'alpha loss')
return tf.mean(loss)
}
const { value, grads } = tf.variableGrads(alphaLossFunction, [this._logAlpha]) // true means trainableOnly
this.alphaOptimizer.applyGradients(grads)
if (this._verbose) console.log('Alpha Loss: ' + value.arraySync(), tf.exp(this._logAlpha).arraySync())
}
/**
* Soft update target Q-networks.
*
* @param {number} [tau = this._tau] - smoothing constant τ for exponentially moving average: `wTarg <- wTarg*(1-tau) + w*tau`
*/
updateTargets(tau = this._tau) {
tau = tf.scalar(tau)
const
q1W = this.q1.getWeights(),
q2W = this.q2.getWeights(),
q1WTarg = this.q1Targ.getWeights(),
q2WTarg = this.q2Targ.getWeights(),
len = q1W.length
// console.log('updateTargets q1W', q1W.map(w=>w.arraySync()))
// console.log('updateTargets q1WTarg', q1WTarg.map(w=>w.arraySync()))
const calc = (w, wTarg) => wTarg.mul(tf.scalar(1).sub(tau)).add(w.mul(tau))
const w1 = [], w2 = []
for (let i = 0; i < len; i++) {
w1.push(calc(q1W[i], q1WTarg[i]))
w2.push(calc(q2W[i], q2WTarg[i]))
}
this.q1Targ.setWeights(w1)
this.q2Targ.setWeights(w2)
}
/**
* Returns actions sampled from normal distribution using means and stds predicted by the actor.
*
* @param {Tensor[]} state - state
* @param {Tensor} [withLogProbs = false] - whether return log probabilities
* @returns {Tensor || Tensor[]} action and log policy
*/
sampleAction(state, withLogProbs = false) { // timer ~3ms
return tf.tidy(() => {
let [ mu, logStd ] = this.actor.predict(this._sighted ? state : state[0], {batchSize: this._batchSize})
// https://github.com/rail-berkeley/rlkit/blob/c81509d982b4d52a6239e7bfe7d2540e3d3cd986/rlkit/torch/sac/policies/gaussian_policy.py#L106
logStd = tf.clipByValue(logStd, LOG_STD_MIN, LOG_STD_MAX)
const std = tf.exp(logStd)
// sample normal N(mu = 0, std = 1)
const normal = tf.randomNormal(mu.shape, 0, 1.0)
// reparameterization trick: z = mu + std * epsilon
let pi = mu.add(std.mul(normal))
let logPi = this._gaussianLikelihood(pi, mu, logStd)
;({ pi, logPi } = this._applySquashing(pi, mu, logPi))
if (!withLogProbs)
return pi
return [pi, logPi]
})
}
/**
* Calculates log probability of normal distribution https://en.wikipedia.org/wiki/Log_probability.
* Converted to js from https://github.com/tensorflow/probability/blob/f3777158691787d3658b5e80883fe1a933d48989/tensorflow_probability/python/distributions/normal.py#L183
*
* @param {Tensor} x - sample from normal distribution with mean `mu` and std `std`
* @param {Tensor} mu - mean
* @param {Tensor} std - standart deviation
* @returns {Tensor} log probability
*/
_logProb(x, mu, std) {
const logUnnormalized = tf.scalar(-0.5).mul(
tf.squaredDifference(x.div(std), mu.div(std))
)
const logNormalization = tf.scalar(0.5 * Math.log(2 * Math.PI)).add(tf.log(std))
return logUnnormalized.sub(logNormalization)
}
/**
* Gaussian likelihood.
* Translated from https://github.com/openai/spinningup/blob/038665d62d569055401d91856abb287263096178/spinup/algos/tf1/sac/core.py#L24
*
* @param {Tensor} x - sample from normal distribution with mean `mu` and std `exp(logStd)`
* @param {Tensor} mu - mean
* @param {Tensor} logStd - log of standart deviation
* @returns {Tensor} log probability
*/
_gaussianLikelihood(x, mu, logStd) {
// pre_sum = -0.5 * (
// ((x-mu)/(tf.exp(log_std)+EPS))**2
// + 2*log_std
// + np.log(2*np.pi)
// )
const preSum = tf.scalar(-0.5).mul(
x.sub(mu).div(
tf.exp(logStd).add(tf.scalar(EPSILON))
).square()
.add(tf.scalar(2).mul(logStd))
.add(tf.scalar(Math.log(2 * Math.PI)))
)
return tf.sum(preSum, 1, true)
}
/**
* Adjustment to log probability when squashing action with tanh
* Enforcing Action Bounds formula derivation https://stats.stackexchange.com/questions/239588/derivation-of-change-of-variables-of-a-probability-density-function
* Translated from https://github.com/openai/spinningup/blob/038665d62d569055401d91856abb287263096178/spinup/algos/tf1/sac/core.py#L48
*
* @param {*} pi - policy sample
* @param {*} mu - mean
* @param {*} logPi - log probability
* @returns {{ pi, mu, logPi }} squashed and adjasted input
*/
_applySquashing(pi, mu, logPi) {
// logp_pi -= tf.reduce_sum(2*(np.log(2) - pi - tf.nn.softplus(-2*pi)), axis=1)
const adj = tf.scalar(2).mul(
tf.scalar(Math.log(2))
.sub(pi)
.sub(tf.softplus(
tf.scalar(-2).mul(pi)
))
)
logPi = logPi.sub(tf.sum(adj, 1, true))
mu = tf.tanh(mu)
pi = tf.tanh(pi)
return { pi, mu, logPi }
}
/**
* Builds actor network model.
*
* @param {string} [name = 'actor'] - name of the model
* @param {string} trainable - whether a critic is trainable
* @returns {tf.LayersModel} model
*/
async _getActor(name = 'actor', trainable = true) {
const checkpoint = await this._loadCheckpoint(name)
if (checkpoint) return checkpoint
let outputs = this._telemetryInput
// outputs = tf.layers.dense({units: 128, activation: 'relu'}).apply(outputs)
if (this._sighted) {
let convOutputL = this._getConvEncoder(this._frameInputL)
let convOutputR = this._getConvEncoder(this._frameInputR)
// let convOutput = tf.layers.concatenate().apply([convOutputL, convOutputR])
// convOutput = tf.layers.dense({units: 10, activation: 'relu'}).apply(convOutput)
outputs = tf.layers.concatenate().apply([convOutputL, convOutputR, outputs])
}
outputs = tf.layers.dense({units: 256, activation: 'relu'}).apply(outputs)
outputs = tf.layers.dense({units: 256, activation: 'relu'}).apply(outputs)
const mu = tf.layers.dense({units: this._nActions}).apply(outputs)
const logStd = tf.layers.dense({units: this._nActions}).apply(outputs)
const model = tf.model({inputs: this._sighted ? [this._telemetryInput, this._frameInputL, this._frameInputR] : [this._telemetryInput], outputs: [mu, logStd], name})
model.trainable = trainable
if (this._verbose) {
console.log('==========================')
console.log('==========================')
console.log('Actor ' + name + ': ')
model.summary()
}
return model
}
/**
* Builds a critic network model.
*
* @param {string} [name = 'critic'] - name of the model
* @param {string} trainable - whether a critic is trainable
* @returns {tf.LayersModel} model
*/
async _getCritic(name = 'critic', trainable = true) {
const checkpoint = await this._loadCheckpoint(name)
if (checkpoint) return checkpoint
let outputs = tf.layers.concatenate().apply([this._telemetryInput, this._actionInput])
// outputs = tf.layers.dense({units: 128, activation: 'relu'}).apply(outputs)
if (this._sighted) {
let convOutputL = this._getConvEncoder(this._frameInputL)
let convOutputR = this._getConvEncoder(this._frameInputR)
// let convOutput = tf.layers.concatenate().apply([convOutputL, convOutputR])
// convOutput = tf.layers.dense({units: 10, activation: 'relu'}).apply(convOutput)
outputs = tf.layers.concatenate().apply([convOutputL, convOutputR, outputs])
}
outputs = tf.layers.dense({units: 256, activation: 'relu'}).apply(outputs)
outputs = tf.layers.dense({units: 256, activation: 'relu'}).apply(outputs)
outputs = tf.layers.dense({units: 1}).apply(outputs)
const model = tf.model({
inputs: this._sighted
? [this._telemetryInput, this._frameInputL, this._frameInputR, this._actionInput]
: [this._telemetryInput, this._actionInput],
outputs, name
})
model.trainable = trainable
if (this._verbose) {
console.log('==========================')
console.log('==========================')
console.log('CRITIC ' + name + ': ')
model.summary()
}
return model
}
// _encoder = null
// _getConvEncoder(inputs) {
// if (!this._encoder)
// this._encoder = this.__getConvEncoder(inputs)
// return this._encoder
// }
/**
* Builds convolutional part of a network.
*
* @param {Tensor} inputs - input for the conv layers
* @returns outputs
*/
_getConvEncoder(inputs) {
const kernelSize = 3
const padding = 'valid'
const poolSize = 3
const strides = 1
// const depthwiseInitializer = 'heNormal'
// const pointwiseInitializer = 'heNormal'
const kernelInitializer = 'glorotNormal'
const biasInitializer = 'glorotNormal'
let outputs = inputs
// 32x8x4 -> 64x4x2 -> 64x3x1 -> 64x4x1
outputs = tf.layers.conv2d({
filters: 16,
kernelSize: 5,
strides: 2,
padding,
kernelInitializer,
biasInitializer,
activation: 'relu',
trainable: true
}).apply(outputs)
outputs = tf.layers.maxPooling2d({poolSize:2}).apply(outputs)
//
// outputs = tf.layers.layerNormalization().apply(outputs)
outputs = tf.layers.conv2d({
filters: 16,
kernelSize: 3,
strides: 1,
padding,
kernelInitializer,
biasInitializer,
activation: 'relu',
trainable: true
}).apply(outputs)
outputs = tf.layers.maxPooling2d({poolSize:2}).apply(outputs)
// outputs = tf.layers.layerNormalization().apply(outputs)
// outputs = tf.layers.conv2d({
// filters: 12,
// kernelSize: 3,
// strides: 1,
// padding,
// kernelInitializer,
// biasInitializer,
// activation: 'relu',
// trainable: true
// }).apply(outputs)
// outputs = tf.layers.conv2d({
// filters: 10,
// kernelSize: 2,
// strides: 1,
// padding,
// kernelInitializer,
// biasInitializer,
// activation: 'relu',
// trainable: true
// }).apply(outputs)
// outputs = tf.layers.conv2d({
// filters: 64,
// kernelSize: 4,
// strides: 1,
// padding,
// kernelInitializer,
// biasInitializer,
// activation: 'relu'
// }).apply(outputs)
// outputs = tf.layers.batchNormalization().apply(outputs)
// outputs = tf.layers.layerNormalization().apply(outputs)
outputs = tf.layers.flatten().apply(outputs)
// convOutputs = tf.layers.dense({units: 96, activation: 'relu'}).apply(convOutputs)
return outputs
}
/**
* Returns clipped alpha.
*
* @returns {Tensor} entropy
*/
_getAlpha() {
// return tf.maximum(tf.exp(this._logAlpha), tf.scalar(this._minAlpha))
return tf.exp(this._logAlpha)
}
/**
* Builds a log of entropy scale (α) for training.
*
* @param {string} name
* @returns {tf.Variable} trainable variable for log entropy
*/
async _getLogAlpha(name = 'alpha') {
let logAlpha = 0.0
const checkpoint = await this._loadCheckpoint(name)
if (checkpoint) {
logAlpha = checkpoint.getWeights()[0].arraySync()[0][0]
if (this._verbose)
console.log('Checkpoint alpha: ', logAlpha)
this._logAlphaPlaceholder = checkpoint
} else {
const model = tf.sequential({ name });
model.add(tf.layers.dense({ units: 1, inputShape: [1], useBias: false }))
model.setWeights([tf.tensor([logAlpha], [1, 1])])
this._logAlphaPlaceholder = model
}
return tf.variable(tf.scalar(logAlpha), true) // true -> trainable
}
/**
* Saves all agent's models to the storage.
*/
async checkpoint() {
if (!this._trainable) throw new Error('(╭ರ_ ⊙ )')
this._logAlphaPlaceholder.setWeights([tf.tensor([this._logAlpha.arraySync()], [1, 1])])
await Promise.all([
this._saveCheckpoint(this.actor),
this._saveCheckpoint(this.q1),
this._saveCheckpoint(this.q2),
this._saveCheckpoint(this.q1Targ),
this._saveCheckpoint(this.q2Targ),
this._saveCheckpoint(this._logAlphaPlaceholder)
])
if (this._verbose)
console.log('Checkpoint succesfully saved')
}
/**
* Saves a model to the storage.
*
* @param {tf.LayersModel} model
*/
async _saveCheckpoint(model) {
const key = this._getChKey(model.name)
const saveResults = await model.save(key)
if (this._verbose)
console.log('Checkpoint saveResults', model.name, saveResults)
}
/**
* Loads saved checkpoint from the storage.
*
* @param {string} name model name
* @returns {tf.LayersModel} model
*/
async _loadCheckpoint(name) {
// return
if (this._forced) {
console.log('Forced to not load from the checkpoint ' + name)
return
}
const key = this._getChKey(name)
const modelsInfo = await tf.io.listModels()
if (key in modelsInfo) {
const model = await tf.loadLayersModel(key)
if (this._verbose)
console.log('Loaded checkpoint for ' + name)
return model
}
if (this._verbose)
console.log('Checkpoint not found for ' + name)
}
/**
* Builds the key for the model weights in LocalStorage.
*
* @param {tf.LayersModel} name model name
* @returns {string} key
*/
_getChKey(name) {
return 'indexeddb://' + name + '-' + VERSION
}
}
})()
/* TESTS */
;(async () => {
return
// https://www.wolframalpha.com/input/?i2d=true&i=y%5C%2840%29x%5C%2844%29+%CE%BC%5C%2844%29+%CF%83%5C%2841%29+%3D+ln%5C%2840%29Divide%5B1%2CSqrt%5B2*%CF%80*Power%5B%CF%83%2C2%5D%5D%5D*Exp%5B-Divide%5B1%2C2%5D*%5C%2840%29Divide%5BPower%5B%5C%2840%29x-%CE%BC%5C%2841%29%2C2%5D%2CPower%5B%CF%83%2C2%5D%5D%5C%2841%29%5D%5C%2841%29
;(() => {
const agent = new AgentSac()
const
mu = tf.tensor([0], [1, 1]), // mu = 0
logStd = tf.tensor([0], [1, 1]), // logStd = 0
std = tf.exp(logStd), // std = 1
normal = tf.tensor([0], [1, 1]), // N = 0
pi = mu.add(std.mul(normal)) // x = 0
const log = agent._gaussianLikelihood(pi, mu, logStd)
console.assert(log.arraySync()[0][0].toFixed(5) === '-0.91894',
'test Gaussian Likelihood for μ=0, σ=1, x=0')
})()
;(() => {
const agent = new AgentSac()
const
mu = tf.tensor([1], [1, 1]), // mu = 1
logStd = tf.tensor([1], [1, 1]), // logStd = 1
std = tf.exp(logStd), // std = e
normal = tf.tensor([0], [1, 1]), // N = 0
pi = mu.add(std.mul(normal)) // x = 1
const log = agent._gaussianLikelihood(pi, mu, logStd)
console.assert(log.arraySync()[0][0].toFixed(5) === '-1.91894',
'test Gaussian Likelihood for μ=1, σ=e, x=0')
})()
;(() => {
const agent = new AgentSac()
const
mu = tf.tensor([1], [1, 1]), // mu = -1
logStd = tf.tensor([1], [1, 1]), // logStd = 1
std = tf.exp(logStd), // std = e
normal = tf.tensor([0.1], [1, 1]), // N = 0
pi = mu.add(std.mul(normal)) // x = -1.27182818
const logPi = agent._gaussianLikelihood(pi, mu, logStd)
const { pi: piSquashed, logPi: logPiSquashed } = agent._applySquashing(pi, mu, logPi)
const logProbBounded = logPi.sub(
tf.log(
tf.scalar(1)
.sub(tf.tanh(pi).pow(tf.scalar(2)))
// .add(EPSILON)
)
).sum(1, true)
console.assert(logPi.arraySync()[0][0].toFixed(5) === '-1.92394',
'test Gaussian Likelihood for μ=-1, σ=e, x=-1.27182818')
console.assert(logPiSquashed.arraySync()[0][0].toFixed(5) === logProbBounded.arraySync()[0][0].toFixed(5),
'test logPiSquashed for μ=-1, σ=e, x=-1.27182818')
console.assert(piSquashed.arraySync()[0][0].toFixed(5) === tf.tanh(pi).arraySync()[0][0].toFixed(5),
'test piSquashed for μ=-1, σ=e, x=-1.27182818')
})()
await (async () => {
const state = tf.tensor([
0.5, 0.3, -0.9,
0, -0.8, 1,
-0.3, 0.04, 0.02,
0.9
], [1, 10])
const action = tf.tensor([
0.1, -1, -0.4,
1, -0.8, -0.8, -0.2,
0.04, 0.02, 0.001
], [1, 10])
const fresh = new AgentSac({ prefix: 'test', forced: true })
await fresh.init()
await fresh.checkpoint()
const saved = new AgentSac({ prefix: 'test' })
await saved.init()
let frPred, saPred
frPred = fresh.actor.predict(state, {batchSize: 1})
saPred = saved.actor.predict(state, {batchSize: 1})
console.assert(
frPred[0].arraySync().length > 0 &&
frPred[1].arraySync().length > 0 &&
frPred[0].arraySync().join(';') === saPred[0].arraySync().join(';') &&
frPred[1].arraySync().join(';') === saPred[1].arraySync().join(';'),
'Models loaded from the checkpoint should be the same')
frPred = fresh.q1.predict([state, action], {batchSize: 1})
saPred = fresh.q1Targ.predict([state, action], {batchSize: 1})
console.assert(
frPred.arraySync()[0][0] !== undefined &&
frPred.arraySync()[0][0] === saPred.arraySync()[0][0],
'Q1 and Q1-target should be the same')
frPred = fresh.q2.predict([state, action], {batchSize: 1})
saPred = saved.q2.predict([state, action], {batchSize: 1})
console.assert(
frPred.arraySync()[0][0] !== undefined &&
frPred.arraySync()[0][0] === saPred.arraySync()[0][0],
'Q and Q restored should be the same')
console.assert(
fresh._logAlpha.arraySync() !== undefined &&
fresh._logAlpha.arraySync() === fresh._logAlpha.arraySync(),
'Q and Q restored should be the same')
})()
})()