RNN หรือ recurrent neural network คือโครงข่ายประสาทเทียมที่สร้างมาสำหรับข้อมูลแบบลำดับ เช่น ข้อความ เสียงพูด หรืออนุกรมเวลา ในแต่ละขั้น มันจะรวมอินพุตปัจจุบันเข้ากับ hidden state จากขั้นก่อนหน้า ทำให้เอาต์พุตขึ้นอยู่กับสิ่งที่เกิดขึ้นก่อนหน้านั้นได้

นี่คือแนวคิดสำคัญ: RNN มีความจำที่ไหลต่อเนื่องอยู่ตลอด ส่วน LSTM เป็น RNN แบบมีเกตที่จัดการความจำนี้อย่างระมัดระวังมากขึ้น เมื่อข้อมูลสำคัญต้องคงอยู่ผ่านหลายขั้น

RNN ทำอะไรในแต่ละช่วงเวลา

ที่ช่วงเวลา tt RNN แบบง่ายจะอัปเดต hidden state ด้วยกฎประมาณนี้

ht=tanh(Wxxt+Whht1+b).h_t = \tanh(W_x x_t + W_h h_{t-1} + b).

ตรงนี้ xtx_t คืออินพุตปัจจุบัน, ht1h_{t-1} คือ hidden state ก่อนหน้า และ hth_t คือ hidden state ใหม่ เมทริกซ์ WxW_x และ WhW_h รวมถึงไบแอส bb จะถูกเรียนรู้ระหว่างการฝึก

ถ้าโมเดลสร้างเอาต์พุตในทุกขั้นด้วย รูปแบบที่พบบ่อยคือ

yt=Wyht+c.y_t = W_y h_t + c.

กฎของเอาต์พุตที่แน่นอนขึ้นอยู่กับงาน บางปัญหาต้องการเอาต์พุตทุกขั้น ขณะที่บางปัญหาใช้เฉพาะ hidden state สุดท้าย

ทำไม hidden state จึงสำคัญ

โครงข่ายแบบ feedforward เห็นอินพุตหนึ่งครั้งแล้วก็จบ แต่ RNN นำส่วนหนึ่งของการคำนวณก่อนหน้ากลับมาใช้ซ้ำ การใช้ซ้ำนี้เองที่ทำให้มันมีประโยชน์กับข้อความ เสียง อนุกรมเวลา และข้อมูลที่มีลำดับอื่น ๆ

คุณอาจมองว่า hidden state คือโน้ตสั้น ๆ ที่โมเดลเขียนไว้ให้ตัวเองหลังจบแต่ละขั้น ขั้นถัดไปจะอ่านโน้ตนั้น อัปเดตมัน แล้วส่งเวอร์ชันใหม่ต่อไปข้างหน้า

ถ้าคุณเปลี่ยนลำดับของอินพุตชุดเดิม hidden state ก็มักจะเปลี่ยนตามไปด้วย นั่นแปลว่าลำดับมีความสำคัญ

ตัวอย่าง RNN แบบคำนวณจริง

RNN จริงมักใช้เวกเตอร์และฟังก์ชันกระตุ้นแบบไม่เชิงเส้น เพื่อให้การคำนวณอ่านง่าย เราจะใช้ตัวอย่างสถานะที่มีเพียงตัวเลขเดียว:

ht=0.5ht1+xt,h0=0.h_t = 0.5 h_{t-1} + x_t, \quad h_0 = 0.

ตอนนี้ประมวลผลลำดับ x1=2x_1 = 2, x2=1x_2 = -1, x3=3x_3 = 3

ขั้นแรก:

h1=0.5(0)+2=2.h_1 = 0.5(0) + 2 = 2.

ขั้นที่สอง:

h2=0.5(2)+(1)=0.h_2 = 0.5(2) + (-1) = 0.

ขั้นที่สาม:

h3=0.5(0)+3=3.h_3 = 0.5(0) + 3 = 3.

สิ่งสำคัญตรงนี้ไม่ใช่สูตรที่แน่นอน แต่คือการพึ่งพาสถานะก่อนหน้า ในขั้นที่ 2 การอัปเดตไม่ได้ใช้แค่ x2x_2 เท่านั้น แต่ยังใช้สิ่งที่ถูกส่งต่อมาจากขั้นที่ 1 ด้วย นี่คือแก่นของแนวคิด RNN

ถ้าคุณสลับลำดับแล้วใช้ x1=1x_1 = -1, x2=2x_2 = 2, x3=3x_3 = 3 จะได้ว่า

h1=1,h2=0.5(1)+2=1.5,h3=0.5(1.5)+3=3.75.h_1 = -1, \quad h_2 = 0.5(-1) + 2 = 1.5, \quad h_3 = 0.5(1.5) + 3 = 3.75.

สถานะสุดท้ายต่างออกไป แม้ว่าจะใช้ตัวเลขชุดเดิมทั้งหมด นี่จึงเป็นเหตุผลที่ RNN เป็นโมเดลสำหรับข้อมูลแบบลำดับ ไม่ใช่โมเดลที่มองอินพุตเป็นเพียงกองข้อมูลที่ไม่สนลำดับ

ทำไม RNN พื้นฐานจึงลำบากกับลำดับยาว

ใน RNN พื้นฐาน ข้อมูลเก่าต้องคงอยู่ผ่านการอัปเดตซ้ำ ๆ หลายครั้ง ถ้าลำดับยาวมาก เรื่องนี้จะทำได้ยาก สัญญาณที่มีประโยชน์อาจค่อย ๆ จางหาย และระหว่างการฝึก gradient ก็อาจเล็กลงมากหรือใหญ่เกินไปเมื่อย้อนผ่านหลายขั้น

นั่นจึงเป็นเหตุผลที่ RNN แบบธรรมดามักมีปัญหาเมื่อโจทย์ต้องพึ่งข้อมูลจากตำแหน่งที่อยู่ไกลมากในลำดับ ปัญหาไม่ได้อยู่ที่แนวคิดการวนซ้ำผิด แต่เป็นเพราะการรักษาความจำระยะไกลด้วยการอัปเดต hidden state แบบง่ายนั้นทำได้ยาก

LSTM ช่วยให้ RNN จำได้ดีขึ้นอย่างไร

LSTM ย่อมาจาก long short-term memory และเป็น RNN แบบมีเกต มันเพิ่มเส้นทางความจำที่มีโครงสร้างมากขึ้น ซึ่งมักเรียกว่า cell state พร้อมกับเกตที่ควบคุมว่าข้อมูลใดควรถูกลืม ข้อมูลใหม่ใดควรถูกเขียนเข้าไป และส่วนใดควรถูกเปิดออกมาเป็นเอาต์พุต

คุณไม่จำเป็นต้องรู้สมการของเกตทั้งหมดเพื่อเข้าใจประเด็นสำคัญ โครงสร้างนี้ทำให้โมเดลควบคุมความจำได้มากขึ้น ถ้ารายละเอียดบางอย่างต้องอยู่รอดผ่านหลายขั้น LSTM ก็มีความพร้อมในการเก็บมันไว้มากกว่า RNN แบบธรรมดา

แต่นั่นไม่ได้แปลว่า LSTM จะจำทุกอย่างได้ตลอดไป ความหมายคือสถาปัตยกรรมนี้เรียนรู้ได้ดีกว่าว่าควรเก็บข้อมูลเมื่อไร และควรทิ้งข้อมูลเมื่อไร

RNN กับ LSTM แบบภาษาง่าย ๆ

RNN พื้นฐานมีสถานะต่อเนื่องเพียงหนึ่งชุดและอัปเดตมันซ้ำไปเรื่อย ๆ ส่วน LSTM เพิ่มกลไกความจำที่แข็งแรงกว่ารอบแนวคิดเดิมนี้

ถ้าลำดับสั้นและความสัมพันธ์อยู่ใกล้กัน RNN แบบธรรมดาอาจเพียงพอ แต่ถ้างานต้องพึ่งข้อมูลจากช่วงที่อยู่ก่อนหน้ามากในลำดับ LSTM มักเป็นตัวเลือกที่ปลอดภัยกว่า

ความเข้าใจผิดที่พบบ่อยเกี่ยวกับ RNN และ LSTM

คิดว่า RNN เห็นทั้งลำดับพร้อมกัน

โดยทั่วไปไม่ใช่ ภาพมาตรฐานคือการประมวลผลทีละขั้น โดยมีสถานะถูกส่งต่อไปข้างหน้า

คิดว่า LSTM แก้ปัญหาความจำได้สมบูรณ์แบบ

มันช่วยเรื่องความสัมพันธ์ระยะไกลได้ แต่ก็ยังเป็นโมเดลที่ต้องฝึก มีความจุจำกัด และมีข้อจำกัดในการใช้งานจริง

มองข้ามลำดับของข้อมูล

RNN ถูกสร้างมาสำหรับข้อมูลที่มีลำดับ การสลับตำแหน่งขององค์ประกอบในลำดับจะเปลี่ยนการคำนวณ

มองว่า hidden state เป็นความจำที่มนุษย์อ่านออกตรง ๆ

hidden state คือการแทนค่าตัวเลขที่โมเดลเรียนรู้ขึ้นมา ไม่ใช่สรุปแบบประโยคที่อ่านเข้าใจได้ชัดเจน

RNN และ LSTM ใช้เมื่อไร

ทั้งสองถูกใช้กับปัญหาแบบลำดับ เช่น language modeling เสียงพูด ลายมือ สตรีมข้อมูลจากเซนเซอร์ และการพยากรณ์อนุกรมเวลา ปัจจุบันงานด้านภาษาหลายอย่างหันไปใช้ transformer มากกว่า แต่ RNN และ LSTM ก็ยังสำคัญ เพราะช่วยอธิบายเรื่องความจำของลำดับได้ชัดเจน และยังมีประโยชน์ในงานขนาดเล็กหรืองานเฉพาะทางบางประเภท

ลองทำเวอร์ชันของคุณเอง

เขียนลำดับ 4 ขั้นของคุณเอง แล้วใช้กฎตัวอย่าง ht=0.5ht1+xth_t = 0.5 h_{t-1} + x_t จากนั้นสลับตำแหน่งอินพุตสองตัวแล้วเปรียบเทียบสถานะสุดท้าย การทดลองเล็ก ๆ นี้จะทำให้เห็นบทบาทของการวนซ้ำชัดกว่าการจำตัวย่อเพียงอย่างเดียว

ถ้าคุณอยากลองอีกกรณีหนึ่ง ให้เปรียบเทียบหน้านี้กับคำอธิบายเรื่อง transformer หรือ Markov chain แล้วสังเกตว่าแต่ละโมเดลจัดการกับข้อมูลในอดีตอย่างไร

ต้องการความช่วยเหลือในการแก้โจทย์?

อัปโหลดคำถามของคุณแล้วรับคำตอบแบบทีละขั้นตอนที่ผ่านการตรวจสอบในไม่กี่วินาที

เปิด GPAI Solver →