|
|
|
|
|
|
|
|
|
const AgentSac = (() => { |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
const assertShape = (tensor, shape, msg = '') => { |
|
console.assert( |
|
JSON.stringify(tensor.shape) === JSON.stringify(shape), |
|
msg + ' shape ' + tensor.shape + ' is not ' + shape) |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
const VERSION = 84 |
|
|
|
const LOG_STD_MIN = -20 |
|
const LOG_STD_MAX = 2 |
|
const EPSILON = 1e-8 |
|
const NAME = { |
|
ACTOR: 'actor', |
|
Q1: 'q1', |
|
Q2: 'q2', |
|
Q1_TARGET: 'q1-target', |
|
Q2_TARGET: 'q2-target', |
|
ALPHA: 'alpha' |
|
} |
|
|
|
return class AgentSac { |
|
constructor({ |
|
batchSize = 1, |
|
frameShape = [25, 25, 3], |
|
nFrames = 1, // Number of stacked frames per state |
|
nActions = 3, // 3 - impuls, 3 - RGB color |
|
nTelemetry = 10, // 3 - linear valocity, 3 - acceleration, 3 - collision point, 1 - lidar (tanh of distance) |
|
gamma = 0.99, // Discount factor (γ) |
|
tau = 5e-3, // Target smoothing coefficient (τ) |
|
trainable = true, // Whether the actor is trainable |
|
verbose = false, |
|
forced = false, // force to create fresh models (not from checkpoint) |
|
prefix = '', // for tests, |
|
sighted = true, |
|
rewardScale = 10 |
|
} = {}) { |
|
this._batchSize = batchSize |
|
this._frameShape = frameShape |
|
this._nFrames = nFrames |
|
this._nActions = nActions |
|
this._nTelemetry = nTelemetry |
|
this._gamma = gamma |
|
this._tau = tau |
|
this._trainable = trainable |
|
this._verbose = verbose |
|
this._inited = false |
|
this._prefix = (prefix === '' ? '' : prefix + '-') |
|
this._forced = forced |
|
this._sighted = sighted |
|
this._rewardScale = rewardScale |
|
|
|
this._frameStackShape = [...this._frameShape.slice(0, 2), this._frameShape[2] * this._nFrames] |
|
|
|
|
|
this._targetEntropy = -nActions |
|
} |
|
|
|
|
|
|
|
|
|
async init() { |
|
if (this._inited) throw Error('щ(゚Д゚щ)') |
|
|
|
this._frameInputL = tf.input({batchShape : [null, ...this._frameStackShape]}) |
|
this._frameInputR = tf.input({batchShape : [null, ...this._frameStackShape]}) |
|
|
|
this._telemetryInput = tf.input({batchShape : [null, this._nTelemetry]}) |
|
|
|
this.actor = await this._getActor(this._prefix + NAME.ACTOR, this.trainable) |
|
|
|
if (!this._trainable) |
|
return |
|
|
|
this.actorOptimizer = tf.train.adam() |
|
|
|
this._actionInput = tf.input({batchShape : [null, this._nActions]}) |
|
|
|
this.q1 = await this._getCritic(this._prefix + NAME.Q1) |
|
this.q1Optimizer = tf.train.adam() |
|
|
|
this.q2 = await this._getCritic(this._prefix + NAME.Q2) |
|
this.q2Optimizer = tf.train.adam() |
|
|
|
this.q1Targ = await this._getCritic(this._prefix + NAME.Q1_TARGET, true) |
|
this.q2Targ = await this._getCritic(this._prefix + NAME.Q2_TARGET, true) |
|
|
|
this._logAlpha = await this._getLogAlpha(this._prefix + NAME.ALPHA) |
|
this.alphaOptimizer = tf.train.adam() |
|
|
|
this.updateTargets(1) |
|
|
|
|
|
|
|
|
|
|
|
this._inited = true |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
train({ state, action, reward, nextState }) { |
|
if (!this._trainable) |
|
throw new Error('Actor is not trainable') |
|
|
|
return tf.tidy(() => { |
|
assertShape(state[0], [this._batchSize, this._nTelemetry], 'telemetry') |
|
assertShape(state[1], [this._batchSize, ...this._frameStackShape], 'frames') |
|
assertShape(action, [this._batchSize, this._nActions], 'action') |
|
assertShape(reward, [this._batchSize, 1], 'reward') |
|
assertShape(nextState[0], [this._batchSize, this._nTelemetry], 'nextState telemetry') |
|
assertShape(nextState[1], [this._batchSize, ...this._frameStackShape], 'nextState frames') |
|
|
|
this._trainCritics({ state, action, reward, nextState }) |
|
this._trainActor(state) |
|
this._trainAlpha(state) |
|
|
|
this.updateTargets() |
|
}) |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
_trainCritics({ state, action, reward, nextState }) { |
|
const getQLossFunction = (() => { |
|
const [nextFreshAction, logPi] = this.sampleAction(nextState, true) |
|
|
|
const q1TargValue = this.q1Targ.predict( |
|
this._sighted ? [...nextState, nextFreshAction] : [nextState[0], nextFreshAction], |
|
{batchSize: this._batchSize}) |
|
const q2TargValue = this.q2Targ.predict( |
|
this._sighted ? [...nextState, nextFreshAction] : [nextState[0], nextFreshAction], |
|
{batchSize: this._batchSize}) |
|
|
|
const qTargValue = tf.minimum(q1TargValue, q2TargValue) |
|
|
|
|
|
const alpha = this._getAlpha() |
|
const target = reward.mul(tf.scalar(this._rewardScale)).add( |
|
tf.scalar(this._gamma).mul( |
|
qTargValue.sub(alpha.mul(logPi)) |
|
) |
|
) |
|
|
|
assertShape(nextFreshAction, [this._batchSize, this._nActions], 'nextFreshAction') |
|
assertShape(logPi, [this._batchSize, 1], 'logPi') |
|
assertShape(qTargValue, [this._batchSize, 1], 'qTargValue') |
|
assertShape(target, [this._batchSize, 1], 'target') |
|
|
|
return (q) => () => { |
|
const qValue = q.predict( |
|
this._sighted ? [...state, action] : [state[0], action], |
|
{batchSize: this._batchSize}) |
|
|
|
|
|
const loss = tf.scalar(0.5).mul(tf.mean(qValue.sub(target).square())) |
|
|
|
assertShape(qValue, [this._batchSize, 1], 'qValue') |
|
|
|
return loss |
|
} |
|
})() |
|
|
|
for (const [q, optimizer] of [ |
|
[this.q1, this.q1Optimizer], |
|
[this.q2, this.q2Optimizer] |
|
]) { |
|
const qLossFunction = getQLossFunction(q) |
|
|
|
const { value, grads } = tf.variableGrads(qLossFunction, q.getWeights(true)) |
|
|
|
optimizer.applyGradients(grads) |
|
|
|
if (this._verbose) console.log(q.name + ' Loss: ' + value.arraySync()) |
|
} |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
_trainActor(state) { |
|
|
|
const actorLossFunction = () => { |
|
const [freshAction, logPi] = this.sampleAction(state, true) |
|
|
|
const q1Value = this.q1.predict( |
|
this._sighted ? [...state, freshAction] : [state[0], freshAction], |
|
{batchSize: this._batchSize}) |
|
const q2Value = this.q2.predict( |
|
this._sighted ? [...state, freshAction] : [state[0], freshAction], |
|
{batchSize: this._batchSize}) |
|
|
|
const criticValue = tf.minimum(q1Value, q2Value) |
|
|
|
const alpha = this._getAlpha() |
|
const loss = alpha.mul(logPi).sub(criticValue) |
|
|
|
assertShape(freshAction, [this._batchSize, this._nActions], 'freshAction') |
|
assertShape(logPi, [this._batchSize, 1], 'logPi') |
|
assertShape(q1Value, [this._batchSize, 1], 'q1Value') |
|
assertShape(criticValue, [this._batchSize, 1], 'criticValue') |
|
assertShape(loss, [this._batchSize, 1], 'alpha loss') |
|
|
|
return tf.mean(loss) |
|
} |
|
|
|
const { value, grads } = tf.variableGrads(actorLossFunction, this.actor.getWeights(true)) |
|
|
|
this.actorOptimizer.applyGradients(grads) |
|
|
|
if (this._verbose) console.log('Actor Loss: ' + value.arraySync()) |
|
} |
|
|
|
_trainAlpha(state) { |
|
const alphaLossFunction = () => { |
|
const [, logPi] = this.sampleAction(state, true) |
|
|
|
const alpha = this._getAlpha() |
|
const loss = tf.scalar(-1).mul( |
|
alpha.mul( |
|
logPi.add(tf.scalar(this._targetEntropy)) |
|
) |
|
) |
|
|
|
assertShape(loss, [this._batchSize, 1], 'alpha loss') |
|
|
|
return tf.mean(loss) |
|
} |
|
|
|
const { value, grads } = tf.variableGrads(alphaLossFunction, [this._logAlpha]) |
|
|
|
this.alphaOptimizer.applyGradients(grads) |
|
|
|
if (this._verbose) console.log('Alpha Loss: ' + value.arraySync(), tf.exp(this._logAlpha).arraySync()) |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
updateTargets(tau = this._tau) { |
|
tau = tf.scalar(tau) |
|
|
|
const |
|
q1W = this.q1.getWeights(), |
|
q2W = this.q2.getWeights(), |
|
q1WTarg = this.q1Targ.getWeights(), |
|
q2WTarg = this.q2Targ.getWeights(), |
|
len = q1W.length |
|
|
|
|
|
|
|
|
|
const calc = (w, wTarg) => wTarg.mul(tf.scalar(1).sub(tau)).add(w.mul(tau)) |
|
|
|
const w1 = [], w2 = [] |
|
for (let i = 0; i < len; i++) { |
|
w1.push(calc(q1W[i], q1WTarg[i])) |
|
w2.push(calc(q2W[i], q2WTarg[i])) |
|
} |
|
|
|
this.q1Targ.setWeights(w1) |
|
this.q2Targ.setWeights(w2) |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sampleAction(state, withLogProbs = false) { |
|
return tf.tidy(() => { |
|
let [ mu, logStd ] = this.actor.predict(this._sighted ? state : state[0], {batchSize: this._batchSize}) |
|
|
|
|
|
logStd = tf.clipByValue(logStd, LOG_STD_MIN, LOG_STD_MAX) |
|
|
|
const std = tf.exp(logStd) |
|
|
|
|
|
const normal = tf.randomNormal(mu.shape, 0, 1.0) |
|
|
|
|
|
let pi = mu.add(std.mul(normal)) |
|
|
|
let logPi = this._gaussianLikelihood(pi, mu, logStd) |
|
|
|
;({ pi, logPi } = this._applySquashing(pi, mu, logPi)) |
|
|
|
if (!withLogProbs) |
|
return pi |
|
|
|
return [pi, logPi] |
|
}) |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
_logProb(x, mu, std) { |
|
const logUnnormalized = tf.scalar(-0.5).mul( |
|
tf.squaredDifference(x.div(std), mu.div(std)) |
|
) |
|
const logNormalization = tf.scalar(0.5 * Math.log(2 * Math.PI)).add(tf.log(std)) |
|
|
|
return logUnnormalized.sub(logNormalization) |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
_gaussianLikelihood(x, mu, logStd) { |
|
|
|
|
|
|
|
|
|
|
|
|
|
const preSum = tf.scalar(-0.5).mul( |
|
x.sub(mu).div( |
|
tf.exp(logStd).add(tf.scalar(EPSILON)) |
|
).square() |
|
.add(tf.scalar(2).mul(logStd)) |
|
.add(tf.scalar(Math.log(2 * Math.PI))) |
|
) |
|
|
|
return tf.sum(preSum, 1, true) |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
_applySquashing(pi, mu, logPi) { |
|
|
|
|
|
const adj = tf.scalar(2).mul( |
|
tf.scalar(Math.log(2)) |
|
.sub(pi) |
|
.sub(tf.softplus( |
|
tf.scalar(-2).mul(pi) |
|
)) |
|
) |
|
|
|
logPi = logPi.sub(tf.sum(adj, 1, true)) |
|
mu = tf.tanh(mu) |
|
pi = tf.tanh(pi) |
|
|
|
return { pi, mu, logPi } |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async _getActor(name = 'actor', trainable = true) { |
|
const checkpoint = await this._loadCheckpoint(name) |
|
if (checkpoint) return checkpoint |
|
|
|
let outputs = this._telemetryInput |
|
|
|
|
|
if (this._sighted) { |
|
let convOutputL = this._getConvEncoder(this._frameInputL) |
|
let convOutputR = this._getConvEncoder(this._frameInputR) |
|
|
|
|
|
|
|
outputs = tf.layers.concatenate().apply([convOutputL, convOutputR, outputs]) |
|
} |
|
|
|
outputs = tf.layers.dense({units: 256, activation: 'relu'}).apply(outputs) |
|
outputs = tf.layers.dense({units: 256, activation: 'relu'}).apply(outputs) |
|
|
|
const mu = tf.layers.dense({units: this._nActions}).apply(outputs) |
|
const logStd = tf.layers.dense({units: this._nActions}).apply(outputs) |
|
|
|
const model = tf.model({inputs: this._sighted ? [this._telemetryInput, this._frameInputL, this._frameInputR] : [this._telemetryInput], outputs: [mu, logStd], name}) |
|
model.trainable = trainable |
|
|
|
if (this._verbose) { |
|
console.log('==========================') |
|
console.log('==========================') |
|
console.log('Actor ' + name + ': ') |
|
|
|
model.summary() |
|
} |
|
|
|
return model |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async _getCritic(name = 'critic', trainable = true) { |
|
const checkpoint = await this._loadCheckpoint(name) |
|
if (checkpoint) return checkpoint |
|
|
|
let outputs = tf.layers.concatenate().apply([this._telemetryInput, this._actionInput]) |
|
|
|
|
|
if (this._sighted) { |
|
let convOutputL = this._getConvEncoder(this._frameInputL) |
|
let convOutputR = this._getConvEncoder(this._frameInputR) |
|
|
|
|
|
|
|
outputs = tf.layers.concatenate().apply([convOutputL, convOutputR, outputs]) |
|
} |
|
|
|
outputs = tf.layers.dense({units: 256, activation: 'relu'}).apply(outputs) |
|
outputs = tf.layers.dense({units: 256, activation: 'relu'}).apply(outputs) |
|
|
|
outputs = tf.layers.dense({units: 1}).apply(outputs) |
|
|
|
const model = tf.model({ |
|
inputs: this._sighted |
|
? [this._telemetryInput, this._frameInputL, this._frameInputR, this._actionInput] |
|
: [this._telemetryInput, this._actionInput], |
|
outputs, name |
|
}) |
|
|
|
model.trainable = trainable |
|
|
|
if (this._verbose) { |
|
console.log('==========================') |
|
console.log('==========================') |
|
console.log('CRITIC ' + name + ': ') |
|
|
|
model.summary() |
|
} |
|
|
|
return model |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
_getConvEncoder(inputs) { |
|
const kernelSize = 3 |
|
const padding = 'valid' |
|
const poolSize = 3 |
|
const strides = 1 |
|
|
|
|
|
const kernelInitializer = 'glorotNormal' |
|
const biasInitializer = 'glorotNormal' |
|
|
|
let outputs = inputs |
|
|
|
|
|
outputs = tf.layers.conv2d({ |
|
filters: 16, |
|
kernelSize: 5, |
|
strides: 2, |
|
padding, |
|
kernelInitializer, |
|
biasInitializer, |
|
activation: 'relu', |
|
trainable: true |
|
}).apply(outputs) |
|
outputs = tf.layers.maxPooling2d({poolSize:2}).apply(outputs) |
|
|
|
|
|
|
|
outputs = tf.layers.conv2d({ |
|
filters: 16, |
|
kernelSize: 3, |
|
strides: 1, |
|
padding, |
|
kernelInitializer, |
|
biasInitializer, |
|
activation: 'relu', |
|
trainable: true |
|
}).apply(outputs) |
|
outputs = tf.layers.maxPooling2d({poolSize:2}).apply(outputs) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
outputs = tf.layers.flatten().apply(outputs) |
|
|
|
|
|
|
|
return outputs |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
_getAlpha() { |
|
|
|
return tf.exp(this._logAlpha) |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async _getLogAlpha(name = 'alpha') { |
|
let logAlpha = 0.0 |
|
|
|
const checkpoint = await this._loadCheckpoint(name) |
|
if (checkpoint) { |
|
logAlpha = checkpoint.getWeights()[0].arraySync()[0][0] |
|
|
|
if (this._verbose) |
|
console.log('Checkpoint alpha: ', logAlpha) |
|
|
|
this._logAlphaPlaceholder = checkpoint |
|
} else { |
|
const model = tf.sequential({ name }); |
|
model.add(tf.layers.dense({ units: 1, inputShape: [1], useBias: false })) |
|
model.setWeights([tf.tensor([logAlpha], [1, 1])]) |
|
|
|
this._logAlphaPlaceholder = model |
|
} |
|
|
|
return tf.variable(tf.scalar(logAlpha), true) |
|
} |
|
|
|
|
|
|
|
|
|
async checkpoint() { |
|
if (!this._trainable) throw new Error('(╭ರ_ ⊙ )') |
|
|
|
this._logAlphaPlaceholder.setWeights([tf.tensor([this._logAlpha.arraySync()], [1, 1])]) |
|
|
|
await Promise.all([ |
|
this._saveCheckpoint(this.actor), |
|
this._saveCheckpoint(this.q1), |
|
this._saveCheckpoint(this.q2), |
|
this._saveCheckpoint(this.q1Targ), |
|
this._saveCheckpoint(this.q2Targ), |
|
this._saveCheckpoint(this._logAlphaPlaceholder) |
|
]) |
|
|
|
if (this._verbose) |
|
console.log('Checkpoint succesfully saved') |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
async _saveCheckpoint(model) { |
|
const key = this._getChKey(model.name) |
|
const saveResults = await model.save(key) |
|
|
|
if (this._verbose) |
|
console.log('Checkpoint saveResults', model.name, saveResults) |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async _loadCheckpoint(name) { |
|
|
|
if (this._forced) { |
|
console.log('Forced to not load from the checkpoint ' + name) |
|
return |
|
} |
|
|
|
const key = this._getChKey(name) |
|
const modelsInfo = await tf.io.listModels() |
|
|
|
if (key in modelsInfo) { |
|
const model = await tf.loadLayersModel(key) |
|
|
|
if (this._verbose) |
|
console.log('Loaded checkpoint for ' + name) |
|
|
|
return model |
|
} |
|
|
|
if (this._verbose) |
|
console.log('Checkpoint not found for ' + name) |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
_getChKey(name) { |
|
return 'indexeddb://' + name + '-' + VERSION |
|
} |
|
} |
|
})() |
|
|
|
|
|
;(async () => { |
|
return |
|
|
|
|
|
;(() => { |
|
const agent = new AgentSac() |
|
|
|
const |
|
mu = tf.tensor([0], [1, 1]), |
|
logStd = tf.tensor([0], [1, 1]), |
|
std = tf.exp(logStd), |
|
normal = tf.tensor([0], [1, 1]), |
|
pi = mu.add(std.mul(normal)) |
|
|
|
const log = agent._gaussianLikelihood(pi, mu, logStd) |
|
|
|
console.assert(log.arraySync()[0][0].toFixed(5) === '-0.91894', |
|
'test Gaussian Likelihood for μ=0, σ=1, x=0') |
|
})() |
|
|
|
;(() => { |
|
const agent = new AgentSac() |
|
|
|
const |
|
mu = tf.tensor([1], [1, 1]), |
|
logStd = tf.tensor([1], [1, 1]), |
|
std = tf.exp(logStd), |
|
normal = tf.tensor([0], [1, 1]), |
|
pi = mu.add(std.mul(normal)) |
|
|
|
const log = agent._gaussianLikelihood(pi, mu, logStd) |
|
|
|
console.assert(log.arraySync()[0][0].toFixed(5) === '-1.91894', |
|
'test Gaussian Likelihood for μ=1, σ=e, x=0') |
|
})() |
|
|
|
;(() => { |
|
const agent = new AgentSac() |
|
|
|
const |
|
mu = tf.tensor([1], [1, 1]), |
|
logStd = tf.tensor([1], [1, 1]), |
|
std = tf.exp(logStd), |
|
normal = tf.tensor([0.1], [1, 1]), |
|
pi = mu.add(std.mul(normal)) |
|
|
|
const logPi = agent._gaussianLikelihood(pi, mu, logStd) |
|
const { pi: piSquashed, logPi: logPiSquashed } = agent._applySquashing(pi, mu, logPi) |
|
|
|
const logProbBounded = logPi.sub( |
|
tf.log( |
|
tf.scalar(1) |
|
.sub(tf.tanh(pi).pow(tf.scalar(2))) |
|
|
|
) |
|
).sum(1, true) |
|
|
|
console.assert(logPi.arraySync()[0][0].toFixed(5) === '-1.92394', |
|
'test Gaussian Likelihood for μ=-1, σ=e, x=-1.27182818') |
|
|
|
console.assert(logPiSquashed.arraySync()[0][0].toFixed(5) === logProbBounded.arraySync()[0][0].toFixed(5), |
|
'test logPiSquashed for μ=-1, σ=e, x=-1.27182818') |
|
|
|
console.assert(piSquashed.arraySync()[0][0].toFixed(5) === tf.tanh(pi).arraySync()[0][0].toFixed(5), |
|
'test piSquashed for μ=-1, σ=e, x=-1.27182818') |
|
})() |
|
|
|
await (async () => { |
|
const state = tf.tensor([ |
|
0.5, 0.3, -0.9, |
|
0, -0.8, 1, |
|
-0.3, 0.04, 0.02, |
|
0.9 |
|
], [1, 10]) |
|
|
|
const action = tf.tensor([ |
|
0.1, -1, -0.4, |
|
1, -0.8, -0.8, -0.2, |
|
0.04, 0.02, 0.001 |
|
], [1, 10]) |
|
|
|
const fresh = new AgentSac({ prefix: 'test', forced: true }) |
|
await fresh.init() |
|
await fresh.checkpoint() |
|
|
|
const saved = new AgentSac({ prefix: 'test' }) |
|
await saved.init() |
|
|
|
let frPred, saPred |
|
|
|
frPred = fresh.actor.predict(state, {batchSize: 1}) |
|
saPred = saved.actor.predict(state, {batchSize: 1}) |
|
console.assert( |
|
frPred[0].arraySync().length > 0 && |
|
frPred[1].arraySync().length > 0 && |
|
frPred[0].arraySync().join(';') === saPred[0].arraySync().join(';') && |
|
frPred[1].arraySync().join(';') === saPred[1].arraySync().join(';'), |
|
'Models loaded from the checkpoint should be the same') |
|
|
|
frPred = fresh.q1.predict([state, action], {batchSize: 1}) |
|
saPred = fresh.q1Targ.predict([state, action], {batchSize: 1}) |
|
console.assert( |
|
frPred.arraySync()[0][0] !== undefined && |
|
frPred.arraySync()[0][0] === saPred.arraySync()[0][0], |
|
'Q1 and Q1-target should be the same') |
|
|
|
frPred = fresh.q2.predict([state, action], {batchSize: 1}) |
|
saPred = saved.q2.predict([state, action], {batchSize: 1}) |
|
console.assert( |
|
frPred.arraySync()[0][0] !== undefined && |
|
frPred.arraySync()[0][0] === saPred.arraySync()[0][0], |
|
'Q and Q restored should be the same') |
|
|
|
console.assert( |
|
fresh._logAlpha.arraySync() !== undefined && |
|
fresh._logAlpha.arraySync() === fresh._logAlpha.arraySync(), |
|
'Q and Q restored should be the same') |
|
})() |
|
})() |
|
|