In addition to Weibo, there is also WeChat
Please pay attention
WeChat public account
Shulou
2025-03-28 Update From: SLTechnology News&Howtos shulou NAV: SLTechnology News&Howtos > Internet Technology >
Share
Shulou(Shulou.com)06/01 Report--
This article introduces the relevant knowledge of "how to train an immortal bird with Java". In the operation process of actual cases, many people will encounter such difficulties. Next, let Xiaobian lead you to learn how to deal with these situations! I hope you can read carefully and learn something!
Architecture of Reinforcement Learning (RL)
In this section, we will introduce the main algorithms and neural networks used to help you better understand how to train them. This project uses a similar approach to DeepLearningFlappyBird. The overall architecture of the algorithm is Q-Learning + Convolutional Neural Network (CNN), which stores the state of each frame of the game, i.e. the actions taken by the bird and the effects after taking the actions, which will be used as training data for the convolutional neural network.
CNN Training Brief
The input data for CNN is 4 consecutive frames of images. We stack these images as the bird's current "observation" and convert them into grayscale images to reduce the training resources required. The matrix form of the image storage is (batch size, 4 (frames), 80 (width), 80 (height)). The elements in the array are the pixel values of the current frame. These data will be input to CNN and output as a matrix of (batch size, 2). The second dimension of the matrix is the corresponding payoff of the bird (wings do not take action).
training data
After the bird takes the action, we get preObservation and currentObservation, which are two sets of four consecutive images showing the bird before and after the action. Then we store the five-tuple of preObservation, currentObservation, action, reward, terminal as a step in replayBuffer. It is a training dataset of finite size that dynamically updates with the latest operations.
public void step(NDList action, boolean training) { if (action.singletonOrThrow().getInt(1) == 1) { bird.birdFlap(); } stepFrame(); NDList preObservation = currentObservation; currentObservation = createObservation(currentImg); FlappyBirdStep step = new FlappyBirdStep(manager.newSubManager(), preObservation, currentObservation, action, currentReward, currentTerminal); if (training) { replayBuffer.addStep(step); } if (gameState == GAME_OVER) { restartGame(); Three cycles of training
The training is divided into 3 different cycles to better generate training data:
Observe Period: Randomly generate training data
Explore cycle: random and inferential actions combined to update training data
Training cycle: Inference actions dominate new data generation
Through this training mode, we can better achieve the desired results.
During the Explore cycle, we choose random actions based on weights or actions inferred from the model as bird actions. Early in training, random actions are heavily weighted because the model's decisions are inaccurate (or even less random). Later in training, as the model learns more actions, we increase the weight of the model inference action and eventually make it the dominant action. The parameter that regulates the random action is called epsilon and it changes over the course of the training.
public NDList chooseAction(RlEnv env, boolean training) { if (training && RandomUtils.random()
< exploreRate.getNewValue(counter++)) { return env.getActionSpace().randomAction(); } else return baseAgent.chooseAction(env, training);}训练逻辑 首先,我们会从 replayBuffer 中随机抽取一批数据作为作为训练集。然后将 preObservation 输入到神经网络得到所有行为的 reward(Q)作为预测值: NDList QReward = trainer.forward(preInput);NDList Q = new NDList(QReward.singletonOrThrow() .mul(actionInput.singletonOrThrow()) .sum(new int[]{1})); postObservation 同样会输入到神经网络,根据马尔科夫决策过程以及贝尔曼价值函数计算出所有行为的 reward(targetQ)作为真实值: // 将 postInput 输入到神经网络中得到 targetQReward 是 (batchsize,2) 的矩阵。根据 Q-learning 的算法,每一次的 targetQ 需要根据当前环境是否结束算出不同的值,因此需要将每一个 step 的 targetQ 单独算出后再将 targetQ 堆积成 NDList。NDList targetQReward = trainer.forward(postInput);NDArray[] targetQValue = new NDArray[batchSteps.length]; for (int i = 0; i < batchSteps.length; i++) { if (batchSteps[i].isTerminal()) { targetQValue[i] = batchSteps[i].getReward(); } else { targetQValue[i] = targetQReward.singletonOrThrow().get(i) .max() .mul(rewardDiscount) .add(rewardInput.singletonOrThrow().get(i)); }}NDList targetQBatch = new NDList();Arrays.stream(targetQValue).forEach(value ->targetQBatch.addAll(new NDList(value)));NDList targetQ = new NDList(NDArrays.stack(targetQBatch, 0));
At the end of training, the loss values for Q and targetQ are calculated and the weights are updated in the CNN.
Convolutional Neural Network Model (CNN)
We adopted a neural network architecture with 3 convolutional layers, 4 relu activation functions and 2 fully connected layers.
layerinput shapeoutput shapeconv2d(batchSize, 4, 80, 80)(batchSize,4, 20,20)conv2d(batchSize, 4, 20, 20)(batchSize, 32, 9, 9)conv2d(batchSize, 32, 9, 9)(batchSize, 64, 7, 7)linear(batchSize, 3136)(batchSize, 512)linear(batchSize, 512)(batchSize, 2) Training process
The RL library of DJL provides very convenient interfaces for implementing reinforcement learning: (RlEnv, RlAgent, ReplayBuffer).
By implementing the RlAgent interface, an agent can be constructed that can be trained.
Implementing the RlEnv interface in an existing game environment can generate the data needed for training.
Create ReplayBuffer to store and dynamically update training data.
After implementing these interfaces, you only need to call the step method:
RlEnv.step(action, training);
This method inputs the RlAgent's decision into the game environment for feedback. We can call the step method in the runEnviroment method provided in RlEnv, and then just repeat the runEnvironment method to continuously generate data for training.
public Step[] runEnvironment(RlAgent agent, boolean training) { // run the game NDList action = agent.chooseAction(this, training); step(action, training); if (training) { batchSteps = this.getBatch(); } return batchSteps;}
We set the number of steps that ReplayBuffer can store to 50000. During the observe cycle, we will store 1000 steps generated using random actions in ReplayBuffer first, so that the agent can learn from random actions faster.
During the explore and training cycles, the neural network randomly generates training sets from replayBuffer and inputs them into the model for training. We iterated the neural network using Adam optimizer and MSE loss function.
Neural network input preprocessing
First resize the image to 80x80 and convert to grayscale, which helps improve training speed without losing information.
public static NDArray imgPreprocess(BufferedImage observation) { return NDImageUtils.toTensor( NDImageUtils.resize( ImageFactory.getInstance().fromImage(observation) .toNDArray(NDManager.newBaseManager(), Image.Flag.GRAYSCALE) ,80,80));}
Then we take four consecutive frames as an input. To get four consecutive frames, we maintain a global image queue that holds images in the game thread, replace the oldest frame after each action, and stack the images in the queue into a single NDArray.
public NDList createObservation(BufferedImage currentImg) { NDArray observation = GameUtil.imgPreprocess(currentImg); if (imgQueue.isEmpty()) { for (int i = 0; i < 4; i++) { imgQueue.offer(observation); } return new NDList(NDArrays.stack(new NDList(observation, observation, observation, observation), 1)); } else { imgQueue.remove(); imgQueue.offer(observation); NDArray[] buf = new NDArray[4]; int i = 0; for (NDArray nd : imgQueue) { buf[i++] = nd; } return new NDList(NDArrays.stack(new NDList(buf[0], buf[1], buf[2], buf[3]), 1)); }}
Once that's done, we can start training. Training Optimization For optimal training performance, we turned off GUI to speed up sample generation. Java multithreading is used to run the training loop and sample generation loop separately in different threads.
List callables = new ArrayList(numOfThreads);callables.add(new GeneratorCallable(game, agent, training));if(training) { callables.add(new TrainerCallable(model, agent));}
"How to use Java to train an immortal bird" content is introduced here, thank you for reading. If you want to know more about industry-related knowledge, you can pay attention to the website. Xiaobian will output more high-quality practical articles for everyone!
Welcome to subscribe "Shulou Technology Information " to get latest news, interesting things and hot topics in the IT industry, and controls the hottest and latest Internet news, technology news and IT industry trends.
Views: 0
*The comments in the above article only represent the author's personal views and do not represent the views and positions of this website. If you have more insights, please feel free to contribute and share.
Continue with the installation of the previous hadoop.First, install zookooper1. Decompress zookoope
"Every 5-10 years, there's a rare product, a really special, very unusual product that's the most un
© 2024 shulou.com SLNews company. All rights reserved.