Network Security Internet Technology Development Database Servers Mobile Phone Android Software Apple Software Computer Software News IT Information

In addition to Weibo, there is also WeChat

Please pay attention

WeChat public account

Shulou

How to use genetic Neural Network to complete handwritten digit recognition in dl4j

2025-01-30 Update From: SLTechnology News&Howtos shulou NAV: SLTechnology News&Howtos > Internet Technology >

Share

Shulou(Shulou.com)06/01 Report--

Today, I will talk to you about how dl4j uses genetic neural network to complete handwritten digit recognition, which may not be well understood by many people. in order to make you understand better, the editor has summarized the following contents for you. I hope you can get something according to this article.

Implementation steps

1. Randomly initialize several agents (neural networks), let the agents recognize the training data, and sort the recognition results.

two。 Randomly select one of the ranking results as the female parent, and randomly select one of the agents with a higher recognition rate than the female parent as the male parent.

3. Randomly select the neural network superparameters of female or male parents to form a new agent.

4. The over-parameter adjustment of the agent is carried out according to the order of the female parent, and the later the ranking, the greater the adjustment (1%-10%).

5. Let the new agent recognize the training set and put it in the ranking, and remove the last one from the ranking.

6. Repeat the 2-5 process to make the recognition rate higher and higher.

This process is similar to the survival of the fittest in nature, in which the overparameter of neural network is regarded as dna, and the adjustment of overparameter is regarded as the sudden change of dna. Of course, neural networks with different hidden layers can be regarded as different species, which makes the competition process more diversified. Of course, we're only talking about one kind of neural network here.

Advantages: it can solve many clueless problems and disadvantages: training efficiency is extremely low.

Gitee address:

Https://gitee.com/ichiva/gnn.git

Implement step 1. Evolutionary interface

Public interface Evolution {/ * genetic * @ param mDna * @ param fDna * @ return * / INDArray inheritance (INDArray mDna,INDArray fDna); / * mutation * @ param dna * @ param v * @ param r mutation range * @ return * / INDArray mutation (INDArray dna,double v, double r) / * replacement * @ param dna * @ param v * @ return * / INDArray substitution (INDArray dna,double v); / * exogenous * @ param dna * @ param v * @ return * / INDArray other (INDArray dna,double v) / * * whether DNA is of the same origin * @ param mDna * @ param fDna * @ return * / boolean iSogeny (INDArray mDna, INDArray fDna);}

A more general implementation

Public class MnistEvolution implements Evolution {private static final MnistEvolution instance = new MnistEvolution (); public static MnistEvolution getInstance () {return instance;} @ Override public INDArray inheritance (INDArray mDna, INDArray fDna) {if (mDna = = fDna) return mDna; long [] mShape = mDna.shape (); if (! iSogeny (mDna,fDna)) {throw new RuntimeException ("non-homologous dna") } INDArray nDna = Nd4j.create (mShape); NdIndexIterator it = new NdIndexIterator (mShape); while (it.hasNext ()) {long [] next = it.next (); doubleval; if (Math.random () > 0.5) {val = fDna.getDouble (next);} else {val = mDna.getDouble (next) } nDna.putScalar (next,val);} return nDna;} @ Override public INDArray mutation (INDArray dna, double v, double r) {long [] shape = dna.shape (); INDArray nDna = Nd4j.create (shape); NdIndexIterator it = new NdIndexIterator (shape); while (it.hasNext ()) {long [] next = it.next () If (Math.random ()

< v){ dna.putScalar(next,dna.getDouble(next) + ((Math.random() - 0.5) * r * 2)); }else { nDna.putScalar(next,dna.getDouble(next)); } } return nDna; } @Override public INDArray substitution(INDArray dna, double v) { long[] shape = dna.shape(); INDArray nDna = Nd4j.create(shape); NdIndexIterator it = new NdIndexIterator(shape); while (it.hasNext()) { long[] next = it.next(); if(Math.random() >

V) {long [] tag = new long [shape.length]; for (int I = 0; I

< shape.length; i++) { tag[i] = (long) (Math.random() * shape[i]); } nDna.putScalar(next,dna.getDouble(tag)); }else { nDna.putScalar(next,dna.getDouble(next)); } } return nDna; } @Override public INDArray other(INDArray dna, double v) { long[] shape = dna.shape(); INDArray nDna = Nd4j.create(shape); NdIndexIterator it = new NdIndexIterator(shape); while (it.hasNext()) { long[] next = it.next(); if(Math.random() >

V) {nDna.putScalar (next,Math.random ());} else {nDna.putScalar (next,dna.getDouble (next));}} return nDna;} @ Override public boolean iSogeny (INDArray mDna, INDArray fDna) {long [] mShape = mDna.shape (); long [] fShape = fDna.shape () If (mShape.length = = fShape.length) {for (int I = 0; I)

< mShape.length; i++) { if (mShape[i] != fShape[i]) { return false; } } return true; } return false; }} 定义智能体配置接口 public interface AgentConfig { /** * 输入量 * @return */ int getInput(); /** * 输出量 * @return */ int getOutput(); /** * 神经网络配置 * @return */ MultiLayerConfiguration getMultiLayerConfiguration();} 按手写数字识别进行配置实现 public class MnistConfig implements AgentConfig { @Override public int getInput() { return 28 * 28; } @Override public int getOutput() { return 10; } @Override public MultiLayerConfiguration getMultiLayerConfiguration() { return new NeuralNetConfiguration.Builder() .seed((long) (Math.random() * Long.MAX_VALUE)) .updater(new Nesterovs(0.006, 0.9)) .l2(1e-4) .list() .layer(0, new DenseLayer.Builder() .nIn(getInput()) .nOut(1000) .activation(Activation.RELU) .weightInit(WeightInit.XAVIER) .build()) .layer(1, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD) //create hidden layer .nIn(1000) .nOut(getOutput()) .activation(Activation.SOFTMAX) .weightInit(WeightInit.XAVIER) .build()) .pretrain(false).backprop(true) .build(); }} 智能体基类 @Getterpublic class Agent { private final AgentConfig config; private final INDArray dna; private final MultiLayerNetwork multiLayerNetwork; /** * 采用默认方法初始化参数 * @param config */ public Agent(AgentConfig config){ this(config,null); } /** * * @param config * @param dna */ public Agent(AgentConfig config, INDArray dna){ if(dna == null){ this.config = config; MultiLayerConfiguration conf = config.getMultiLayerConfiguration(); this.multiLayerNetwork = new MultiLayerNetwork(conf); multiLayerNetwork.init(); this.dna = multiLayerNetwork.params(); }else { this.config = config; MultiLayerConfiguration conf = config.getMultiLayerConfiguration(); this.multiLayerNetwork = new MultiLayerNetwork(conf); multiLayerNetwork.init(dna,true); this.dna = dna; } }} 手写数字智能体实现类 @Getter@Setterpublic class MnistAgent extends Agent { private static final AtomicInteger index = new AtomicInteger(0); private String name; /** * 环境适应分数 */ private double score; /** * 验证分数 */ private double validScore; public MnistAgent(AgentConfig config) { this(config,null); } public MnistAgent(AgentConfig config, INDArray dna) { super(config, dna); name = "agent-" + index.incrementAndGet(); } public static MnistConfig mnistConfig = new MnistConfig(); public static MnistAgent newInstance(){ return new MnistAgent(mnistConfig); } public static MnistAgent create(INDArray dna){ return new MnistAgent(mnistConfig,dna); }} 手写数字识别环境构建 @Slf4jpublic class MnistEnv { /** * 环境数据 */ private static final ThreadLocal tLocal = ThreadLocal.withInitial(() ->

{try {return new MnistDataSetIterator (128, true, 0);} catch (IOException e) {throw new RuntimeException ("mnist file read failed");}}); private static final ThreadLocal testLocal = ThreadLocal.withInitial (()-> {try {return new MnistDataSetIterator (128, false, 0) } catch (IOException e) {throw new RuntimeException ("failed to read mnist file");}}); private static final MnistEvolution evolution = MnistEvolution.getInstance (); / * * Environment load limit * * fierce competition will occur if AI exceeds the upper limit * / private final int max; private Double maxScore,minScore / * * Life in the environment * * Cenozoic and historical generations are sorted together to select the individual * / / 2 variables that are most suitable for the environment. A queue stores the order of KEY, and a MAP stores the data thread safety of specific objects corresponding to KEY. Map private final TreeMap lives = new TreeMap (); / * initialize the environment * * 1. Initialize ai * 2 to the environment. Ai will be initialized for environmental adaptability testing, and sort * @ param max * / public MnistEnv (int max) {this.max = max; for (int I = 0; I)

< max; i++) { MnistAgent agent = MnistAgent.newInstance(); test(agent); synchronized (lives) { lives.put(agent.getScore(),agent); } log.info("初始化智能体 name = {} , score = {}",i,agent.getScore()); } synchronized (lives) { minScore = lives.firstKey(); maxScore = lives.lastKey(); } } /** * 环境适应性评估 * @param ai */ public void test(MnistAgent ai){ MultiLayerNetwork network = ai.getMultiLayerNetwork(); MnistDataSetIterator dataIterator = tLocal.get(); Evaluation eval = new Evaluation(ai.getConfig().getOutput()); try { while (dataIterator.hasNext()) { DataSet data = dataIterator.next(); INDArray output = network.output(data.getFeatures(), false); eval.eval(data.getLabels(),output); } }finally { dataIterator.reset(); } ai.setScore(eval.accuracy()); } /** * 迁移评估 * * @param ai */ public void validation(MnistAgent ai){ MultiLayerNetwork network = ai.getMultiLayerNetwork(); MnistDataSetIterator dataIterator = testLocal.get(); Evaluation eval = new Evaluation(ai.getConfig().getOutput()); try { while (dataIterator.hasNext()) { DataSet data = dataIterator.next(); INDArray output = network.output(data.getFeatures(), false); eval.eval(data.getLabels(),output); } }finally { dataIterator.reset(); } ai.setValidScore(eval.accuracy()); } /** * 进化 * * 每轮随机创建ai并放入环境中进行优胜劣汰 * @param n 进化次数 */ public void evolution(int n){ BlockThreadPool blockThreadPool=new BlockThreadPool(2); for (int i = 0; i < n; i++) { blockThreadPool.execute(() ->

Contend (newLive ());} / / for (int I = 0; I

< n; i++) {// contend(newLive());// } } /** * 竞争 * @param ai */ public void contend(MnistAgent ai){ test(ai); quality(ai); double score = ai.getScore(); if(score max) { MnistAgent lastAI = lives.remove(lives.firstKey()); UI.put("淘 汰 ",String.format("name = %s, score = %s", lastAI.getName(),lastAI.getScore())); } lastEntry = lives.lastEntry(); minScore = lives.firstKey(); } Double lastScore = lastEntry.getKey(); if(lastScore >

MaxScore) {maxScore = lastScore; MnistAgent agent = lastEntry.getValue (); validation (agent); UI.put ("max Verification", String.format ("score =% s Magi validScore =% s", lastScore,agent.getValidScore ()); try {Warehouse.write (agent) } catch (IOException ex) {log.error ("failed to save object", ex);} ArrayList scoreList = new ArrayList (100); ArrayList avgList = new ArrayList (); private void quality (MnistAgent ai) {synchronized (scoreList) {scoreList.add (ai.getScore ()) If (scoreList.size () > = 100) {double avg = scoreList.stream () .mapToDouble (e-> e) .average () .getAsDouble (); avgList.add ((int) (avg * 1000)); StringBuffer buffer = new StringBuffer () AvgList.forEach (e-> buffer.append (e) .append ('\ t'); UI.put ("average score", String.format ("aix100 avg =% s", buffer.toString ()); scoreList.clear () } / * randomly generate new agents * * completely randomly generate female parents * randomly select male parents from the same or higher scores than the target * * genetic evolution is carried out between 1% and 10%, and the higher the score is based on the more stable * / public MnistAgent newLive () {double r = Math.random () / / Gene mutation rate double v = r / 11 + 0.01; / female parent MnistAgent mAgent = getMother (r); / / male parent MnistAgent fAgent = getFather (r); int I = (int) (Math.random () * 3); INDArray newDNA = evolution.inheritance (mAgent.getDna (), fAgent.getDna ()) Switch (I) {case 0: newDNA = evolution.other (newDNA,v); break; case 1: newDNA = evolution.mutation (newDNA,v,0.1); break; case 2: newDNA = evolution.substitution (newDNA,v); break } return MnistAgent.create (newDNA);} / * the male parent only chooses samples with higher scores than the female parent * @ param r * @ return * / private MnistAgent getFather (double r) {r + = (Math.random () * (1mer r)); return getMother (r) } private MnistAgent getMother (double r) {int index = (int) (r * max); return getMnistAgent (index);} private MnistAgent getMnistAgent (int index) {synchronized (lives) {Iterator it = lives.entrySet () .iterator (); for (int I = 0; I

< index; i++) { it.next(); } return it.next().getValue(); } }} 主函数 @Slf4jpublic class Program { public static void main(String[] args) { UI.put("开始时间",new Date().toLocaleString()); MnistEnv env = new MnistEnv(128); env.evolution(Integer.MAX_VALUE); }} 运行截图

After reading the above, do you have any further understanding of how dl4j uses genetic neural networks to complete handwritten digit recognition? If you want to know more knowledge or related content, please follow the industry information channel, thank you for your support.

Welcome to subscribe "Shulou Technology Information " to get latest news, interesting things and hot topics in the IT industry, and controls the hottest and latest Internet news, technology news and IT industry trends.

Views: 0

*The comments in the above article only represent the author's personal views and do not represent the views and positions of this website. If you have more insights, please feel free to contribute and share.

Share To

Internet Technology

Wechat

© 2024 shulou.com SLNews company. All rights reserved.

12
Report