|
importScripts('https://cdn.jsdelivr.net/npm/@tensorflow/[email protected]/dist/tf.min.js') |
|
importScripts('agent_sac.js') |
|
importScripts('reply_buffer.js') |
|
|
|
;(async () => { |
|
const DISABLED = false |
|
|
|
const agent = new AgentSac({batchSize: 100, verbose: true}) |
|
await agent.init() |
|
await agent.checkpoint() |
|
agent.actor.summary() |
|
self.postMessage({weights: await Promise.all(agent.actor.getWeights().map(w => w.array()))}) |
|
|
|
const rb = new ReplyBuffer(50000, ({ state: [telemetry, frameL, frameR], action, reward }) => { |
|
frameL.dispose() |
|
frameR.dispose() |
|
telemetry.dispose() |
|
action.dispose() |
|
reward.dispose() |
|
}) |
|
|
|
|
|
|
|
|
|
|
|
|
|
const job = async () => { |
|
|
|
if (DISABLED) return 99999 |
|
if (rb.size < agent._batchSize*10) return 1000 |
|
|
|
const samples = rb.sample(agent._batchSize) |
|
if (!samples.length) return 1000 |
|
|
|
const |
|
framesL = [], |
|
framesR = [], |
|
telemetries = [], |
|
actions = [], |
|
rewards = [], |
|
nextFramesL = [], |
|
nextFramesR = [], |
|
nextTelemetries = [] |
|
|
|
for (const { |
|
state: [telemetry, frameL, frameR], |
|
action, |
|
reward, |
|
nextState: [nextTelemetry, nextFrameL, nextFrameR] |
|
} of samples) { |
|
framesL.push(frameL) |
|
framesR.push(frameR) |
|
telemetries.push(telemetry) |
|
actions.push(action) |
|
rewards.push(reward) |
|
nextFramesL.push(nextFrameL) |
|
nextFramesR.push(nextFrameR) |
|
nextTelemetries.push(nextTelemetry) |
|
} |
|
|
|
tf.tidy(() => { |
|
console.time('train') |
|
agent.train({ |
|
state: [tf.stack(telemetries), tf.stack(framesL), tf.stack(framesR)], |
|
action: tf.stack(actions), |
|
reward: tf.stack(rewards), |
|
nextState: [tf.stack(nextTelemetries), tf.stack(nextFramesL), tf.stack(nextFramesR)] |
|
}) |
|
console.timeEnd('train') |
|
}) |
|
|
|
console.time('train postMessage') |
|
self.postMessage({ |
|
weights: await Promise.all(agent.actor.getWeights().map(w => w.array())) |
|
}) |
|
console.timeEnd('train postMessage') |
|
|
|
return 1 |
|
} |
|
|
|
|
|
|
|
|
|
const tick = async () => { |
|
try { |
|
setTimeout(tick, await job()) |
|
} catch (e) { |
|
console.error(e) |
|
setTimeout(tick, 5000) |
|
} |
|
} |
|
|
|
setTimeout(tick, 1000) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
const decodeTransition = transition => { |
|
let { id, state: [telemetry, frameL, frameR], action, reward, priority } = transition |
|
|
|
return tf.tidy(() => { |
|
state = [ |
|
tf.tensor1d(telemetry), |
|
tf.tensor3d(frameL, agent._frameStackShape), |
|
tf.tensor3d(frameR, agent._frameStackShape) |
|
] |
|
action = tf.tensor1d(action) |
|
reward = tf.tensor1d([reward]) |
|
|
|
return { id, state, action, reward, priority } |
|
}) |
|
} |
|
|
|
let i = 0 |
|
self.addEventListener('message', async e => { |
|
i++ |
|
|
|
if (DISABLED) return |
|
if (i%50 === 0) console.log('RBSIZE: ', rb.size) |
|
|
|
switch (e.data.action) { |
|
case 'newTransition': |
|
const transition = decodeTransition(e.data.transition) |
|
rb.add(transition) |
|
|
|
tf.tidy(()=> { |
|
return |
|
const { |
|
state: [telemetry, frameL, frameR], |
|
action, |
|
} = transition; |
|
const state = [tf.stack([telemetry]), tf.stack([frameL]), tf.stack([frameR])] |
|
const q1TargValue = agent.q1Targ.predict([...state, tf.stack([action])], {batchSize: 1}) |
|
const q2TargValue = agent.q2Targ.predict([...state, tf.stack([action])], {batchSize: 1}) |
|
console.log('value', Math.min(q1TargValue.arraySync()[0][0], q2TargValue.arraySync()[0][0]).toFixed(5)) |
|
}) |
|
|
|
|
|
break |
|
default: |
|
console.warn('Unknown action') |
|
break |
|
} |
|
|
|
if (i % rb._limit === 0) |
|
agent.checkpoint() |
|
}) |
|
})() |
|
|